卧薪尝胆,厚积薄发。
Date: Tue Nov 27 08:34:48 CST 2018 In Category: NoCategory

Description:

给定你 $n$ 个字符串,询问每个字符串有多少子串是所有 $n$ 个字符串中至少 $k$ 个字符串的子串。
$1\leqslant n\leqslant 10^5$

Solution:

建出广义后缀自动机,为每个点标一下他属于哪个串,然后在 $Parent$ 树上 $set$ 合并,对 $siz\geqslant k$ 的点的权值标成 $s[i].maxl-s[s[i].par].maxl$ ,然后在 $Parent$ 树上从上往下递推出每个位置的值,最后对于每个串跑一边统计一下就行了。
突然发现好像不用 $set$ 合并,直接枚举每个字符串把他经过的所有位置的祖先都加一就行了?

Code:


#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cmath>
#include<set>
#include<cctype>
#include<cstring>
using namespace std;
int n,m;
#define MAXN 200010
struct node
{
int maxl,par;
set<int> s;
int tr[26];
}s[MAXN << 1];
int root = 1,ptr = 1,last = 1;
int newnode(int l){int k = ++ptr;s[k].maxl = l;return k;}
void extend(int k)
{
int p = last;
int np = newnode(s[p].maxl + 1);
for(;p && s[p].tr[k] == 0;p = s[p].par)s[p].tr[k] = np;
if(p == 0)s[np].par = root;
else
{
int q = s[p].tr[k];
if(s[p].maxl + 1 == s[q].maxl)s[np].par = q;
else
{
int nq = newnode(s[p].maxl + 1);
memcpy(s[nq].tr,s[q].tr,sizeof(s[q].tr));
s[nq].par = s[q].par;
s[q].par = s[np].par = nq;
for(;p && s[p].tr[k] == q;p = s[p].par)s[p].tr[k] = nq;
}
}
last = np;
return;
}
string str[MAXN];
struct edge
{
int to,nxt;
}e[MAXN];
int edgenum = 0;
int lin[MAXN] = {0};
void add(int a,int b)
{
e[++edgenum] = (edge){b,lin[a]};lin[a] = edgenum;
return;
}
int val[MAXN];
void dfs(int k)
{
for(int i = lin[k];i != 0;i = e[i].nxt)
{
dfs(e[i].to);
for(set<int>::iterator it = s[e[i].to].s.begin();it != s[e[i].to].s.end();++it)
{
s[k].s.insert(*it);
}
}
if(s[k].s.size() >= m)val[k] = s[k].maxl - s[s[k].par].maxl;
return;
}
void dfs2(int k)
{
for(int i = lin[k];i != 0;i = e[i].nxt)
{
val[e[i].to] += val[k];
dfs2(e[i].to);
}
return;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;++i)
{
cin >> str[i];
int l = str[i].length();
last = 1;
for(int x = 0;x < l;++x)extend(str[i][x] - 'a');
}
for(int i = 1;i <= n;++i)
{
int l = str[i].length(),cur = root;
for(int x = 0;x < l;++x)
{
cur = s[cur].tr[str[i][x] - 'a'];
s[cur].s.insert(i);
}
}
for(int i = 1;i <= ptr;++i)add(s[i].par,i);
dfs(1);
dfs2(1);
for(int i = 1;i <= n;++i)
{
int l = str[i].length(),cur = root;
int ans = 0;
for(int x = 0;x < l;++x)
{
cur = s[cur].tr[str[i][x] - 'a'];
ans += val[cur];
}
printf("%d ",ans);
}
cout << endl;
return 0;
}
Copyright © 2020 wjh15101051
ღゝ◡╹)ノ♡