卧薪尝胆,厚积薄发。
ZJOI2018 历史
Date: Fri Nov 23 11:41:11 CST 2018 In Category: NoCategory

Description:

给一棵树和树上每个点 $access$ 操作的次数 $a[i]$ ,求出轻重链切换次数的最大值,支持给 $a[i]$ 增加一个正整数。
$1\leqslant n\leqslant4\times 10^5$

Solution:

首先我们可以把这些轻重链切换按他们所在的点分类,对于点 $k$ ,如果它的子树(包括它) $a[i]$ 之和为 $S$ ,最大值为 $mx$ ,注意由于只考虑在这个点的切换所以来自同一个子树的 $a[i]$ 都是等价的,也就是说这个最大值实际是子树的 $S$ 的最大值而不是单点的最大值,那么在这个点会产生的贡献最大是: $res[k]=\min(S-1,2\times(S-mx))$ ,也就是说如果不存在一个 $mx>\lfloor\frac S 2\rfloor$ 那么就可以把他们两两都不相邻,否则剩下的每个可以和它左边和右边产生贡献,这样就可以 $O(n)$ 求出整棵树的贡献,但是还要支持动态修改,但是并不用动态 $DP$ ,好像也不能用,仔细研究一下会发现如果某个子树 $i$ 满足 $S_i>\lfloor\frac{S_k}2\rfloor$ 那么子树 $i$ 内的单点增量是不影响 $k$ 的 $res$ 的,因为反正也是做差,那么我们可以借用 $LCT$ 的思想,把这样的一条链用一棵 $splay$ 来维护,不符合这样的用虚父亲也就是轻边来维护,那么我们就可以跳过大段的重边,也就是说一次单点增量只有轻边的父亲的 $res$ 会改变,还有一个更优秀的性质是按照这样类似轻重链剖分的做法也是满足树链剖分从某个点到根轻边个数不超过 $\log$ 个这个性质的,因为轻边是一定满足 $S_i\leqslant\lfloor\frac{S_k}2\rfloor$ ,那么我们就可以先一次 $dfs$ 求出初始答案,然后每次考虑答案的增量,也就是像 $access$ 那样不断往上跳,同时维护轻重链切换的关系以及答案的变化,由糖水不等式 $\frac{a+c}{b+c}>\frac a b$ 得这条链上是重儿子的还一定是重儿子,但可能别的链在虚边那里切换成了轻儿子,同时顺便用 $splay$ 来维护当前每个点的 $S_k$ ,这个只要子树加打标记即可,实现的时候和 $LCT$ 略有不同因为不能简单跳虚边,轻重边切换时一定要用它在原树上的儿子而不能用辅助树上的儿子。

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<stack>
#include<cctype>
#include<cstring>
using namespace std;
inline int rd()
{
register int res = 0,f = 1;
register char c = getchar();
while(!isdigit(c))
{
if(c == '-')f = -1;
c = getchar();
}
while(isdigit(c))
{
res = (res << 1) + (res << 3) + c - '0';
c = getchar();
}
return res * f;
}
int n,m;
#define MAXN 400010
long long a[MAXN];
struct edge
{
int to,nxt;
}e[MAXN << 1];
int edgenum = 0;
int lin[MAXN] = {0};
void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b]};lin[b] = edgenum;
return;
}
typedef long long ll;
struct node
{
int c[2],fa;
ll val,tag;
}t[MAXN];
bool isroot(int k){return (t[t[k].fa].c[0] != k && t[t[k].fa].c[1] != k);}
int id(int k){return (t[t[k].fa].c[0] == k ? 0 : 1);}
void connect(int k,int f,int p){t[k].fa = f;t[f].c[p] = k;return;}
void pushdown(int rt)
{
if(t[rt].tag != 0)
{
if(t[rt].c[0] != 0){t[t[rt].c[0]].tag += t[rt].tag;t[t[rt].c[0]].val += t[rt].tag;}
if(t[rt].c[1] != 0){t[t[rt].c[1]].tag += t[rt].tag;t[t[rt].c[1]].val += t[rt].tag;}
t[rt].tag = 0;
}
return;
}
void rotate(int x)
{
int y = t[x].fa,z = t[y].fa,fx = id(x),fy = id(y);
if(!isroot(y))t[z].c[fy] = x;
t[x].fa = z;
connect(t[x].c[fx ^ 1],y,fx);
connect(y,x,fx ^ 1);
return;
}
stack<int> s;
void splay(int x)
{
s.push(x);
for(int i = x;!isroot(i);i = t[i].fa)s.push(t[i].fa);
while(!s.empty()){pushdown(s.top());s.pop();}
while(!isroot(x))
{
int y = t[x].fa;
if(isroot(y)){rotate(x);break;}
if(id(x) == id(y)){rotate(y);rotate(x);}
else{rotate(x);rotate(x);}
}
return;
}
void set(int rt,ll k)
{
if(rt == 0)return;
t[rt].val += k;t[rt].tag += k;
return;
}
ll ans = 0;
bool vis[MAXN];
ll sum[MAXN];
ll res[MAXN];
void dfs(int k)
{
vis[k] = true;
sum[k] = a[k];
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to])continue;
dfs(e[i].to);
sum[k] += sum[e[i].to];
}
t[k].val = sum[k];
bool tag = false;
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to])continue;
if(sum[e[i].to] > sum[k] / 2)
{
splay(k);splay(e[i].to);
connect(e[i].to,k,1);
tag = true;
res[k] = (sum[k] - sum[e[i].to]) * 2;
}
else
{
t[e[i].to].fa = k;
}
}
if(a[k] > sum[k] / 2)
{
tag = true;
res[k] = (sum[k] - a[k]) * 2;
}
if(!tag)res[k] = sum[k] - 1;
ans += res[k];
vis[k] = false;
return;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;++i)a[i] = rd();
for(int i = 1;i < n;++i)add(rd(),rd());
dfs(1);
cout << ans << endl;
int x,y;
for(int i = 1;i <= m;++i)
{
x = rd();y = rd();
for(int i = x;i;i = t[i].fa)
{
splay(i);pushdown(i);set(t[i].c[0],y);t[i].val += y;
}
a[x] += y;
while(x)
{
splay(x);pushdown(x);
ans -= res[x];
int nxt = t[x].c[1];
while(nxt != 0 && t[nxt].c[0] != 0)
{
pushdown(nxt);
nxt = t[nxt].c[0];
}
if(nxt != 0 && t[nxt].val <= t[x].val / 2)nxt = t[x].c[1] = 0;
if(nxt != 0)res[x] = (t[x].val - t[nxt].val) * 2;
else if(a[x] > t[x].val / 2)res[x] = (t[x].val - a[x]) * 2;
else res[x] = t[x].val - 1;
ans += res[x];
while(t[x].c[0] != 0)x = t[x].c[0];
splay(x);
int f = t[x].fa;
if(f == 0)break;
splay(f);
if(t[x].val > t[f].val / 2)t[f].c[1] = x;
x = f;
}
printf("%lld\n",ans);
}
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡