卧薪尝胆,厚积薄发。
树上四次求和
Date: Mon Mar 11 14:33:41 CST 2019 In Category: NoCategory

Description:

给一棵树和一个排列 $p$ ,每次询问给出 $k$ ,求: $$ \sum_{l=1}^k\sum_{r=l}^k\sum_{i=l}^r\sum_{j=i}^rdis(p[i].p[j]) $$ $1\leqslant n\leqslant 10^5$

Solution:

$$ \begin{align} ans[k]&=\sum_{l=1}^k\sum_{r=l}^k\sum_{i=l}^r\sum_{j=i}^rdis(p[i],p[j])\\ &=\sum_{i=1}^k\sum_{j=i}^kdis(p[i],p[j])\times i\times (k-j+1)\\ &=\sum_{i=1}^{k-1}\sum_{j=i}^kdis(p[i],p[j])\times i\times (k-j+1)\\ &=\sum_{i=1}^{k-1}\sum_{j=i}^{k-1}dis(p[i],p[j])\times i\times (k-j+1)+\sum_{i=1}^{k-1}dis(p[i],p[k])\times i\\ &=\sum_{i=1}^{k-1}\sum_{j=i}^{k-1}dis(p[i],p[j])\times i\times (k-j)+\sum_{i=1}^{k-1}\sum_{j=i}^{k-1}dis(p[i],p[j])\times i+\sum_{i=1}^{k-1}dis(p[i],p[k])\times i\\ &=ans[k-1]+\sum_{i=1}^{k-1}\sum_{j=i}^{k-1}dis(p[i],p[j])\times i+\sum_{i=1}^{k-1}dis(p[i],p[k])\times i\\ \end{align} $$
设 $sum[i]$ 表示: $$ sum[k]=\sum_{i=1}^{k-1}dis(p[i],p[k])\times i,sums[k]=\sum_{i=1}^ksum[k] $$ 那么: $$ ans[k]=ans[k-1]+sums[k-1]+sum[k] $$ 于是我们只要求 $sum[k]$ 即可。 $$ \begin{align} sum[k]&=\sum_{i=1}^{k-1}dis(p[i],p[k])\times i\\ &=\sum_{i=1}^{k-1}dep[p[i]]\times i+dep[p[k]]\sum_{i=1}^{k-1}i+\sum_{i=1}^{k-1}dep[LCA(p[i],p[k])]\times i \end{align} $$ 前两项都很好维护重点是最后一项,发现最后一项的形式非常像 $LNOI\;LCA$ ,于是我们只要把每次 $+1$ 改成 $+i$ 就可以了。

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<cctype>
#include<cstring>
using namespace std;
int n,q;
#define MAXN 100010
#define MOD 998244353
struct edge
{
int to,nxt;
}e[MAXN << 1];
int edgenum = 0;
int lin[MAXN] = {0};
void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b]};lin[b] = edgenum;
return;
}
int dep[MAXN],top[MAXN],siz[MAXN],son[MAXN],fa[MAXN];
int rnk[MAXN],th[MAXN],tot = 0;
void dfs1(int k,int depth)
{
dep[k] = depth;
siz[k] = 1;
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to == fa[k])continue;
fa[e[i].to] = k;
dfs1(e[i].to,depth + 1);
siz[k] += siz[e[i].to];
if(son[k] == 0 || siz[e[i].to] > siz[son[k]])son[k] = e[i].to;
}
return;
}
void dfs2(int k,int tp)
{
rnk[k] = ++tot;th[tot] = k;
top[k] = tp;
if(son[k] == 0)return;
dfs2(son[k],tp);
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to == fa[k] || e[i].to == son[k])continue;
dfs2(e[i].to,e[i].to);
}
return;
}
int LCA(int a,int b)
{
while(top[a] != top[b])
{
if(dep[top[a]] < dep[top[b]])swap(a,b);
a = fa[top[a]];
}
return (dep[a] < dep[b] ? a : b);
}
struct node
{
int lc,rc;
int sum,tag;
node(){sum = tag = 0;}
}t[MAXN << 1];
int ptr = 0;
int newnode(){return ++ptr;}
int root;
#define mid ((l + r) >> 1)
void build(int &rt,int l,int r)
{
rt = newnode();
if(l == r)return;
build(t[rt].lc,l,mid);
build(t[rt].rc,mid + 1,r);
return;
}
void pushdown(int rt,int l,int r)
{
if(t[rt].tag == 0)return;
int ll = l,lr = mid,rl = mid + 1,rr = r;
t[t[rt].lc].sum = (t[t[rt].lc].sum + 1ll * (lr - ll + 1) * t[rt].tag % MOD) % MOD;
t[t[rt].lc].tag = (t[t[rt].lc].tag + t[rt].tag) % MOD;
t[t[rt].rc].sum = (t[t[rt].rc].sum + 1ll * (rr - rl + 1) * t[rt].tag % MOD) % MOD;
t[t[rt].rc].tag = (t[t[rt].rc].tag + t[rt].tag) % MOD;
t[rt].tag = 0;
return;
}
void add(int rt,int L,int R,int k,int l,int r)
{
if(L <= l && r <= R)
{
t[rt].sum = (t[rt].sum + 1ll * (r - l + 1) * k % MOD) % MOD;
t[rt].tag = (t[rt].tag + k) % MOD;
return;
}
pushdown(rt,l,r);
if(L <= mid)add(t[rt].lc,L,R,k,l,mid);
if(R > mid)add(t[rt].rc,L,R,k,mid + 1,r);
t[rt].sum = (t[t[rt].lc].sum + t[t[rt].rc].sum) % MOD;
return;
}
int query(int rt,int L,int R,int l,int r)
{
if(L <= l && r <= R)return t[rt].sum;
pushdown(rt,l,r);
int res = 0;
if(L <= mid)res = (res + query(t[rt].lc,L,R,l,mid)) % MOD;
if(R > mid)res = (res + query(t[rt].rc,L,R,mid + 1,r)) % MOD;
return res;
}
void change(int k,int v)
{
while(k != 0)
{
add(root,rnk[top[k]],rnk[k],v,1,n);
k = fa[top[k]];
}
return;
}
int query(int k)
{
int res = 0;
while(k != 0)
{
res = (res + query(root,rnk[top[k]],rnk[k],1,n)) % MOD;
k = fa[top[k]];
}
return res;
}
int ans[MAXN];
int p[MAXN];
int sum[MAXN];
int sums[MAXN];
int sumd[MAXN];
int main()
{
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
scanf("%d%d",&n,&q);
int a,b;
for(int i = 1;i < n;++i)
{
scanf("%d%d",&a,&b);
add(a,b);
}
for(int i = 1;i <= n;++i)scanf("%d",&p[i]);
dfs1(1,1);dfs2(1,1);
build(root,1,n);
int sumid = 0;
for(int k = 1;k <= n;++k)
{
ans[k] = (ans[k - 1] + sums[k - 1]) % MOD;
sum[k] = query(p[k]);
sum[k] = 1ll * (MOD - 2) * sum[k] % MOD;
sum[k] = (sum[k] + 1ll * sumid * dep[p[k]] % MOD) % MOD;
sumid = (sumid + k) % MOD;
sum[k] = (sum[k] + sumd[k - 1]) % MOD;
ans[k] = (ans[k] + sum[k]) % MOD;
sums[k] = (sums[k - 1] + sum[k]) % MOD;
sumd[k] = (sumd[k - 1] + 1ll * dep[p[k]] * k % MOD) % MOD;
change(p[k],k);
}
for(int i = 1;i <= q;++i)
{
scanf("%d",&a);
printf("%d\n",ans[a]);
}
fclose(stdin);
fclose(stdout);
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡