解法一:容易得到递推式,可以用CDQ分治+FFT
代码用时:1h 比较顺利,没有低级错误。
实现比较简单,11348ms
#include<cstdio>
#include<algorithm>
#define rep(i,l,r) for (int i=l; i<=r; i++)
typedef long long ll;
using namespace std;
const int N=(1<<18)+100,P=998244353,g=3;
int n,rev[N];
ll inv[N],fac[N],facinv[N],f[N],a[N],b[N];
ll ksm(ll a,ll b){
ll ans=1;
for (; b; b>>=1,a=a*a%P)
if (b & 1) ans=ans*a%P;
return ans;
}
void DFT(ll a[],int n,int f){
rep(i,0,n-1) if (i<rev[i]) swap(a[i],a[rev[i]]);
for (int i=1; i<n; i<<=1){
int wn=ksm(g,(f==1) ? (P-1)/(i<<1) : (P-1)-(P-1)/(i<<1));
for (int p=i<<1,j=0; j<n; j+=p){
int w=1;
for (int k=0; k<i; k++,w=1ll*w*wn%P){
int x=a[j+k],y=1ll*w*a[i+j+k]%P;
a[j+k]=(x+y)%P; a[i+j+k]=(x-y+P)%P;
}
}
}
if (f==-1){
int inv=ksm(n,P-2);
rep(i,0,n-1) a[i]=1ll*a[i]*inv%P;
}
}
void cdq(int l,int r){
if (l==r) return;
int mid=(l+r)>>1,lim=r-l+1,n=1,L=0;
cdq(l,mid);
while (n<lim) n<<=1,L++;
rep(i,0,n-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
rep(i,0,n-1) a[i]=b[i]=0;
rep(i,l,mid) a[i-l]=f[i];
rep(i,0,r-l) b[i]=facinv[i];
DFT(a,n,1); DFT(b,n,1);
rep(i,0,n-1) a[i]=a[i]*b[i]%P;
DFT(a,n,-1);
rep(i,mid+1,r) f[i]=(f[i]+2*a[i-l])%P;
cdq(mid+1,r);
}
int main(){
freopen("bzoj4555.in","r",stdin);
freopen("bzoj4555.out","w",stdout);
scanf("%d",&n); inv[1]=1; fac[0]=facinv[0]=1;
rep(i,1,n){
if (i!=1) inv[i]=(P-P/i)*inv[P%i]%P;
fac[i]=fac[i-1]*i%P;
facinv[i]=facinv[i-1]*inv[i]%P;
}
f[0]=1; cdq(0,n); ll ans=0;
rep(i,0,n) ans=(ans+f[i]*fac[i]%P)%P;
if (ans<0) ans+=P;
printf("%lld\n",ans);
return 0;
}
解法二:
代码用时1.5h long long上出了一点问题
整体上说还是比较简单的。
#include<cstdio>
#include<algorithm>
#define rep(i,l,r) for (int i=l; i<=r; i++)
typedef long long ll;
using namespace std;
const int N=(1<<18)+5,P=998244353,g=3;
int n,rev[N];
ll ans,inv[N],fac[N],facinv[N],f[N],a[N],b[N];
ll ksm(ll a,ll b){
ll ans=1;
for (; b; b>>=1,a=a*a%P)
if (b & 1) ans=ans*a%P;
return ans;
}
void DFT(ll a[],int n,int f){
rep(i,0,n-1) if (i<rev[i]) swap(a[i],a[rev[i]]);
for (int i=1; i<n; i<<=1){
ll wn=ksm(g,(f==1) ? (P-1)/(i<<1) : (P-1)-(P-1)/(i<<1));
for (int p=i<<1,j=0; j<n; j+=p){
ll w=1;
for (int k=0; k<i; k++,w=w*wn%P){
ll x=a[j+k],y=w*a[i+j+k]%P;
a[j+k]=(x+y)%P; a[i+j+k]=(x-y+P)%P;
}
}
}
if (f==-1){
int inv=ksm(n,P-2);
rep(i,0,n-1) a[i]=a[i]*inv%P;
}
}
int main(){
scanf("%d",&n); inv[1]=1; fac[0]=facinv[0]=1;
rep(i,1,n){
if (i!=1) inv[i]=(P-P/i)*inv[P%i]%P;
fac[i]=fac[i-1]*i%P;
facinv[i]=facinv[i-1]*inv[i]%P;
}
a[0]=1; b[0]=1; b[1]=n+1;
rep(i,1,n) a[i]=((i&1)?-1:1)*facinv[i];
rep(i,2,n) b[i]=(ksm(i,n+1)-1)*inv[i-1]%P*facinv[i]%P;
ll lim=n+n+1,nn=1,L=0;
while (nn<lim) nn<<=1,L++;
rep(i,0,nn-1) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
DFT(a,nn,1); DFT(b,nn,1);
rep(i,0,nn-1) a[i]=a[i]*b[i];
DFT(a,nn,-1);
rep(i,0,n) ans=(ans+ksm(2,i)*fac[i]%P*a[i]%P)%P;
if (ans<0) ans+=P;
printf("%lld\n",ans);
return 0;
}