卧薪尝胆,厚积薄发。
NOI2018 情报中心
Date: Wed Feb 13 08:15:57 CST 2019 In Category: NoCategory

Description:

给一棵树还有树上的若干条链,每条链有权值,边有边权,最大化链并的边权之和减去链的权值和。
$1\leqslant n\leqslant 1000233$

Solution:

分情况讨论,两条链 $LCA$ 相同和不同。
先考虑 $LCA$ 不同的情况,那么如果他们有交,一定是直上直下的一条链:
考虑这样对答案的贡献: $$ res=len[x]+len[y]-val[x]-val[y]-(dist[a]-dist[b=LCA(x)]) $$ 那么我们可以把每条链存在他的两个端点,然后从下往上线段树合并,线段树的下标为 $LCA$ 的深度,记: $$ \begin{align} fir[x]&=len[x]-val[x]\\ sec[x]&=len[x]-val[x]+dist[LCA(x)] \end{align} $$ 可以在 $dfs$ 到 $a$ 的时候计算,因此 $dist[a]$ 是常数,那么贡献就是: $$ \begin{align} res=&\\ &fir[x]+sec[y]&dep[LCA(x)]<dep[LCA(y)]\\ &sec[x]+fir[y]&dep[LCA(x)]>dep[LCA(y)] \end{align} $$ 由于线段树的下标是 $LCA$ 的深度,所以直接左右子树交叉统计答案就可以了。 因为 $a$ 和 $b$ 不能相等,而线段树的下标又恰好是 $dep[LCA]$ ,因此每次清空当前 $dep$ 这一位就行了。
再考虑 $LCA$ 相同的情况,它们的交形状不再是直上直下的一条链:
如果我们把 $c$ 和 $b$ 之间, $d$ 和 $e$ 之间的链补满,那么会发现就恰好是链并的两倍。
于是我们把每条链存在两个端点处,那么贡献就是: $$ \begin{align} res=&\frac{len[x]+len[y]+dist[c]+dist[d]-2\times dist[a]+dist(d,e)}2-val[x]-val[y]\\ =&\frac{(len[x]-2\times val[x]+dist[c])+(len[y]-2\times val[y]+dist[b])+dist(d,e)}2-dist[a] \end{align} $$ 那么我们可以枚举点 $a$ ,然后发现这其实就是一个 $w[a]+w[b]+dist(a,b)-$ 常数的形式,于是像通道那题在第二棵树上所做的那样树形 $DP$ 一下就可以了。但是我们还没有要求链的 $LCA$ 必须相同,于是我们把所有 $LCA$ 相同的链找出来,对每个建虚树就可以了。
因为要求链至少有一个公共边,所以 $dep[LCA]$ 必须小于 $dep[a]$ 。

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<cctype>
#include<cstring>
using namespace std;
typedef long long ll;
inline ll rd()
{
register ll res = 0,f = 1;register char c = getchar();
while(!isdigit(c)){if(c == '-')f = -1;c = getchar();}
while(isdigit(c)){res = (res << 1) + (res << 3) + c - '0';c = getchar();}
return res * f;
}
int n,m;
#define MAXN 50010
#define MAXM 100010
struct edge{int to,nxt,v;}e[MAXN << 1];
int edgenum = 0,lin[MAXN] = {0};
void add(int a,int b,int c)
{
e[++edgenum] = (edge){b,lin[a],c};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b],c};lin[b] = edgenum;
return;
}
ll d[MAXN];
int top[MAXN],dep[MAXN],son[MAXN],siz[MAXN],fa[MAXN],rnk[MAXN],tot = 0;
void dfs1(int k,int depth)
{
dep[k] = depth;siz[k] = 1;
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to == fa[k])continue;
d[e[i].to] = d[k] + e[i].v;fa[e[i].to] = k;
dfs1(e[i].to,depth + 1);siz[k] += siz[e[i].to];
if(son[k] == 0 || siz[e[i].to] > siz[son[k]])son[k] = e[i].to;
}
return;
}
void dfs2(int k,int tp)
{
top[k] = tp;rnk[k] = ++tot;
if(son[k] == 0)return;
dfs2(son[k],tp);
for(int i = lin[k];i != 0;i = e[i].nxt)
if(e[i].to != fa[k] && e[i].to != son[k])dfs2(e[i].to,e[i].to);
return;
}
int LCA(int a,int b)
{
while(top[a] != top[b]){if(dep[top[a]] < dep[top[b]])swap(a,b);a = fa[top[a]];}
return (dep[a] < dep[b] ? a : b);
}
ll dis(int a,int b){return d[a] + d[b] - 2 * d[LCA(a,b)];}
struct chain{int a,b,lca;ll l,v,fir,sec;}s[MAXM];
#define INF 0x3f3f3f3f3f3f3f3f
ll ans = -INF;
namespace SOLVE1
{
struct node{int lc,rc;ll fir,sec;node(){fir = sec = -INF;}}t[MAXN * 60];
int ptr = 0;
int newnode(){int k = ++ptr;t[k].lc = t[k].rc = 0;t[k].fir = t[k].sec = -INF;return k;}
int root[MAXN];
#define mid ((l + r) >> 1)
void insert(int &rt,int p,ll fir,ll sec,ll v,int l,int r)
{
if(rt == 0)rt = newnode();
if(l == r){t[rt].fir = max(t[rt].fir,fir);t[rt].sec = max(t[rt].sec,sec);return;}
if(p <= mid)insert(t[rt].lc,p,fir,sec,v,l,mid);
else insert(t[rt].rc,p,fir,sec,v,mid + 1,r);
ans = max(ans,t[t[rt].lc].fir + t[t[rt].rc].sec - v);
t[rt].fir = max(t[t[rt].lc].fir,t[t[rt].rc].fir);
t[rt].sec = max(t[t[rt].lc].sec,t[t[rt].rc].sec);
return;
}
int merge(int x,int y,ll v)
{
if(x == 0 || y == 0)return x + y;
ans = max(ans,t[t[x].lc].fir + t[t[y].rc].sec - v);
ans = max(ans,t[t[x].rc].sec + t[t[y].lc].fir - v);
t[x].fir = max(t[x].fir,t[y].fir);t[x].sec = max(t[x].sec,t[y].sec);
t[x].lc = merge(t[x].lc,t[y].lc,v);t[x].rc = merge(t[x].rc,t[y].rc,v);
return x;
}
void reset(int rt,int p,int l,int r)
{
if(rt == 0)return;
if(l == r){t[rt].fir = t[rt].sec = -INF;return;}
if(p <= mid)reset(t[rt].lc,p,l,mid);
else reset(t[rt].rc,p,mid + 1,r);
t[rt].fir = max(t[t[rt].lc].fir,t[t[rt].rc].fir);
t[rt].sec = max(t[t[rt].lc].sec,t[t[rt].rc].sec);
return;
}
void dfs(int k,int fa)
{
reset(root[k],dep[k],1,n);
for(int i = lin[k];i != 0;i = e[i].nxt)if(e[i].to != fa){dfs(e[i].to,k);reset(root[e[i].to],dep[k],1,n);}
for(int i = lin[k];i != 0;i = e[i].nxt)if(e[i].to != fa)root[k] = merge(root[k],root[e[i].to],d[k]);
return;
}
void solve()
{
ptr = 0;
for(int i = 1;i <= n;++i)root[i] = 0;
for(int i = 1;i <= m;++i)
{
if(s[i].a != s[i].lca)insert(root[s[i].a],dep[s[i].lca],s[i].fir,s[i].sec,d[s[i].a],1,n);
if(s[i].b != s[i].lca)insert(root[s[i].b],dep[s[i].lca],s[i].fir,s[i].sec,d[s[i].b],1,n);
}
dfs(1,0);
return;
}
}
namespace SOLVE2
{
struct query{int id,nxt;}q[MAXM];
int head[MAXN] = {0},qnum = 0;
void addq(int p,int a){q[++qnum] = (query){a,head[p]};head[p] = qnum;return;}
int v[MAXN];
struct dat
{
ll w;int i;
dat(ll w_ = 0,int i_ = 0){w = w_;i = i_;}
void init(){w = i = 0;}
};
pair<dat,dat> f[MAXN];
ll calc(dat a,dat b){if(a.i == 0 || b.i == 0)return -INF;else return a.w + b.w + dis(a.i,b.i);}
#define fi first
#define se second
void insert(pair<dat,dat> &f,dat k)
{
if(f.fi.i == 0){f.fi = k;return;}if(f.se.i == 0){f.se = k;return;}
pair<dat,dat> res = f;
if(calc(f.fi,k) > calc(res.fi,res.se))res = make_pair(f.fi,k);
if(calc(f.se,k) > calc(res.fi,res.se))res = make_pair(f.se,k);
f = res;
return;
}
pair<dat,dat> merge(pair<dat,dat> a,pair<dat,dat> b)
{
if(a.fi.i == 0)return b;if(b.fi.i == 0)return a;
ll v = -INF,l;pair<dat,dat> res;
if((l = calc(a.fi,a.se)) > v)res = make_pair(a.fi,a.se),v = l;
if((l = calc(b.fi,b.se)) > v)res = make_pair(b.fi,b.se),v = l;
if((l = calc(a.fi,b.fi)) > v)res = make_pair(a.fi,b.fi),v = l;
if((l = calc(a.fi,b.se)) > v)res = make_pair(a.fi,b.se),v = l;
if((l = calc(a.se,b.fi)) > v)res = make_pair(a.se,b.fi),v = l;
if((l = calc(a.se,b.se)) > v)res = make_pair(a.se,b.se),v = l;
return res;
}
ll calc(pair<dat,dat> a,pair<dat,dat> b)
{
ll v = -INF,l;
if((l = calc(a.fi,b.fi)) > v)v = l;
if((l = calc(a.fi,b.se)) > v)v = l;
if((l = calc(a.se,b.fi)) > v)v = l;
if((l = calc(a.se,b.se)) > v)v = l;
return v;
}
int p[MAXM * 2];
bool cmp_rnk(int a,int b){return rnk[a] < rnk[b];}
int stack[MAXN],top = 0;
int cursolve;
struct edge{int to,nxt;}e[MAXN];
int lin[MAXN] = {0},edgenum;
void add(int a,int b){e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;return;}
int t[MAXN];
void dp(int k)
{
t[++t[0]] = k;ll tp = ans;
if(f[k].se.i != 0 && k != cursolve)ans = max(ans,calc(f[k].fi,f[k].se) / 2 - d[k]);
for(int i = lin[k];i != 0;i = e[i].nxt)
{
dp(e[i].to);
if(k != cursolve)ans = max(ans,calc(f[k],f[e[i].to]) / 2 - d[k]);
f[k] = merge(f[k],f[e[i].to]);
}
return;
}
void solve(int k)
{
t[0] = 0;
cursolve = k;p[0] = 0;
int cnt = 0;
for(int i = head[k];i != 0;i = q[i].nxt)++cnt;
if(cnt == 0 || cnt == 1)return;
for(int i = head[k];i != 0;i = q[i].nxt)
{
int v = q[i].id;
insert(f[s[v].a],(dat){d[s[v].a] + s[v].l - 2 * s[v].v,s[v].b});
insert(f[s[v].b],(dat){d[s[v].b] + s[v].l - 2 * s[v].v,s[v].a});
p[++p[0]] = s[v].a;p[++p[0]] = s[v].b;
}
sort(p + 1,p + 1 + p[0]);p[0] = unique(p + 1,p + 1 + p[0]) - p - 1;
sort(p + 1,p + 1 + p[0],cmp_rnk);
top = 0;stack[++top] = p[1];
for(int i = 2;i <= p[0];++i)
{
int lca = LCA(p[i],stack[top]);
if(lca == stack[top]){stack[++top] = p[i];continue;}
while(top >= 2)
{
if(rnk[lca] < rnk[stack[top - 1]])add(stack[top - 1],stack[top]),--top;
else{add(lca,stack[top]);--top;break;}
}
if(top == 1 && rnk[lca] < rnk[stack[top]]){add(lca,stack[top]);stack[top] = lca;}
if(stack[top] != lca)stack[++top] = lca;stack[++top] = p[i];
}
while(top >= 2){add(stack[top - 1],stack[top]);--top;}
dp(stack[1]);
for(int i = 1;i <= t[0];++i)f[t[i]].fi.init(),f[t[i]].se.init();
edgenum = 0;for(int i = 1;i <= t[0];++i)lin[t[i]] = 0;
return;
}
void solve()
{
for(int i = 1;i <= n;++i)head[i] = 0;
qnum = 0;
for(int i = 1;i <= m;++i)if(s[i].a != s[i].b)addq(s[i].lca,i);
for(int i = 1;i <= n;++i)solve(i);
return;
}
}
void work()
{
scanf("%d",&n);
edgenum = 0;
for(int i = 1;i <= n;++i)son[i] = lin[i] = 0;
int a,b,c;
for(int i = 1;i < n;++i){a = rd();b = rd();c = rd();add(a,b,c);}
dfs1(1,1);dfs2(1,1);
scanf("%d",&m);
for(int i = 1;i <= m;++i)
{
s[i].a = rd();s[i].b = rd();s[i].v = rd();
s[i].lca = LCA(s[i].a,s[i].b);s[i].l = d[s[i].a] + d[s[i].b] - 2 * d[s[i].lca];
s[i].fir = s[i].l - s[i].v;s[i].sec = s[i].l - s[i].v + d[s[i].lca];
}
ans = -INF;
SOLVE1::solve();SOLVE2::solve();
if(ans >= -0x3f3f3f3f3f3f3f3f / 2)printf("%lld\n",ans);
else puts("F");
return;
}
int main()
{
int testcases;
scanf("%d",&testcases);
while(testcases--)work();
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡