卧薪尝胆,厚积薄发。
操作
Date: Sun Mar 17 16:10:37 CST 2019 In Category: NoCategory

Description:

有 $n$ 个操作和变量 $x=0$ ,第 $i$ 个操作为:以 $p_i$ 的概率给 $x$ 加上 $a_i$ , $1−p_i$ 的概率给 $x$ 乘上 $b_i$ 。随机生成一个长度为 $n$ 的排列 $C$ ,依次执行第 $C_1,C_2,\dots,C_n$ 个操作。求执行完所有操作后,变量 $x$ 的期望模 $998244353$ 的值。
$1\leqslant n\leqslant 10^5$

Solution:

发现这个问题很不好下手,因为操作是确定的,但是操作的执行顺序是不定的,不过题解给我们提供了一个很神奇的做法:
首先假如当前的值为 $x$ ,那么一次操作之后的结果为 $p_i\times (x+a_i)+(1-p_i)\times b_ix=(p_i+b_i-p_ib_i)x+p_ia_i$ ,我们可以把它看成是把 $x$ 带入了一个一次函数 $x'=xD_i+E_i$ ,那么我们可以转而计算 $E_i$ 被加的系数,如果操作 $j$ 在操作 $i$ 之后,那么 $E_i$ 就会被乘上 $D_j$ ,否则不变,那么我们计算 $\prod_{j\ne i}1+D_jx$ ,这个多项式的第 $k$ 次项系数乘上 $k!(n-k-1)!$ 就是在 $i$ 后面有 $k$ 个操作的所有方案被乘上的所有系数和,那么我们乘上 $E_i$ 就是这个操作的贡献,于是我们要求的就是所有这个多项式之和,可以用一个神奇的分治 $FFT$ ,具体来说就是每个区间维护两个多项式 $ans$ 和 $tmp$ , $ans$ 就是答案, $tmp=\prod (1+D_i)x$ ,那么显然 $tmp$ 就是两边卷起来, $ans=ans_L\times tmp_R+tmp_L\times ans_R$ 。 $$ \begin{align} ans_{[l,r]}&=\sum_{i=l}^rE_i\prod_{j=l}^r[j\ne i](1+D_jx)\\ tmp_{[l,r]}&=\prod_{i=l}^r(1+D_ix)\\ \end{align} $$

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<cctype>
#include<cstring>
using namespace std;
int n;
#define MAXN 400010
int p[MAXN],a[MAXN],b[MAXN];
int d[MAXN],e[MAXN];
#define P 998244353
int ans[20][MAXN],tmp[20][MAXN];
int fac[MAXN];
int rev[MAXN];
int ww[MAXN << 1],*g = ww + MAXN;
int power(int a,int b)
{
int res = 1;
for(;b > 0;b = b >> 1,a = 1ll * a * a % P)if(b & 1)res = 1ll * res * a % P;
return res;
}
int init(int n)
{
int l = 0,len = 1;
for(;len <= n;len = len << 1)++l;
for(int i = 0;i < len;++i)rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
g[0] = g[-len] = 1;g[1] = g[1 - len] = power(3,(P - 1) / len);
for(int i = 2;i < len;++i)g[i] = g[i - len] = 1ll * g[i - 1] * g[1] % P;
return len;
}
void NTT(int f[],int l,int type)
{
for(int i = 0;i < l;++i)
{
if(i < rev[i])swap(f[i],f[rev[i]]);
}
for(int i = 2;i <= l;i = i << 1)
{
int wn = g[type * l / i];
for(int j = 0;j < l;j += i)
{
int w = 1;
for(int k = j;k < j + i / 2;++k)
{
int u = f[k],t = 1ll * w * f[k + i / 2] % P;
f[k] = (u + t) % P;
f[k + i / 2] = (u - t + P) % P;
w = 1ll * w * wn % P;
}
}
}
if(type == -1)
{
int ni = power(l,P - 2);
for(int i = 0;i < l;++i)f[i] = 1ll * f[i] * ni % P;
}
return;
}
void solve(int s,int l,int r)
{
if(l == r)
{
tmp[s][0] = 1;tmp[s][1] = d[l];
ans[s][0] = e[l];
return;
}
int mid = ((l + r) >> 1);
solve(s + 1,l,mid);
for(int i = 0;i <= mid - l + 1;++i)tmp[s][i] = tmp[s + 1][i],ans[s][i] = ans[s + 1][i];
for(int i = 0;i <= mid - l + 1;++i)tmp[s + 1][i] = ans[s + 1][i] = 0;
solve(s + 1,mid + 1,r);
int len = init(r - l + 1);
NTT(tmp[s],len,1);NTT(tmp[s + 1],len,1);NTT(ans[s],len,1);NTT(ans[s + 1],len,1);
for(int i = 0;i < len;++i)ans[s][i] = (1ll * ans[s][i] * tmp[s + 1][i] % P + 1ll * ans[s + 1][i] * tmp[s][i] % P) % P;
for(int i = 0;i < len;++i)tmp[s][i] = 1ll * tmp[s][i] * tmp[s + 1][i] % P;
NTT(tmp[s],len,-1);NTT(ans[s],len,-1);
for(int i = 0;i < len;++i)tmp[s + 1][i] = ans[s + 1][i] = 0;
return;
}
int main()
{
scanf("%d",&n);
for(int i = 1;i <= n;++i)
{
scanf("%d%d%d",&p[i],&a[i],&b[i]);
d[i] = ((p[i] + b[i]) % P - 1ll * p[i] * b[i] % P + P) % P;
e[i] = 1ll * p[i] * a[i] % P;
}
fac[0] = 1;
for(int i = 1;i <= n;++i)fac[i] = 1ll * fac[i - 1] * i % P;
solve(0,1,n);
int res = 0;
for(int i = 0;i <= n;++i)res = (res + 1ll * ans[0][i] * fac[i] % P * fac[n - 1 - i] % P) % P;
for(int i = 1;i <= n;++i)res = 1ll * res * power(i,P - 2) % P;
cout << res << endl;
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡