卧薪尝胆,厚积薄发。
九省联考2018 秘密袭击coat
Date: Thu Feb 14 08:35:32 CST 2019 In Category: NoCategory

Description:

给一颗有 $N$ 个点的树,点权在 $1\sim W$ 之间,求树的每一个联通块的第 $K$ 大点权之和。
$1\leqslant n\leqslant 1666$

Solution:

先转化题目: $$ \begin{align} ans&=\sum_{S}kth\ val\ of\ S\\ &=\sum_{i=1}^W\sum_S[kth\ val\ of\ S\geqslant i]\\ \end{align} $$ 再设 $cnt_S[i]$ 表示联通块 $S$ 的权值大于等于 $i$ 的点的个数,那么: $$ ans=\sum_{i=1}^W\sum_S[cnt_S[i]\geqslant K] $$ 那么就可以 $DP$ 了,设 $f[i][j][k]$ 表示以 $i$ 为根的所有联通块中,权值大于等于 $j$ 的点的个数恰好为 $k$ 的方案数,在最外面枚举一个 $j$ ,转移就是先分根节点是 $\geqslant j$ 还是 $<j$ 讨论,然后再枚举所有子树的 $k'$ 来计算: $$ \begin{align} &f[i][j][k]=\sum\prod_{v\in son[i]}f[v][j][k_v]&&(val[i]\geqslant j,\sum_{v\in son[i]}k_v=k-1)\\ &f[i][j][k]=\sum\prod_{v\in son[i]}f[v][j][k_v]&&(val[i]<j,\sum_{v\in son[i]}k_v=k) \end{align} $$ 显然最后一维就是一个背包,最后的答案就是: $$ ans=\sum_{i=1}^W\sum_{k=K}^n\sum_{x=1}^nf[x][i][k] $$ 这个暴力看上去是 $O(n^4)$ 的,但是可以用树形背包优化到 $O(n^3)$ 。
我们可以发现最后一维就是一个背包卷积,于是我们就可以想到生成函数还有多项式,考虑设: $$ F[i][j]=\sum_{i=0}^nf[i][j][k]\times x^k $$ 那么转移就可以改写成: $$ F[i][j]=\prod_{v\in son[i]}(F[v][j]+1)\times (1(val[i]<j):x(val[i]\geqslant j)) $$ 之所以要加一是因为当 $k=0$ 时还要计算一种不选子树的方案。
由于最后的求和还是 $O(n^3)$ 的,所以再设一个: $$ G[i][j]=\sum_{v\in son[i]}G[v][j]+F[i][j] $$ 就可以在最后统计答案时省去枚举 $i$ 。
由于多项式卷积是 $O(n^2)$ 的很慢,而模数又不是 $NTT$ 模数,于是我们可以用拉格朗日插值法把多项式转化成点值,然后就可以 $O(n)$ 卷积,又由于多项式是线性变换,因此可以在最后把 $G$ 插出来就可以求解了。
然后我们可以考虑利用整体 $DP$ 的思想,每个点用一棵线段树维护,线段树的下标是 $j$ ,线段树每个叶子节点维护这一个多项式代入 $x=v$ 时的值,考虑我们需要支持的操作:初始化 $F=1\times x^0$ ,给前缀 $[1,val[i]]$ 乘上多项式 $x$ ,给后缀 $[val[i]+1,W]$ 乘上多项式 $1$ ,给所有多项式 $+1$ 还有多项式对应相乘,对应相乘可以在和儿子节点的线段树合并的时候做。具体来说,就是:
$(F,G)=(1,0)/(x,0)$ :线段树区间覆盖。
$(F,G)=(F\times(F_v+1),G+G_v)$ :线段树合并
$(F,G)=(F\times x,G)$ :线段树区间打标记
$(F,G)=(F,G+F)$ :线段树区间打标记
发现每次变换只会有 $(F,G)\to(F\times a+b,G+c+d\times F)$ ,因此我们只要用一个 $(a,b,c,d)$ 的 $tag$ 就行了。
$tag$ 的合并: $$ (a_1,b_1,c_1,d_1)+(a_2,b_2,c_2,d_2)\to(a_1\times a_2,b_1\times a_2+b_2,c_1+c_2+d_2\times b_1,d_1+d_2\times a_1) $$

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<cctype>
#include<cstring>
using namespace std;
int n,m,W;
#define MAXN 1700
#define MOD 64123
typedef unsigned int uint;
int val[MAXN];
struct edge
{
int to,nxt;
}e[MAXN << 1];
int edgenum = 0;
int lin[MAXN] = {0};
void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b]};lin[b] = edgenum;
return;
}
int siz[MAXN];
#define MOD 64123
uint power(uint a,int b)
{
uint res = 1;
while(b > 0)
{
if(b & 1)res = res * a % MOD;
a = a * a % MOD;
b = b >> 1;
}
return res;
}
uint inv(uint k){return power(k,MOD - 2);}
uint res[MAXN];
uint w[MAXN],a[MAXN];
void calc(uint res[MAXN],uint val[MAXN],int n)
{
for(int i = 0;i <= n;++i)res[i] = w[i] = a[i] = 0;
w[0] = 1;
for(int i = 1;i <= n;++i)
{
a[i] = val[i];
for(int j = 1;j <= n;++j)
if(i != j)a[i] = a[i] * inv((i - j + MOD) % MOD) % MOD;
}
for(int i = 1;i <= n;++i)
for(int j = n;j >= 0;--j)w[j] = (w[j - 1] - i * w[j] % MOD + MOD) % MOD;
uint tmp[MAXN];
for(int i = 1;i <= n;++i)
{
tmp[0] = (-w[0] + MOD) % MOD * inv(i) % MOD;
for(int j = 1;j < n;++j)tmp[j] = (tmp[j - 1] - w[j] + MOD) % MOD * inv(i) % MOD;
for(int j = 0;j < n;++j)tmp[j] = tmp[j] * a[i] % MOD;
for(int j = 0;j < n;++j)res[j] = (res[j] + tmp[j]) % MOD;
}
return;
}
struct data
{
uint a,b,c,d;
data(uint a_ = 1,uint b_ = 0,uint c_ = 0,uint d_ = 0){a = a_;b = b_;c = c_;d = d_;}
void init(){a = 1;b = c = d = 0;}
friend data operator + (data a,data b)
{
data res;
res.a = a.a * b.a % MOD;
res.b = (a.b * b.a % MOD + b.b) % MOD;
res.c = (a.c + b.c + b.d * a.b % MOD) % MOD;
res.d = (a.d + b.d * a.a % MOD) % MOD;
return res;
}
};
struct node
{
int lc,rc;
data v;
}t[MAXN * 60];
int ptr = 0;
int newnode()
{
int k = ++ptr;
t[k].v.init();
t[k].lc = t[k].rc = 0;
return k;
}
int root[MAXN];
#define mid ((l + r) >> 1)
void pushdown(int rt)
{
if(t[rt].lc == 0)t[rt].lc = newnode();t[t[rt].lc].v = t[t[rt].lc].v + t[rt].v;
if(t[rt].rc == 0)t[rt].rc = newnode();t[t[rt].rc].v = t[t[rt].rc].v + t[rt].v;
t[rt].v.init();
return;
}
void add(int &rt,int L,int R,data k,int l,int r)
{
if(rt == 0)rt = newnode();
if(L <= l && r <= R){t[rt].v = t[rt].v + k;return;}
pushdown(rt);
if(L <= mid)add(t[rt].lc,L,R,k,l,mid);
if(R > mid)add(t[rt].rc,L,R,k,mid + 1,r);
return;
}
int merge(int x,int y)
{
if(x == 0 || y == 0)return x + y;
if(!t[x].lc && !t[x].rc)swap(x,y);
if(!t[y].lc && !t[y].rc)
{
t[x].v = t[x].v + (data){t[y].v.b,0,0,0};
t[x].v = t[x].v + (data){1,0,t[y].v.c,0};
return x;
}
pushdown(x);pushdown(y);
t[x].lc = merge(t[x].lc,t[y].lc);
t[x].rc = merge(t[x].rc,t[y].rc);
return x;
}
void dp(int k,int fa,uint v)
{
add(root[k],1,W,(data){0,1,0,0},1,W);
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to == fa)continue;
dp(e[i].to,k,v);
root[k] = merge(root[k],root[e[i].to]);
}
add(root[k],1,val[k],(data){v,0,0,0},1,W);
add(root[k],1,W,(data){1,0,0,1},1,W);
add(root[k],1,W,(data){1,1,0,0},1,W);
return;
}
uint tans[MAXN];
uint query(int rt,int l,int r)
{
if(l == r)return t[rt].v.c;
pushdown(rt);
return (query(t[rt].lc,l,mid) + query(t[rt].rc,mid + 1,r)) % MOD;
}
int main()
{
scanf("%d%d%d",&n,&m,&W);
for(int i = 1;i <= n;++i)scanf("%d",&val[i]);
int a,b;
for(int i = 1;i < n;++i)
{
scanf("%d%d",&a,&b);
add(a,b);
}
uint ans = 0;
for(int i = 1;i <= n + 1;++i)
{
ptr = 0;memset(root,0,sizeof(root));
dp(1,0,i);
tans[i] = query(root[1],1,W);
}
calc(res,tans,n + 1);
for(int x = m;x <= n;++x)ans = (ans + res[x]) % MOD;
cout << ans << endl;
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡