卧薪尝胆,厚积薄发。
Surprise me!
Date: Wed Jan 16 11:57:42 CST 2019 In Category: NoCategory

Description:

给定一棵 $n$ 个节点的树,每个点有一个权值 $a[i]$ ,保证 $a[i]$ 是一个排列。 求: $$ \frac{1}{n(n-1)}\sum_{i=1}^n\sum_{j=1}^n\varphi(a_i\times a_j)\times dist(i,j) $$ $1\leqslant n\leqslant 2\times 10^5$

Solution:

首先有结论: $$ \varphi(ab)=\frac{\varphi(a)\times\varphi(b)\gcd(a,b)}{\varphi(\gcd(a,b))} $$ 于是: $$ \begin{align} ans&=\frac{1}{n(n-1)}\sum_{i=1}^n\sum_{j=1}^n\varphi(a_i\times a_j)\times dist(i,j)\\ &=\frac{1}{n(n-1)}\sum_{i=1}^n\sum_{j=1}^n\varphi(ij)\times dist(p_i,p_j)\\ &=\frac{1}{n(n-1)}\sum_{i=1}^n\sum_{j=1}^n\varphi(i)\varphi(j)\frac{\gcd(i,j)}{\varphi(\gcd(i,j))}dist(p_i,p_j)\\ &=\frac{1}{n(n-1)}\sum_{d=1}^n\frac d{\varphi(d)}\sum_{i=1}^n\sum_{j=1}^n\varphi(i)\varphi(j)dist(p_i,p_j)[\gcd(i,j)=d]\\ \end{align} $$ 然后我们只要求出: $$ \begin{align} f(d)&=\sum_{i=1}^n\sum_{j=1}^n\varphi(i)\varphi(j)dist(p_i,p_j)[\gcd(i,j)=d]\\ \end{align} $$ 就可以计算答案了。
考虑怎么计算 $f$ ,我们可以把所有 $a_i$ 是 $d$ 的倍数的点拿出来建虚树,然后在虚树上树形 $dp$ ,大概是: $$ \begin{align} sum&=\sum_{i=1}^n\varphi(i)\\ f(d)&=\sum_{i=1}^n\sum_{j=1}^n\varphi(i)\varphi(j)dist(p_i,p_j)\\ &=\sum_{i=1}^n\sum_{j=1}^n\varphi(i)\varphi(j)(dep[i]+dep[j]-2\times dep[LCA(i,j)])\\ &=\sum_{i=1}^n\sum_{j=1}^n\varphi(i)\varphi(j)dep[i]+\varphi(i)\varphi(j)dep[j]-2\times \varphi(i)\varphi(j)dep[LCA(i,j)])\\ &=2\times \sum_{i=1}^n\varphi(i)dep[i]\times sum-\sum_{i=1}^n\sum_{j=1}^n2\times \varphi(i)\varphi(j)dep[LCA(i,j)])\\ \end{align} $$ 在某个点合并所有的子树然后计算贡献就行了。
然后再倒序枚举每个 $f$ 减掉它的所有倍数就可以计算了。
复杂度 $O(n\log^2n)$ 。

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<vector>
#include<cctype>
#include<cstring>
using namespace std;
int n;
#define MAXN 200010
#define MOD 1000000007
bool isprime[MAXN];
int prime[MAXN],pr = 0;
int phi[MAXN];
int power(int a,int b)
{
int res = 1;
while(b > 0)
{
if(b & 1)res = 1ll * res * a % MOD;
a = 1ll * a * a % MOD;
b = b >> 1;
}
return res;
}
void sieve(int n)
{
for(int i = 2;i <= n;++i)isprime[i] = true;
phi[1] = 1;
for(int i = 2;i <= n;++i)
{
if(isprime[i])prime[++pr] = i,phi[i] = i - 1;
for(int j = 1;j <= pr && i * prime[j] <= n;++j)
{
isprime[i * prime[j]] = false;
if(i % prime[j] == 0)
{
phi[i * prime[j]] = phi[i] * prime[j];
break;
}
else phi[i * prime[j]] = phi[i] * phi[prime[j]];
}
}
return;
}
int s[MAXN],p[MAXN];
struct edge
{
int to,nxt;
}e[MAXN << 1];
int edgenum = 0;
int lin[MAXN] = {0};
void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b]};lin[b] = edgenum;
return;
}
int fa[MAXN],dep[MAXN],son[MAXN],siz[MAXN],top[MAXN],rnk[MAXN],tot = 0;
void dfs1(int k,int depth)
{
siz[k] = 1;dep[k] = depth;
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to != fa[k])
{
fa[e[i].to] = k;
dfs1(e[i].to,depth + 1);
siz[k] += siz[e[i].to];
if(son[k] == 0 || siz[e[i].to] > siz[son[k]])
{
son[k] = e[i].to;
}
}
}
return;
}
void dfs2(int k,int tp)
{
rnk[k] = ++tot;
top[k] = tp;
if(son[k] == 0)return;
dfs2(son[k],tp);
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to != fa[k] && e[i].to != son[k])
{
dfs2(e[i].to,e[i].to);
}
}
return;
}
int LCA(int a,int b)
{
while(top[a] != top[b])
{
if(dep[top[a]] < dep[top[b]])swap(a,b);
a = fa[top[a]];
}
return (dep[a] < dep[b] ? a : b);
}
int F[MAXN];
int v[MAXN];
bool cmp_dfn(int a,int b){return rnk[a] < rnk[b];}
int st[MAXN],tp = 0;
vector<int> g[MAXN];
int f[MAXN];
bool tag[MAXN];
int dp(int k)
{
int res = 0,sum = 0;
int tot = 0;
if(tag[k])tot = phi[s[k]];
for(vector<int>::iterator it = g[k].begin();it != g[k].end();++it)
{
res = (res + dp(*it)) % MOD;
sum = (sum + 1ll * tot * f[*it] % MOD) % MOD;
tot = (tot + f[*it]) % MOD;
}
g[k].clear();
f[k] = tot;
sum = sum * 2 % MOD;
if(tag[k])sum = (sum + 1ll * phi[s[k]] * phi[s[k]] % MOD) % MOD;
res = (res + 1ll * sum * dep[k] % MOD) % MOD;
return res;
}
int calc(int k)
{
v[0] = 0;
for(int i = k;i <= n;i += k)v[++v[0]] = p[i];
int sum = 0;
for(int i = 1;i <= v[0];++i)sum = (sum + phi[s[v[i]]]) % MOD;
int ans = 0;
for(int i = 1;i <= v[0];++i)ans = (ans + 1ll * phi[s[v[i]]] * dep[v[i]] % MOD * sum % MOD) % MOD;
ans = ans * 2 % MOD;
sort(v + 1,v + 1 + v[0],cmp_dfn);
st[tp = 1] = v[1];
for(int i = 1;i <= v[0];++i)tag[v[i]] = true;
for(int i = 2;i <= v[0];++i)
{
int lca = LCA(v[i],st[tp]);
if(lca == st[tp])
{
st[++tp] = v[i];
continue;
}
while(tp >= 2)
{
if(rnk[lca] < rnk[st[tp - 1]])
{
g[st[tp - 1]].push_back(st[tp]);
--tp;
}
else
{
if(lca != st[tp])
{
g[lca].push_back(st[tp]);
}
--tp;
break;
}
}
if(tp == 1 && rnk[lca] < rnk[st[tp]])
{
g[lca].push_back(st[tp]);
st[tp] = lca;
}
if(lca != st[tp])st[++tp] = lca;
st[++tp] = v[i];
}
while(tp >= 2)
{
g[st[tp - 1]].push_back(st[tp]);
--tp;
}
ans = (ans - 2 * dp(st[tp]) % MOD + MOD) % MOD;
for(int i = 1;i <= v[0];++i)tag[v[i]] = false;
return ans;
}
int main()
{
scanf("%d",&n);
sieve(n);
for(int i = 1;i <= n;++i)scanf("%d",&s[i]);
for(int i = 1;i <= n;++i)p[s[i]] = i;
int a,b;
for(int i = 1;i < n;++i)
{
scanf("%d%d",&a,&b);
add(a,b);
}
dfs1(1,1);dfs2(1,1);
for(int d = 1;d <= n;++d)F[d] = calc(d);
for(int i = n;i >= 1;--i)
for(int j = i + i;j <= n;j += i)F[i] = (F[i] - F[j] + MOD) % MOD;
int ans = 0;
for(int d = 1;d <= n;++d)
{
ans = (ans + 1ll * d * power(phi[d],MOD - 2) % MOD * F[d] % MOD) % MOD;
}
ans = 1ll * ans * power(n,MOD - 2) % MOD * power(n - 1,MOD - 2) % MOD;
cout << ans << endl;
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡