卧薪尝胆,厚积薄发。
CTSC2010 珠宝商
Date: Tue Mar 26 15:39:10 CST 2019 In Category: NoCategory

Description:

有一棵 $n$ 个节点的树和一个长度为 $m$ 的字符串 $S$ ,树上每个节点有一个字符。问对于任意的有序数对 $(x,y)$ ,从 $x$ 到 $y$ 路径组成的字符串在 $S$ 中出现次数的和。
$1\leqslant n,m\leqslant50000$

Solution:

首先考虑两个暴力:
1、从每个点开始 $dfs$ ,顺便在 $SAM$ 上匹配,统计答案。
2、考虑所有以 $k$ 为 $LCA$ 的串,我们把路径拆成 $x\to k,k\to y$ ,我们建立正串和反串的后缀树,然后把它在后缀树上跑,把跑到的位置打一个标记,最后再 $dfs$ 整棵后缀树把所有标记推到叶子,然后再把对应前缀和后缀的答案乘起来累加到答案里,这个可以用点分治优化。
然后我们根号分治,在点分治之前先统计一下连通块大小,如果小于 $\sqrt n$ 就用第一种暴力,否则用第二种暴力。

Code:


// luogu-judger-enable-o2
#include
#include
#include
#include
#include
#include
#include
using namespace std;
inline int rd()
{
register int res = 0,f = 1;register char c = getchar();
while(!isdigit(c)){if(c == '-')f = -1;c = getchar();}
while(isdigit(c))res = (res << 1) + (res << 3) + c - '0',c = getchar();
return res * f;
}
#define I inline
#define R register
I char getc()
{
R char c = getchar();
while(!islower(c))c = getchar();
return c;
}
int n,m;
#define MAXN 50010
struct edge
{
int to,nxt;
}e[MAXN << 1];
int edgenum = 0;
int lin[MAXN] = {0};
I void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
e[++edgenum] = (edge){a,lin[b]};lin[b] = edgenum;
return;
}
struct SAM
{
struct node
{
int tr[26],par,maxl;
}s[MAXN << 1];
int root,last,ptr;
I int newnode(int l){int k = ++ptr;s[k].maxl = l;return k;}
SAM(){ptr = last = root = 1;}
int siz[MAXN << 1],rig[MAXN << 1];
char str[MAXN];
int loc[MAXN];
I void extend(int k,int pos)
{
str[pos] = k;
R int p = last;
R int np = newnode(s[p].maxl + 1);
loc[pos] = np;
++siz[np];rig[np] = pos;
for(;p && s[p].tr[k] == 0;p = s[p].par)s[p].tr[k] = np;
if(p == 0)s[np].par = root;
else
{
R int q = s[p].tr[k];
if(s[p].maxl + 1 == s[q].maxl)s[np].par = q;
else
{
R int nq = newnode(s[p].maxl + 1);
memcpy(s[nq].tr,s[q].tr,sizeof(s[q].tr));
s[nq].par = s[q].par;
s[q].par = s[np].par = nq;
for(;p && s[p].tr[k] == q;p = s[p].par)s[p].tr[k] = nq;
}
}
last = np;
return;
}
node t[MAXN << 1];
int c[MAXN << 1],p[MAXN << 1];
int tag[MAXN << 1];
I void buildsuffixtree()
{
for(R int i = 1;i <= ptr;++i)++c[s[i].maxl];
for(R int i = 1;i <= m;++i)c[i] += c[i - 1];
for(R int i = ptr;i >= 1;--i)p[c[s[i].maxl]--] = i;
for(R int i = ptr;i >= 1;--i)
{
R int k = p[i],fa = s[k].par;
siz[fa] += siz[k];
rig[fa] = max(rig[fa],rig[k]);
t[fa].tr[str[rig[k] - s[fa].maxl]] = k;
}
for(R int i = 1;i <= ptr;++i)t[i].par = s[i].par,t[i].maxl = s[i].maxl;
return;
}
I int go(int cur,char c,int pos)
{
if(pos > s[cur].maxl)cur = t[cur].tr[c - 'a'];
else if(c - 'a' != str[rig[cur] - pos + 1])cur = 0;
return cur;
}
void pushtag(int rt)
{
for(R int i = 0;i < 26;++i)
{
if(t[rt].tr[i] == 0)continue;
tag[t[rt].tr[i]] += tag[rt];
pushtag(t[rt].tr[i]);
}
return;
}
}S,S1,S2;
char c[MAXN],s[MAXN];
int root,si,siz[MAXN],d[MAXN];
bool vis[MAXN];
#define INF 0x3f3f3f3f
void getroot(int k,int fa = 0)
{
siz[k] = 1;d[k] = 0;
for(R int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to] || e[i].to == fa)continue;
getroot(e[i].to,k);
siz[k] += siz[e[i].to];
d[k] = max(d[k],siz[e[i].to]);
}
d[k] = max(d[k],si - siz[k]);
if(d[k] < d[root])root = k;
return;
}
int getsiz(int k,int fa = 0)
{
siz[k] = 1;
for(R int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to] || e[i].to == fa)continue;
getsiz(e[i].to,k);
siz[k] += siz[e[i].to];
}
return siz[k];
}
long long ans = 0;
int B;
int v[MAXN];
void getpoint(int k,int fa = 0)
{
v[++v[0]] = k;
for(R int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to] || e[i].to == fa)continue;
getpoint(e[i].to,k);
}
return;
}
void calc(int k,int cur,int fa)
{
cur = S.s[cur].tr[c[k] - 'a'];
if(cur == 0)return;
ans += S.siz[cur];
for(R int i = lin[k];i != 0;i = e[i].nxt)
{
if(!vis[e[i].to] && e[i].to != fa)calc(e[i].to,cur,k);
}
return;
}
void dfs(int k,int fa,int cur1,int cur2,int dep)
{
if(cur1)cur1 = S1.go(cur1,c[k],dep + 1);
if(cur2)cur2 = S2.go(cur2,c[k],dep + 1);
if(cur1 == 0 && cur2 == 0)return;
if(cur1)++S1.tag[cur1];if(cur2)++S2.tag[cur2];
for(R int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to] || e[i].to == fa)continue;
dfs(e[i].to,k,cur1,cur2,dep + 1);
}
return;
}
I void solve2(int k,int cur1,int cur2,int d,int f)
{
dfs(k,0,cur1,cur2,d);
S1.pushtag(1);S2.pushtag(1);
for(R int i = 1;i <= m;++i)ans += 1ll * f * S1.tag[S1.loc[i]] * S2.tag[S2.loc[m - i + 1]];
for(R int i = 1;i <= S1.ptr;++i)S1.tag[i] = 0;
for(R int i = 1;i <= S2.ptr;++i)S2.tag[i] = 0;
return;
}
void divide(int k)
{
R int sizk = getsiz(k);
if(sizk <= B)
{
v[0] = 0;
getpoint(k);
for(R int i = 1;i <= v[0];++i)calc(v[i],1,0);
for(R int i = 1;i <= v[0];++i)vis[v[i]] = true;
return;
}
else
{
vis[k] = true;
solve2(k,1,1,0,1);
for(R int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to])continue;
solve2(e[i].to,S1.go(1,c[k],1),S2.go(1,c[k],1),1,-1);
}
for(R int i = lin[k];i != 0;i = e[i].nxt)
{
if(vis[e[i].to])continue;
root = 0;si = getsiz(e[i].to);
getroot(e[i].to);
divide(root);
}
}
return;
}
int main()
{
scanf("%d%d",&n,&m);
B = sqrt(n) * 7;
for(R int i = 1;i < n;++i)add(rd(),rd());
for(R int i = 1;i <= n;++i)c[i] = getc();
for(R int i = 1;i <= m;++i)s[i] = getc();
for(R int i = 1;i <= m;++i)S.extend(s[i] - 'a',i);
for(R int i = 1;i <= m;++i)S2.extend(s[i] - 'a',i);
for(R int i = m;i >= 1;--i)S1.extend(s[i] - 'a',m - i + 1);
S.buildsuffixtree();
S1.buildsuffixtree();S2.buildsuffixtree();
root = 0;si = n;d[0] = INF;
getroot(1);
divide(root);
cout << ans << endl;
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡