卧薪尝胆,厚积薄发。
Mas的仙人掌
Date: Sat Mar 09 15:20:19 CST 2019 In Category: NoCategory

Description:

给一棵树和 $M$ 条非树边,每条边有 $p_i$ 的概率会掉落,问期望有多少条边只在一个简单环中。
$1\leqslant n\leqslant 10^6$

Solution1:

利用期望的线性性可以发现答案是和这条非树边对应的树链相交的边的概率之积乘上 $1-$ 这个树链的概率求和。
考虑一个非常暴力的分类讨论,因为原题要求边相交非常不好做,所以我们考虑转化,首先把链拆开两个直上直下的链,去掉 $LCA$ ,那么两条链相交当且仅当一个链的最高点在另一个链上,这样我们就成功把边相交转化成了点相交,然后分类讨论,假如当前询问的是链 $k$ ,一种情况是另一条链的最高点在 $k$ 的最低点到最高点之间,另一种情况是另一条链经过当前链的最高点,两种情况都可以树上差分做,注意可能两条链在 $LCA$ 处相交,因此再用一个 $map$ 记录每条链从哪两个子树中伸出来,然后每次把这部分贡献扣掉就行了。

Code1:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<map>
#include<vector>
#include<cctype>
#include<cstring>
using namespace std;
inline int rd()
{
register int res = 0,f = 1;register char c = getchar();
while(!isdigit(c)){if(c == '-')f = -1;c = getchar();}
while(isdigit(c))res = (res << 1) + (res << 3) + c - '0',c = getchar();
return res * f;
}
int n,m;
#define MAXN 1000010
struct edge
{
int to,nxt;
}e[MAXN << 1];
int edgenum = 0;
int lin[MAXN] = {0};
void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b]};lin[b] = edgenum;
return;
}
int f[MAXN][21],dep[MAXN];
void dfs(int k,int depth)
{
dep[k] = depth;
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to == f[k][0])continue;
f[e[i].to][0] = k;
dfs(e[i].to,depth + 1);
}
return;
}
int skip_to_dep(int a,int depth)
{
for(int k = 20;k >= 0;--k)if(dep[f[a][k]] >= depth)a = f[a][k];
return a;
}
int LCA(int a,int b)
{
if(dep[a] < dep[b])swap(a,b);
a = skip_to_dep(a,dep[b]);
if(a == b)return a;
for(int k = 20;k >= 0;--k)
if(f[a][k] != f[b][k])a = f[a][k],b = f[b][k];
return f[a][0];
}
struct edges
{
int u,v,w;
}es[MAXN];
#define MOD 998244353
int power(int a,int b)
{
int res = 1;
for(;b > 0;b = b >> 1,a = 1ll * a * a % MOD)if(b & 1)res = 1ll * res * a % MOD;
return res;
}
int inver(int a){return power(a,MOD - 2);}
struct inte
{
int val,cnt;
inte(int v = 0,int c = 0){val = c;cnt = c;}
int calc(){return (cnt > 0 ? 0 : val);}
}val1[MAXN],val2[MAXN],val3[MAXN],val4[MAXN];
inte operator * (inte a,inte b){a.val = 1ll * a.val * b.val % MOD;a.cnt += b.cnt;return a;}
inte operator * (inte a,int b)
{
if(b == 0){++a.cnt;return a;}
else{a.val = 1ll * a.val * b % MOD;return a;};
}
inte operator / (inte a,int b)
{
if(b == 0){--a.cnt;return a;}
else{a.val = 1ll * a.val * inver(b) % MOD;return a;};
}
inte operator / (inte a,inte b)
{
a.val = 1ll * a.val * inver(b.val) % MOD;a.cnt -= b.cnt;
return a;
}
#define pii pair<int,int>
#define smp(a,b) make_pair(min(a,b),max(a,b))
map<pii,inte> p;
void push(int k)
{
for(int i = lin[k];i != 0;i = e[i].nxt)
{
if(e[i].to == f[k][0])continue;
val3[e[i].to] = val3[e[i].to] * val3[k];
val4[e[i].to] = val4[e[i].to] * val4[k];
push(e[i].to);
val1[k] = val1[k] * val1[e[i].to];
val2[k] = val2[k] * val2[e[i].to];
}
return;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i = 1;i < n;++i)add(rd(),rd());
dfs(1,1);
for(int k = 1;k <= 20;++k)
for(int i = 1;i <= n;++i)
f[i][k] = f[f[i][k - 1]][k - 1];
for(int i = 1;i <= n;++i)val1[i].val = val2[i].val = val3[i].val = val4[i].val = 1;
inte sigma;sigma.val = 1;
for(int i = 1;i <= m;++i)
{
es[i].u = rd();es[i].v = rd();es[i].w = rd();
sigma = sigma * es[i].w;
int lca = LCA(es[i].u,es[i].v);
int s1 = skip_to_dep(es[i].u,dep[lca] + 1);
int s2 = skip_to_dep(es[i].v,dep[lca] + 1);
if(es[i].u != lca)
{
val1[s1] = val1[s1] * es[i].w;
val2[es[i].u] = val2[es[i].u] * es[i].w;
val3[es[i].u] = val3[es[i].u] * es[i].w;
val4[s1] = val4[s1] * es[i].w;
}
if(es[i].v != lca)
{
val1[s2] = val1[s2] * es[i].w;
val2[es[i].v] = val2[es[i].v] * es[i].w;
val3[es[i].v] = val3[es[i].v] * es[i].w;
val4[s2] = val4[s2] * es[i].w;
}
if(es[i].u != lca && es[i].v != lca)
{
if(p.find(smp(s1,s2)) == p.end())
{
if(es[i].w != 0)p[smp(s1,s2)].val = es[i].w;
else p[smp(s1,s2)].val = 1,p[smp(s1,s2)].cnt = 1;
}
else p[smp(s1,s2)] = p[smp(s1,s2)] * es[i].w;
}
}
push(1);
int ans = 0;
for(int i = 1;i <= m;++i)
{
inte sum1,sum2;sum1.val = sum2.val = 1;
int lca = LCA(es[i].u,es[i].v);
int s1 = skip_to_dep(es[i].u,dep[lca] + 1);
int s2 = skip_to_dep(es[i].v,dep[lca] + 1);
if(es[i].u != lca)sum1 = val4[es[i].u] / val4[lca] * val2[s1] / val1[s1];
if(es[i].v != lca)sum2 = val4[es[i].v] / val4[lca] * val2[s2] / val1[s2];
inte sum;sum.val = 1;
if(es[i].u != lca && es[i].v != lca)sum = sum / p[smp(s1,s2)];
sum = sum * sum1 * sum2;
sum = sum / es[i].w;
sum = sum * (1 - es[i].w + MOD);
ans = (ans + sum.calc()) % MOD;
}
cout << ans << endl;
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡