link : https://loj.ac/problem/6261
一看就是一个已经退役的大佬出的题。。。
我一开始还是too young了,忘了看时限,,以为NTT+多项式快速幂就能水过的。。。
于是就写了个我代码里被我注释掉的东西。。。。。
后来被卡了之后才想起来全是1的多项式的N次方的各个项的系数是可以直接用组合数算出来的。。。。
具体的说,设 A = {1,x,x^2,,,,,,,} ^ N 。
A中的 x^i 的系数就是 C(N+i-1,i) ,因为这就是可重组合的定义式吧23333。
而根据组合数的通项公式,我们可以很容易的从 x^i 的系数来递推 x^(i+1) 的系数。
O(N) 计算出 A 之后,直接一遍NTT 让 a卷一下它,得到结果。
写的时候犯了很多SB错误,想想都想打自己23333.
1.一开始求补0之后序列长度的逆元求成补0之前的了,,,虽然这个样例并不能看出来因为样例 4=2^2 2333.
2. 数组大小一定要开到 > 卷积后最高次数 的2^k 。
3.K比较大,,,乘积的时候得先%ha之后再运算233
#include<bits/stdc++.h> #define ll long long #define maxn 400005 using namespace std; const int ha=998244353; const int root=3; const int inv=ha/3+1; ll k; int a[maxn],b[maxn],n,N; int INV,e[maxn],r[maxn],l; int object[2][maxn]; int ni[maxn]; inline int ksm(int x,int y){ int an=1; for(;y;y>>=1,x=x*(ll)x%ha) if(y&1) an=an*(ll)x%ha; return an; } inline int add(int x,int y){ x+=y; return x>=ha?x-ha:x; } inline void NTT(int *c,int f){ for(int i=0;i<N;i++) if(i<r[i]) swap(c[i],c[r[i]]); for(int i=1,o=1;i<N;i<<=1,o++){ int omega=object[f==-1][o]; for(int p=i<<1,j=0;j<N;j+=p){ int now=1; for(int u=0;u<i;u++,now=now*(ll)omega%ha){ int x=c[j+u],y=c[j+u+i]*(ll)now%ha; c[j+u]=add(x,y); c[j+u+i]=add(x,ha-y); } } } if(f==-1) for(int i=0;i<N;i++) c[i]=c[i]*(ll)INV%ha; } inline void calc(){ b[0]=1; for(int i=1;i<n;i++) b[i]=b[i-1]*((ll)(k+i-1)%ha)%ha*(ll)ni[i]%ha; } inline void solve(){ int len=(n-1)<<1; for(N=1,l=0;N<=len;N<<=1) l++; for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); for(int i=1;i<=l;i++){ object[0][i]=ksm(root,(ha-1)/(1<<i)); object[1][i]=ksm(inv,(ha-1)/(1<<i)); } INV=ksm(N,ha-2); ni[1]=1; for(int i=2;i<=n;i++) ni[i]=-ni[ha%i]*(ll)(ha/i)%ha+ha; calc(); /* while(k){ if(k&1){ NTT(a,1),NTT(b,1); for(int i=0;i<N;i++) a[i]=a[i]*(ll)b[i]%ha; NTT(a,-1),NTT(b,-1); fill(a+n,a+N,0); } NTT(b,1); for(int i=0;i<N;i++) b[i]=b[i]*(ll)b[i]%ha; NTT(b,-1); fill(b+n,b+N,0); k>>=1; } */ NTT(a,1),NTT(b,1); for(int i=0;i<N;i++) a[i]=a[i]*(ll)b[i]%ha; NTT(a,-1); } int main(){ scanf("%d%lld",&n,&k); for(int i=0;i<n;i++){ scanf("%d",a+i); } solve(); for(int i=0;i<n;i++) printf("%d\n",a[i]); return 0; }