AC自动机总结
Posted mikufun-hzoi-cpp
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了AC自动机总结相关的知识,希望对你有一定的参考价值。
用了将近一周的时间,总算把AC自动机后面四道dp做完了
先说一下总体感受:全是套路
AC自动机的题dp一般就是第一维表长度,第二维表节点,然后从父亲转移到儿子(当然偶尔有例外)
而且做完之后发现AC自动机建trie树完全没卵用,几乎都得用到trie图(trie树会各种re)
来说一下做这个专题的经历
前三道题就是模板,没啥好说的
然后第四道题卡了一会
D:病毒
题目大意:给定几个串,问是否存在一个无限长的串不包含其中任何一个串
题解:在AC自动机上跑dfs,在不经过危险节点的情况下,看能不能跑出来环
这里包含了一个重要性质:在AC自动机上跑dfs(不加限制),能遍历所有可能的串,这将成为接下来几乎所有题的突破口
#include <queue>
#include <cstdlib>
#include <cmath>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define SI 2
using namespace std;
int cnt;
struct node
node *fi;
node *tr[SI];
int c,is,are;
node()
fi=NULL;
is=1;are=0;
c=++cnt;
memset(tr,NULL,sizeof(tr));
;
void insert(char *s,node *root)
node *p=root;
int in,len=strlen(s+1);
for(int i=1;i<=len;i++)
in=s[i]-‘0‘;
if(p->tr[in]==NULL) p->tr[in]=new node();
p=p->tr[in];
p->are=1;
void getFail(node *root)
int be=0,en=0;
queue<node*>q;
for(int i=0;i<SI;i++)
if(root->tr[i]!=NULL)
root->tr[i]->fi=root;
q.push(root->tr[i]);
else root->tr[i]=root;
while(!q.empty())
node *now =q.front();q.pop();
for(int i=0;i<SI;i++)
if(now->tr[i]!=NULL)
now->tr[i]->fi=now->fi->tr[i];
if(now->fi->tr[i]->are) now->tr[i]->are=1;
q.push(now->tr[i]);
else
now->tr[i]=now->fi->tr[i];
int v[40010],vs[40010];
void dfs(node *root)
if(v[root->c])
printf("TAK");
exit(0);
if(root->are||vs[root->c]) return;
vs[root->c]=1;
v[root->c]=1;
for(int i=0;i<SI;i++)
if(root->tr[i]->c!=1)dfs(root->tr[i]);
v[root->c]=0;
int main()
int n;char s[30010];
int T;
node *root=new node();
scanf("%d",&n);
for(int i=0;i<n;i++)
scanf("%s",s+1);
insert(s,root);
getFail(root);
dfs(root);
printf("NIE");
return 0;
E:最短母串
题目大意:给定几个串,找到一个包含所有串的最短的串
题解:同样在AC自动机上跑dfs并用二进制数表示经过了哪几个串,最早找到的一个经历了所有串的串就是答案
弄起来很简单,就是输出串有点恶心
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#define re register
#define null NULL
using namespace std;
int tot;
struct node
node *ch[26],*fi,*fa;
int val,c;char has;
node()
memset(ch,null,sizeof ch);
fi=null,val=0;c=++tot;
*rt=new node(),*lnode[5100][5000];
char s[15][100];
void insert(int x)
node *now=rt;int len=strlen(s[x]+1);
for(int i=1;i<=len;i++)
int in=s[x][i]-‘A‘;
if(now->ch[in]==null) now->ch[in]=new node();
now=now->ch[in];
now->has=in+‘A‘;
now->val|=(1<<(x-1));
inline void getfail()
queue<node*>q;
for(int i=0;i<26;i++)
if(rt->ch[i]!=null)
rt->ch[i]->fi=rt;
q.push(rt->ch[i]);
else rt->ch[i]=rt;
while(!q.empty())
node *now=q.front();q.pop();
for(int i=0;i<26;i++)
if(now->ch[i]!=null)
now->ch[i]->fi=now->fi->ch[i];
now->ch[i]->val|=now->ch[i]->fi->val;
q.push(now->ch[i]);
else now->ch[i]=now->fi->ch[i];
int f[5100][5000],lhas[5100][5000],top,maxn;char b[6000];
void bfs()
memset(f,0x3f,sizeof f);
queue<node*>q;
queue<int>p;
q.push(rt);p.push(0);
f[0][0]=0;
while(!q.empty())
node *now=q.front();int ans=p.front();
q.pop();p.pop();
if(ans==maxn)
node *is=now;
while(is!=rt)
b[++top]=is->has;
node *x=is;int a=ans;
is=lnode[x->c][a];ans=lhas[x->c][a];
for(int i=top;i>0;i--) putchar(b[i]);
return;
for(int i=0;i<26;i++)
int are=ans|now->ch[i]->val;
if(f[now->ch[i]->c][are]==0x3f3f3f3f)
//cout<<1<<endl;
q.push(now->ch[i]);p.push(are);
f[now->ch[i]->c][are]=f[now->c][ans]+1;
lnode[now->ch[i]->c][are]=now;
lhas[now->ch[i]->c][are]=ans;
int main()
// freopen("a.in","r",stdin);
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%s",s[i]+1),insert(i);
maxn=(1<<n)-1;
getfail();
bfs();
return 0;
F:文本编译器
题目大意:给定几个串,问包含至少一个串且长度为len的有多少种
题解:经典的AC自动机dp,dp[i][j]表示长度为i在第j个节点,不经过给定串的种数,答案26^len-dp[len][j]
#include<cstdio>
#include<cstring>
#include<iostream>
#include<queue>
#define SI 26
using namespace std;
const int mod=10007;
int cnt,dp[102][50000];
struct node
node *fi;
node *tr[SI];
int c,are;
node()
fi=NULL;
are=0;
c=cnt;
memset(tr,NULL,sizeof(tr));
*all[3000000];
void insert(char *s,node *root)
node *p=root;
int in,len=strlen(s+1);
for(int i=1;i<=len;i++)
in=s[i]-‘A‘;
if(p->tr[in]==NULL) all[++cnt]=p->tr[in]=new node();
p=p->tr[in];
p->are=1;
void getFail(node *root)
int be=0,en=0;
queue<node*>q;
for(int i=0;i<SI;i++)
if(root->tr[i]!=NULL)
root->tr[i]->fi=root;
q.push(root->tr[i]);
else root->tr[i]=root;
while(!q.empty())
node *now =q.front();q.pop();
for(int i=0;i<SI;i++)
if(now->tr[i]!=NULL)
now->tr[i]->fi=now->fi->tr[i];
if(now->fi->tr[i]->are) now->tr[i]->are=1;
q.push(now->tr[i]);
else
now->tr[i]=now->fi->tr[i];
inline int pow(int a,int b)
int ans=1;
for(;b;b>>=1)
if(b&1) (ans*=a)%=mod;
(a*=a)%=mod;
return ans;
int main()
int n,m;char s[100];
int T;
node *root=new node();
all[0]=root;
scanf("%d%d",&n,&m);
for(int i=0;i<n;i++)
scanf("%s",s+1);
insert(s,root);
getFail(root);
dp[0][0]=1;
for(int i=1;i<=m;i++)
for(int j=0;j<=cnt;j++)
for(int k=0;k<SI;k++)
if(!all[j]->tr[k]->are)
(dp[i][all[j]->tr[k]->c]+=dp[i-1][j])%=mod;
int ans=0;
for(int i=0;i<=cnt;i++)
(ans+=dp[m][i])%=mod;
printf("%d",(pow(26,m)-ans+mod)%mod);
return 0;
G:背单词
题目大意:给定一张包含N个单词的表,每个单词有个价值W。要求从中选出一个子序列使得其中的每个单词是后一个单词的子串,最大化子序列中W的和
题解: 首先仔细读题,要求的是子序列,所以必须按照给定顺序
我们很容易推出dp式子:dp[i]表示从后往前最后一个串选i时的最大和
显然dp[i]=max(dp[j])+sor[i](i是j的子串)
但是这样复杂度O(n^2),会TLE
我们先考虑卡常
用每一个串在AC自动机上匹配的时候,对于同一个节点的fail指针,当且仅当它在之前经过的节点后边时才入队
这样dp量大大减少,完全可以AC
然后是正解
对于每一个串,另一个串是它的子串当且仅当他上面的一个节点通过fail指针跳转可以跳到另一个串的结尾
这里我们引入一个重要性质:AC自动机上fail指针的反向构成了一棵树
所以另一个串是当前串的母串当且仅当另一个串上的点在当前串的子树上
则我们用fail树建dfs序,就可以用线段树维护区间最大值
每次修改的时候对这个串上的点暴力修改就可以,因为总长度不超过3×10^5
复杂度O(nlogn)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<vector>
#define null NULL
//const int L=1<<20|1;
//char buffer[L],*S,*T;
//#define getchar() ((S==T&&(T=(S=buffer)+fread(buffer,1,L,stdin),S==T))?EOF:*S++)
using namespace std;
string s[20010];
inline int read()
register int a(0);register char ch=getchar(),x=1;
while(ch<‘0‘||ch>‘9‘)
if(ch==‘-‘) x=-1;
ch=getchar();
while(ch>=‘0‘&&ch<=‘9‘)
a=(a<<3)+(a<<1)+ch-‘0‘;
ch=getchar();
return a*x;
inline void reads(int x)
char ch=getchar();s[x]="";
while(!((ch>=‘a‘)&&(ch<=‘z‘))) ch=getchar();
while((ch>=‘a‘)&&(ch<=‘z‘)) s[x]+=ch,ch=getchar();
struct node
node *fi,*ch[26];
int are,c;
vector<short>val;
node()
fi=null;
memset(ch,null,sizeof(ch));
are=c=0;
*root;
inline void insert(int x)
node *now=root;int len=s[x].length();
for(int i=0;i<len;i++)
int in=s[x][i]-‘a‘;
if(now->ch[in]==null) now->ch[in]=new node();
now=now->ch[in];
now->are=1;
now->val.push_back(x);
inline void getfail()
queue<node*>q;
for(int i=0;i<26;i++)
if(root->ch[i]!=null)
root->ch[i]->fi=root;
q.push(root->ch[i]);
else root->ch[i]=root;
while(!q.empty())
node *now=q.front();q.pop();
for(int i=0;i<26;i++)
if(now->ch[i]!=null)
now->ch[i]->fi=now->fi->ch[i];
q.push(now->ch[i]);
else now->ch[i]=now->fi->ch[i];
vector<short>p[20010];
inline void query(int x)
node *now=root;int len=s[x].length(),maxx=0;
for(int i=0;i<len;i++)
int out=s[x][i]-‘a‘;
if(now==null) now=root;
now=now->ch[out];
for(node *j=now;j!=null&&j->c!=x;j=j->fi)
maxx=0;
for(int k=0;k<j->val.size();k++)
if(j->val[k]<x&&j->val[k]>maxx) p[x].push_back(j->val[k]),maxx=j->val[k];
j->c=x;
int ans,dp[20010];
int sor[20010];
int main()
int T,n,ans;
T=read();
while(T--)
root=new node();
n=read();
for(int i=1;i<=n;i++)
reads(i);
scanf("%d",&sor[i]);
if(sor[i]>0) p[i].clear();
else i--,n--;
dp[i]=0;
for(int i=n;i>=1;i--) insert(i);
getfail();
for(int i=n;i>=1;i--) query(i);
ans=0;
for(int i=n;i>=1;i--)
dp[i]+=sor[i];ans=max(ans,dp[i]);
for(int j=0;j<p[i].size();j++)
dp[p[i][j]]=max(dp[p[i][j]],dp[i]);
printf("%d\n",ans);
return 0;
H:密码
题目大意:给定几个串,问包括所有串且长度为len的有多少种,当种数小于42时输出方案
题解:和最短母串很像,这次我们用dp
dp[i][j][s]表示长为i,在j节点,经过串方案为s的种数
dp[i][j][s]=dp[i-1][k][s^j->val]
但是方案不好输出因为如果方案太多肯定要MLE
我们可以在ans<=42时重新dp并求方案,用vector存就好了
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#define re register
#define ll long long
#define null NULL
using namespace std;
int tot,maxn;
struct node
node *ch[26],*fi,*fa;
int val,c;char has;
node()
memset(ch,null,sizeof ch);
fi=null,val=0;c=tot;
*rt,*h[102];
struct tree
int a;char b;
ks;
vector<tree>v[26][101][1025];
char s[15][15];ll dp[26][101][1024];
string are[50];
void insert(int x)
node *now=rt;int len=strlen(s[x]+1);
for(int i=1;i<=len;i++)
int in=s[x][i]-‘a‘;
if(now->ch[in]==null) h[++tot]=now->ch[in]=new node();
now=now->ch[in];
now->has=in+‘a‘;
now->val|=(1<<(x-1));
inline void getfail()
queue<node*>q;
for(int i=0;i<26;i++)
if(rt->ch[i]!=null)
rt->ch[i]->fi=rt;
q.push(rt->ch[i]);
else rt->ch[i]=rt;
while(!q.empty())
node *now=q.front();q.pop();
for(int i=0;i<26;i++)
if(now->ch[i]!=null)
now->ch[i]->fi=now->fi->ch[i];
now->ch[i]->val|=now->ch[i]->fi->val;
q.push(now->ch[i]);
else now->ch[i]=now->fi->ch[i];
int sum;
void dfs(int x,int len,int now)
for(int i=0;i<v[len][x][now].size();i++)
int u=sum+1;
if(len!=1) dfs(v[len][x][now][i].a,len-1,now^h[x]->val);
else are[++sum]="";
for(int j=u;j<=sum;j++)
are[j]+=v[len][x][now][i].b;
int main()
//freopen("a.in","r",stdin);
int n,l,x;h[0]=rt=new node();
scanf("%d%d",&l,&n);
for(int i=1;i<=n;i++) scanf("%s",s[i]+1),insert(i);
maxn=(1<<n)-1;
getfail();
dp[0][0][0]=1;
for(int i=1;i<=l;i++)
for(int s=0;s<=maxn;s++)
for(int j=0;j<=tot;j++)
if(dp[i-1][j][s])
for(int k=0;k<26;k++)
int now=s|h[j]->ch[k]->val;
dp[i][h[j]->ch[k]->c][now]+=dp[i-1][j][s];
ll ans=0;
for(int i=0;i<=tot;i++) ans+=dp[l][i][maxn];
printf("%lld\n",ans);
if(ans<=42)
memset(dp,0,sizeof dp);
dp[0][0][0]=1;
for(int i=1;i<=l;i++)
for(int s=0;s<=maxn;s++)
for(int j=0;j<=tot;j++)
if(dp[i-1][j][s])
for(int k=0;k<26;k++)
int now=s|h[j]->ch[k]->val;
dp[i][h[j]->ch[k]->c][now]+=dp[i-1][j][s];
ks.a=j;ks.b=k+‘a‘;
v[i][h[j]->ch[k]->c][now].push_back(ks);
for(int i=0;i<=tot;i++)
if(dp[l][i][maxn])
dfs(i,l,maxn);
sort(are+1,are+sum+1);
for(int i=1;i<=sum;i++)
cout<<are[i]<<endl;
return 0;
I:禁忌
题目大意:给定几个串,问长度为len的串里期望有几个给定串,串和串不能重合
题解:看到N<=5,状压?(串不能重合,没法往下转移,pass)
看到len<=109,暴力?)(其实是我10^9哒)
其实dp和之前的差不多,dp[i][j]表示长度为i,在j这个节点时的期望伤害
有dp[i][j]=dp[i-1][k]*a[k][j] (a[k][j]表示从k走到j期望)
这个式子满足用矩阵快速幂转移的要求,复杂度问题解决
但是我们这样没法求出答案(起码你不知道输出哪一个)
接下来是核心:设一个超级节点,每个串的结尾都可以转移到它,同时每个串的结尾都可以转移到根
其它点直接往儿子转移就好了
答案为ans[rt][super]
复杂度O(n^3loglen)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#define re register
#define ll long long
#define null NULL
using namespace std;
int tot,maxn,all;
struct node
node *ch[26],*fi,*fa;
int val,c;
node()
memset(ch,null,sizeof ch);
fi=null,val=0;c=tot;
*rt,*h[102];
char s[7][17];
void insert(int x)
node *now=rt;int len=strlen(s[x]+1);
for(int i=1;i<=len;i++)
int in=s[x][i]-‘a‘;
if(now->ch[in]==null) h[++tot]=now->ch[in]=new node();
now=now->ch[in];
if(now->val) return;
now->val=1;
struct m
double p[100][100];
m clear()
memset(p,0,sizeof p);
friend m operator *(m a,m b)
m c;c.clear();
for(int i=0;i<=tot;i++)
for(int j=0;j<=tot;j++)
for(int k=0;k<=tot;k++)
c.p[i][j]+=a.p[i][k]*b.p[k][j];
return c;
base,ans;
inline void getfail()
queue<node*>q;
for(int i=0;i<all;i++)
if(rt->ch[i]!=null)
rt->ch[i]->fi=rt;
q.push(rt->ch[i]);
else rt->ch[i]=rt;
while(!q.empty())
node *now=q.front();q.pop();
for(int i=0;i<all;i++)
if(now->ch[i]!=null)
now->ch[i]->fi=now->fi->ch[i];
now->ch[i]->val|=now->ch[i]->fi->val;
q.push(now->ch[i]);
else now->ch[i]=now->fi->ch[i];
inline void pow(int a)
for(int i=0;i<=tot;i++) ans.p[i][i]=1;
for(;a;a>>=1)
if(a&1) ans=ans*base;
base=base*base;
int main()
int n,len;
scanf("%d%d%d",&n,&len,&all);h[++tot]=rt=new node();
for(int i=1;i<=n;i++) scanf("%s",s[i]+1),insert(i);
getfail();base.clear();
double k=1.0/all;
for(int i=1;i<=tot;i++)
for(int j=0;j<all;j++)
if(h[i]->ch[j]->val)
base.p[i][0]+=k;
base.p[i][1]+=k;
else base.p[i][h[i]->ch[j]->c]+=k;
base.p[0][0]=1;
pow(len);
printf("%.8lf",ans.p[1][0]);
return 0;
以上是关于AC自动机总结的主要内容,如果未能解决你的问题,请参考以下文章