[SDOI2015]序列统计
很有趣的一道题目,很巧妙。
显然是一个dp,考虑最朴素的dp,f[i][j]表示选i个乘起来,%m为j的方案数为多少。转移也很简单。
然而乘法的转移并不能进行什么优化,于是考虑设法将其转为加法。
我们可以通过求出m的原根,因为原根G,G^i %m(1<=i<m)是一一对应1<=x<m的所以我们将原本S集合中的元素由"x"替换为"i",这样就完成了由乘法向加法的转换,但是需要注意的是,此时的模数并不再是m而是m-1,。下文中的m均为输入的m-1后的值。
那么问题就转换为了从一堆数中选n个加起来,最终%m为j的方案数为多少,如果我们用g数组来表示某一个元素是否存在,我们发现f[i+1][j]是由f[i][k]*g[(j-k+m)%m](0<=k<m)转移而来,我们发现他很像卷积的形式,但是%m怎么处理呢?
这里get了一个新技能,好像叫循环卷积。
就是计算f和g的卷积后,对于m<=i<l这一段,将他的值加到i%m上去。
f数组自然是可以重复使用的。
那么我们重复卷积n次,得到的即为最终答案。
然而n太大了。
卷积是具有交换律和结合律的,所以可以用快速幂来加速这一过程。
整体复杂度mlogmlogn
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int inf=4e4+10; 4 const int mod=1004535809; 5 int n,m,x,s; 6 int a[inf]; 7 int fast(int x,int y,int z){ 8 int ans=1; 9 while(y){ 10 if(y&1)ans=1ll*ans*x%z; 11 x=1ll*x*x%z; 12 y>>=1; 13 } 14 return ans; 15 } 16 namespace RT{ 17 int tmp[inf],cnt; 18 bool check(int x,int y){ 19 for(int i=1;i<=cnt;i++){ 20 if(::fast(x,(y-1)/tmp[i],y)==1)return 0; 21 } 22 return 1; 23 } 24 int get_rt(int x){ 25 int u=x; 26 x--; 27 for(int i=2;i*i<=x;i++){ 28 if(x%i)continue; 29 tmp[++cnt]=i; 30 while(x%i==0)x/=i; 31 } 32 if(x>1)tmp[++cnt]=x; 33 int now=2; 34 while(1){ 35 if(check(now,u))return now; 36 now++; 37 } 38 } 39 } 40 int hs[inf]; 41 int g[inf],f[inf]; 42 int r[inf]; 43 void ntt(int *a,int l,int type){ 44 for(int i=0;i<l;i++) 45 if(i<r[i])swap(a[i],a[r[i]]); 46 for(int i=2;i<=l;i<<=1){ 47 int wn=fast(3,(type*(mod-1)/i+(mod-1))%(mod-1),mod); 48 for(int j=0;j<l;j+=i){ 49 int w=1; 50 for(int k=j;k<j+i/2;k++,w=1ll*w*wn%mod){ 51 int u=a[k],v=1ll*w*a[k+i/2]%mod; 52 a[k]=(u+v)%mod; 53 a[k+i/2]=(u-v+mod)%mod; 54 } 55 } 56 } 57 } 58 int main() 59 { 60 freopen("sdoi2015_sequence.in","r",stdin); 61 freopen("sdoi2015_sequence.out","w",stdout); 62 scanf("%d%d%d%d",&n,&m,&x,&s); 63 for(int i=1;i<=s;i++)scanf("%d",&a[i]); 64 int G=RT::get_rt(m); 65 for(int i=1;i<m;i++)hs[fast(G,i,m)]=i; 66 m--; 67 for(int i=1;i<=s;i++) 68 if(a[i])g[hs[a[i]]%m]=1; 69 int l=1,h=0; 70 while(l<m*2-1)l<<=1,h++; 71 for(int i=0;i<l;i++) 72 r[i]=(r[i>>1]>>1)+((i&1)<<(h-1)); 73 int inv=fast(l,mod-2,mod); 74 f[0]=1; 75 while(n){ 76 ntt(g,l,1); 77 if(n&1){ 78 ntt(f,l,1); 79 for(int i=0;i<l;i++)f[i]=1ll*f[i]*g[i]%mod; 80 ntt(f,l,-1); 81 for(int i=0;i<l;i++)f[i]=1ll*f[i]*inv%mod; 82 for(int i=m;i<l;i++)f[i%m]=(f[i%m]+f[i])%mod,f[i]=0; 83 } 84 for(int i=0;i<l;i++)g[i]=1ll*g[i]*g[i]%mod; 85 ntt(g,l,-1); 86 for(int i=0;i<l;i++)g[i]=1ll*g[i]*inv%mod; 87 for(int i=m;i<l;i++)g[i%m]=(g[i%m]+g[i])%mod,g[i]=0; 88 n>>=1; 89 } 90 printf("%d\n",f[hs[x]%m]); 91 return 0; 92 }