Time Limit: 1000 ms Memory Limit: 256 MB
description
吐槽
fft除了模板以外的第一题!!高兴ovo
题目好长啊。。
好吧中途脑子秀逗了一直在想怎么用fft求卷积。。。服了我自己了qwq
正题
首先看到说回文字符串,那就。。。马拉车?但是马拉车只能求连续的回文串呀。。
于是我们就想能不能先求出没有任何限制的回文串的数量(记为\(ans1\))然后再用马拉车求出连续的回文串的数量(记为\(ans2\)),那么最终的\(ans = ans1 - ans2\)
现在问题就变成了怎么求没有任何限制的回文串的数量了
我们用\(f_i\)表示以\(i\)为中心的对称字符的对数,那么可以得到这样的一条式子:
\[
f_i = \lfloor\frac{1+\sum\limits_{j}[s_{i-j}=s_{i+j}]}{2}\rfloor
\]
这样一来显然\(ans1 = \sum\limits_{i=1}^{len}2^{f_i}-1\)(总共有\(f_i\)对,每对可以选或者不选,最后再把空的那个(就是全部都不选)减掉)
然后这题有个很妙的限制:一个字符串中只有两种字符\(a\)或者\(b\)
那么我们可以分别考虑\(a\)和\(b\)的贡献,然后相加得到总的贡献(也就是那坨sigma)
我们先考虑怎么算\(a\)的贡献(\(b\)的话其实一样的所以就只讨论一种情况了)
用一个数组\(A\)来记录字符串每一位的\(a\)的贡献
对于字符串的第\(i\)位,如果说是\(a\),我们将\(A_i\)赋为\(1\),否则为\(0\)
那么第\(i\)位和第\(j\)位这对字符串的贡献就为\(A_i * A_j\)
所以,我们如果把\(A\)看做一个多项式的系数矩阵,\(A * A\)算出来的就是\(a\)的贡献
多项式乘法,那显然用FFT就好了ovo
接着用同样的方式算出\(b\)的贡献,再和\(a\)的贡献相加,得到上面式子中的sigma,然后再算\(f\)就好啦
然后这题就很愉快地做完啦ovo
(不得不说用多项式乘法求贡献那步很妙啊ovo)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#define MOD 1000000007
#define ll long long
using namespace std;
const double pi=acos(-1);
const int MAXN=4*(1e5)+10;
struct cmplx
{
double a,b;
cmplx(){}
cmplx(double x,double y){a=x,b=y;}
friend cmplx operator + (cmplx x,cmplx y)
{return cmplx(x.a+y.a,x.b+y.b);}
friend cmplx operator - (cmplx x,cmplx y)
{return cmplx(x.a-y.a,x.b-y.b);}
friend cmplx operator * (cmplx x,cmplx y)
{return cmplx(x.a*y.a-x.b*y.b,x.a*y.b+x.b*y.a);}
}a[MAXN],b[MAXN];
char s[MAXN],s1[MAXN];
int f[MAXN],mx[MAXN],two[MAXN];
int rev[MAXN];
int n,m,tot,k;//k=mx
ll ans;
int fft(cmplx *a,int op);
int manacher(char *s1,int lens);
int main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
scanf("%s",s);
int len=strlen(s);
for (int i=0;i<len;++i){
if (s[i]=='a') a[i]=cmplx(1,0);
else a[i]=cmplx(0,0);
}
k=1; two[0]=1;
for (int i=1;i<=2*len;++i)
two[i]=two[i-1]*2%MOD;
while (k<2*len) k<<=1;
fft(a,1);
for (int i=0;i<k;++i) b[i]=a[i]*a[i];
memset(a,0,sizeof(a));
for (int i=0;i<len;++i){
if (s[i]=='b') a[i]=cmplx(1,0);
else a[i]=cmplx(0,0);
}
fft(a,1);
for (int i=0;i<k;++i) b[i]=b[i]+a[i]*a[i];
fft(b,-1);
ans=0;
for (int i=2;i<=2*len;++i){
f[i]+=(ll)(b[i-2].a/k+0.5);
f[i]=f[i]+1>>1;
}
for (int i=2;i<=2*len;++i) ans=(ans+two[f[i]]-1)%MOD;
printf("%lld\n",(ans+MOD-manacher(s,len))%MOD);
}
int fft(cmplx *a,int op)
{
int step,bit=0;
cmplx w_n,w,t,u;
for (int i=1;i<k;i<<=1,++bit);
rev[0]=0;
for (int i=0;i<k;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
for (int i=0;i<k;++i)
if (i<rev[i]) swap(a[i],a[rev[i]]);
for (int step=2;step<=k;step<<=1)
{
w_n=cmplx(cos(2*pi/step),op*sin(2*pi/step));
for (int st=0;st<k;st+=step)
{
w=cmplx(1,0);
for (int i=0;i<(step>>1);++i)
{
t=a[st+i+(step>>1)]*w;
u=a[st+i];
a[st+i]=u+t;
a[st+i+(step>>1)]=u-t;
w=w*w_n;
}
}
}
}
int manacher(char *s,int lens){
int k,p=0,ret=0,len=1;
s1[0]='$'; s1[1]='#';
for (int i=0;i<lens;++i)
s1[++len]=s[i],s1[++len]='#';
memset(mx,0,sizeof(mx));
for (int i=2;i<len;++i){
if (p>i)
mx[i]=min(mx[k*2-i],p-i);
else
mx[i]=1;
while (s1[i+mx[i]]==s1[i-mx[i]]) ++mx[i];
if (i+mx[i]>p) p=i+mx[i],k=i;
ret=(ret+mx[i]/2)%MOD;
}
return ret;
}