【传送门:BZOJ2434】
简要题意:
给出一个模式串
题解:
处理x字符串在y字符串出现的次数,很容易想到fail树
一开始想着把y字符串的结尾字符在trie树上的位置开始,往上找,找到的点的fail指针如果指向x字符串的结尾字符的话,ans就++
但是这样做的时间复杂度是O(mn),显然会超时
这时,就要想更快的离线的方法...
不会!!!!
果断膜题解(以下来自神犇)
发现可以利用fail树上的一个节点及其子树在dfs序中是连续的一段,那么我们可以用一个树状数组来维护x串末尾节点及其子树上有多少个属于y串的节点,那么我们可以得到一个离线算法:对fail树遍历一遍,得到一个dfs序,再维护一个树状数组,对原trie树进行遍历,每访问一个节点,就修改树状数组,对树状数组中该节点的dfs序起点的位置加上1,每往回走一步,就减去1。如果访问到了一个y字串的末尾节点,枚举询问中每个y串对应的x串,查询树状数组中x串末尾节点从dfs序中的起始位置到结束位置的和,并记录答案
参考代码:
#include<queue> #include<cstdio> #include<cstring> #include<algorithm> #include<cstdlib> #include<cmath> using namespace std; struct node { int c[27],fail,f; node() { fail=f=0; memset(c,-1,sizeof(c)); } }t[110000]; int tot,n; char st[110000]; int s[110000]; int end[110000]; void bt(int root,int z) { int x=root,len=strlen(st+1); for(int i=1;i<=len;i++) { int y=st[i]-‘a‘+1; if(t[x].c[y]==-1) { t[x].c[y]=++tot; } x=t[x].c[y];s[x]++; } end[z]=x; } struct edge { int x,y,next; }a[110000];int len,last[110000]; void ins(int x,int y) { len++; a[len].x=x;a[len].y=y; a[len].next=last[x];last[x]=len; } queue<int> q; void bfs() { int x; q.push(0); while(q.empty()==0) { x=q.front(); for(int i=1;i<=26;i++) { int son=t[x].c[i]; if(son==-1)continue; if(x==0) t[son].fail=0; else { int j=t[x].fail; while(j!=0&&t[j].c[i]==-1) j=t[j].fail; t[son].fail=max(t[j].c[i],0); } ins(t[son].fail,son); q.push(son); } q.pop(); } } int l[110000],r[110000],z; void dfs(int x) { l[x]=++z; for(int k=last[x];k;k=a[k].next) { int y=a[k].y; dfs(y); } r[x]=z; } int d[110000]; int lowbit(int x){return x&-x;} int getsum(int x) { int ans=0; while(x!=0) { ans+=d[x]; x-=lowbit(x); } return ans; } void change(int x,int c) { while(x<=z) { d[x]+=c; x+=lowbit(x); } } struct qn { int x,y,id; }p[110000]; bool cmp(qn n1,qn n2) { return n1.y<n2.y; } int ans[110000]; int main() { tot=0; scanf("%s",st+1); n=strlen(st+1); int x=0; for(int i=1;i<=n;i++) { if(st[i]==‘P‘) end[++len]=x; else if(st[i]==‘B‘) x=t[x].f; else { int y=st[i]-‘a‘+1; if(t[x].c[y]==-1) { t[x].c[y]=++tot; t[tot].f=x; s[x]++; } x=t[x].c[y]; } } len=0;memset(last,0,sizeof(last)); bfs(); z=0; dfs(0); int m;scanf("%d",&m); for(int i=1;i<=m;i++){scanf("%d%d",&p[i].x,&p[i].y);p[i].id=i;} sort(p+1,p+m+1,cmp); int k=1,cnt=0;x=0; for(int i=1;i<=n;i++) { if(st[i]==‘P‘) { cnt++; while(cnt==p[k].y) { ans[p[k].id]=getsum(r[end[p[k].x]])-getsum(l[end[p[k].x]]-1); k++; } } else if(st[i]==‘B‘) { change(l[x],-1); x=t[x].f; } else { x=t[x].c[st[i]-‘a‘+1]; change(l[x],1); } } for(int i=1;i<=m;i++) printf("%d\n",ans[i]); return 0; }