卧薪尝胆,厚积薄发。
WC2019 数树
Date: Fri Mar 08 00:27:52 CST 2019 In Category: NoCategory

Description:

对于两棵树,要求在两棵树中都连边的点颜色必须相同,分别对于给出两棵树/一棵树/零棵树的情况计算所有可能的树的 $k$ 种颜色的染色方案总个数。
$1\leqslant n\leqslant10^5$

Solution:

Task1:

两棵树都给出,直接把在两棵树中都连边的点连起来,答案是颜色的连通块个数次方。

Task2:

给出一棵树,先把题意转化为: $$ ans=\sum_{T_2}y^{n-|E_1\cap E_2|} $$ 那么我们实际上只要求每种树的方案的 $y^{-|E_1\cap E_2|}$ 即可。
然后有一个小技巧就是对于贡献是 $x^{|S|}$ 这类的,我们可以用二项式定理展开: $$ x^k=\sum_{i=0}^k\binom ki(x-1)^i $$ 组合意义就是枚举所有子集,大小为 $i$ 的子集有 $\binom ki$ 个,单个的贡献是 $(x-1)^i$ ,于是我们可以直接枚举所有子集,子集 $T$ 的贡献是 $(x-1)^{|T|}$ ,省去了子集反演,也就是容斥。
然后又根据 $prufer$ 序列我们知道 $m$ 个联通块的森林,每个联通块大小为 $a_i$ 的生成树的个数为: $$ n^{m-2}\prod_{i=1}^ma_i $$ 设 $c=y-1$ ,那么我们可以枚举所有边集 $E=|E_1\cap E_2|$ : $$ \begin{align} ans&=\sum_{E_2}y^{|E_1\cap E_2|}\\ &=\sum_{E_2}\sum_{E_2'\subseteq E_1\cap E_2}c^{|E_2'|}\\ &=\sum_{E_2'\subseteq E_1}c^{|E_2'|}\sum_{E_2'\subseteq E_2}1\\ &=\sum_{E_2'\subseteq E_1}c^{|E_2'|}n^{n-|E_2'|-2}\prod_{j=1}^{n-|E_2|}a_j\\ &=n^{n-2}\sum_{E_2'\subseteq E_1}\Bigl(\frac cn\Bigr)^{|E_2'|}\prod_{j=1}^{n-|E_2|}a_j\\ &=n^{n-2}\Bigl(\frac cn\Bigr)^n\sum_{E_2'\subseteq E_1}\Bigl(\frac nc\Bigr)^{n-|E_2'|}\prod_{j=1}^{n-|E_2|}a_j\\ &=n^{-2}c^n\sum_{E_2'\subseteq E_1}\prod_{j=1}^{n-|E_2|}\frac {na_j}c\\ \end{align} $$ 于是我们就可以 $DP$ 了,但是那个 $a_i$ 非常不好处理,不过我们还是能写一个 $O(n^2)$ 的 $DP$ 出来,设 $f[i][j]$ 表示以 $i$ 为根的子树 $i$ 所在联通块大小为 $j$ ,但是考虑 $\prod a_i$ 的组合意义,相当于从每个连通块中选一个点的方案数,因此可以设 $f[i][0/1]$ 表示以 $i$ 为根的子树根节点所在连通块选了还是没选,然后直接树形 $DP$ 即可。

Task3:

还是用上面提到的小技巧,对于每种边集计数。
$$ \begin{align} ans&=\sum_{E_1}\sum_{E_2}y^{n-|E_1\cap E_2|}\\ ans&=y^n\sum_{E_1}\sum_{E_2}y^{-|E_1\cap E_2|}=y^n\sum_{E_1}\sum_{E_2}x^{|E_1\cap E_2|}(x=y^{-1}) \end{align} $$
$$ \begin{align} &\sum_{E_1}\sum_{E_2}x^{|E_1\cap E_2|}=\sum_{E_1}\sum_{E_2}\sum_{E'\subseteq E_1\cap E_2}(x-1)^{|E'|} \end{align} $$
设 $c=x-1$ : $$ \begin{align} &\sum_{E_1}\sum_{E_2}\sum_{E'\subseteq E_1\cap E_2}c^{|E'|}\\ =&\sum_{E'}c^{|E'|}\sum_{E'\subseteq E_1}\sum_{E'\subseteq E_2}1\\ =&\sum_{E'}c^{|E'|}\Bigl(n^{n-|E'|-2}\prod_{j=1}^{n-|E'|}a_j\Bigr)^2\\ =&\sum_{k=1}^nc^{n-k}\sum_{a_1+a_2+\cdots+a_k=n}\frac{\binom{n}{a_1,a_2,\dots,a_k}}{k!}\prod_{i=1}^k\Big(a_i^{a_i-2}\Big)n^{2k-4}\Big(\prod_{j=1}^ka_j\Bigr)^2\\ =&\frac{n!}{n^4}\sum_{k=1}^n\frac{c^{n-k}n^{2k}}{k!}\sum_{a_1+a_2+\cdots+a_k=n}\prod_{i=1}^k\frac{a_i^{a_i}}{a_i!}\\ =&\frac{n!}{n^4}c^n\sum_{k=1}^n\frac1{k!}\sum_{a_1+a_2+\cdots+a_k=n}\prod_{i=1}^k\frac{n^2a_i^{a_i}}{ca_i!}\\ \end{align} $$ 设: $$ f_i=\frac{n^2i^i}{ci!},F(x)=\sum_{i\geqslant 0}f_ix^i $$
$$ ans=y^nc^n\frac{n!}{n^4}\sum_{k=1}^n\frac1{k!}[x^n]F(x)^k=c^ny^n\frac{n!}{n^4}[x^n]\sum_{k=1}^n\frac{F^k}{k!}=c^ny^n\frac{n!}{n^4}[x^n]e^F $$
于是多项式 $\exp$ 就行了。

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<map>
#include<cctype>
#include<cstring>
using namespace std;
int n,y,op;int x[100000],z[100000];
#define MAXN 400010
#define P 998244353
int power(int a,int b)
{
int res = 1;
for(;b > 0;b = b >> 1,a = 1ll * a * a % P)if(b & 1)res = 1ll * res * a % P;
return res;
}
int inver(int a){return power(a,P - 2);}
namespace SOLVE0
{
map<pair<int,int>,int> p;
void solve()
{
if(y == 1){cout << 1 << endl;return;}
int tot = n;
int a,b;
for(int i = 1;i < n;++i)
{
scanf("%d%d",&a,&b);
p[make_pair(min(a,b),max(a,b))] = 1;
}
for(int i = 1;i < n;++i)
{
scanf("%d%d",&a,&b);
if(p.find(make_pair(min(a,b),max(a,b))) != p.end())--tot;
}
cout << power(y,tot) << endl;
return;
}
}
namespace SOLVE1
{
struct edge
{
int to,nxt;
}e[MAXN << 1];
int edgenum = 0;
int lin[MAXN] = {0};
void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b]};lin[b] = edgenum;
}
int f[MAXN][2];
int c;
void dp(int k,int fa)
{
f[k][0] = 1;f[k][1] = 1ll * n * inver(c) % P;
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to == fa)continue;
dp(e[i].to,k);
int val0 = f[k][0],val1 = f[k][1];
f[k][0] = 1ll * val0 * (f[e[i].to][0] + f[e[i].to][1]) % P;
f[k][1] = (1ll * val1 * (f[e[i].to][0] + f[e[i].to][1]) % P + 1ll * val0 * f[e[i].to][1] % P) % P;
}
return;
}
void solve()
{
if(n == 1){cout << 1 << endl;return;}
if(y == 1){cout << power(n,n - 2) << endl;return;}
int a,b;
for(int i = 1;i < n;++i)
{
scanf("%d%d",&a,&b);
add(a,b);
}
c = inver(y) - 1;
dp(1,0);
int ans = 1ll * f[1][1] * power(y,n) % P * power(n,n - 2) % P * power(1ll * c * inver(n) % P,n) % P;
cout << ans << endl;
return;
}
}
namespace SOLVE2
{
int F[MAXN];
int fac[MAXN],inv[MAXN];
int G[MAXN];
int rev[MAXN];
int ww[MAXN << 1],*g = ww + MAXN;
int init(int n)
{
int l = 0,len = 1;
for(;len <= n;len = len << 1)++l;
for(int i = 0;i < len;++i)rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
g[0] = g[-len] = 1;g[1] = g[1 - len] = power(3,(P - 1) / len);
for(int i = 2;i < len;++i)g[i] = g[i - len] = 1ll * g[i - 1] * g[1] % P;
return len;
}
void NTT(int f[],int l,int type)
{
for(int i = 0;i < l;++i)
{
if(i < rev[i])swap(f[i],f[rev[i]]);
}
for(int i = 2;i <= l;i = i << 1)
{
int wn = g[type * l / i];
for(int j = 0;j < l;j += i)
{
int w = 1;
for(int k = j;k < j + i / 2;++k)
{
int u = f[k],t = 1ll * w * f[k + i / 2] % P;
f[k] = (u + t) % P;
f[k + i / 2] = (u - t + P) % P;
w = 1ll * w * wn % P;
}
}
}
if(type == -1)
{
int ni = power(l,P - 2);
for(int i = 0;i < l;++i)f[i] = 1ll * f[i] * ni % P;
}
return;
}
int invtmp[MAXN];
void poly_inv(int deg,int a[],int b[])
{
if(deg == 1)
{
b[0] = power(a[0],P - 2);
return;
}
poly_inv((deg + 1) >> 1,a,b);
int l = init(deg * 2);
for(int i = 0;i < deg;++i)invtmp[i] = a[i];
for(int i = deg;i < l;++i)invtmp[i] = 0;
NTT(invtmp,l,1);NTT(b,l,1);
for(int i = 0;i < l;++i)b[i] = (2ll * b[i] % P - 1ll * invtmp[i] * b[i] % P * b[i] % P + P) % P;
NTT(b,l,-1);
for(int i = 0;i < l;++i)invtmp[i] = 0;
for(int i = deg;i < l;++i)b[i] = 0;
return;
}
int lntmp[MAXN];
void poly_ln(int deg,int a[],int b[])
{
poly_inv(deg,a,b);
for(int i = 0;i < deg;++i)lntmp[i] = a[i];
for(int i = 0;i < deg;++i)lntmp[i] = 1ll * (i + 1) * lntmp[i + 1] % P;
int l = init(deg * 2);
NTT(b,l,1);NTT(lntmp,l,1);
for(int i = 0;i < l;++i)b[i] = 1ll * b[i] * lntmp[i] % P;
NTT(b,l,-1);
for(int i = deg - 1;i >= 1;--i)b[i] = 1ll * inver(i) * b[i - 1] % P;b[0] = 0;
for(int i = deg;i < l;++i)b[i] = 0;
for(int i = 0;i < l;++i)lntmp[i] = 0;
return;
}
int exptmp[MAXN];
void poly_exp(int deg,int a[],int b[])
{
if(deg == 1)
{
b[0] = 1;
return;
}
poly_exp((deg + 1) >> 1,a,b);
poly_ln(deg,b,exptmp);
int l = init(deg * 2);
for(int i = 0;i < deg;++i)exptmp[i] = (a[i] - exptmp[i] + P) % P;
exptmp[0] = (exptmp[0] + 1) % P;
NTT(exptmp,l,1);NTT(b,l,1);
for(int i = 0;i < l;++i)b[i] = 1ll * b[i] * exptmp[i] % P;
NTT(b,l,-1);
for(int i = deg;i < l;++i)b[i] = 0;
for(int i = 0;i < l;++i)exptmp[i] = 0;
return;
}
void solve()
{
if(n == 1){cout << 1 << endl;return;}
if(y == 1){cout << power(n,2 * n - 4) << endl;return;}
fac[0] = 1;for(int i = 1;i <= n;++i)fac[i] = 1ll * fac[i - 1] * i % P;
inv[n] = inver(fac[n]);for(int i = n - 1;i >= 0;--i)inv[i] = 1ll * inv[i + 1] * (i + 1) % P;
int c = inver(y) - 1;
for(int i = 1;i <= n;++i)F[i] = 1ll * n * n % P * power(i,i) % P * inver(c) % P * inv[i] % P;
poly_exp(n + 1,F,G);
cout << 1ll * power(c,n) % P * power(y,n) % P * fac[n] % P * inver(power(n,4)) % P * G[n] % P << endl;
return;
}
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%d%d%d",&n,&y,&op);
if(op == 0)SOLVE0::solve();
else if(op == 1)SOLVE1::solve();
else SOLVE2::solve();
fclose(stdin);
fclose(stdout);
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡