【CF932E】Team Work
题意:求$\sum\limits_{i=1}^nC_n^ii^k$,答案模$10^9+7$。$n\le 10^9,k\le 5000$。
【BZOJ5093】图的价值
题意:“简单无向图”是指无重边、无自环的无向图(不一定连通)。一个带标号的图的价值定义为每个点度数的k次方的和。给定n和k,请计算所有n个点的带标号的简单无向图的价值之和。因为答案很大,请对998244353取模输出。
$n\le 10^9,k\le 200000$
题解:对于第二道题我们显然可以将每个点的度数分开算,枚举这个点的度数显然可以得出$ans=n\times 2^{\frac {(n-1)(n-2)} 2}\sum\limits_{i=0}^{n-1}C_n^ii^k$。那么这两道题的关键都在于如何求出$\sum\limits_{i=0}^nC_ii^k$。
由于第一道题的k比较小,这里给出一种简单的递推方法:
$\sum\limits_{i=1}^nC_n^ii^k=\sum\limits_{i=1}^nC_n^i\times i\times i^{k-1}\\=n\sum\limits_{i=1}^nC_{n-1}^{i-1}\times i^{k-1}\\=n\sum\limits_{i=1}^{n-1}C_{n-1}^i\times (i+1)^{k-1}+1$
这提醒我们用$f[k][j]=\sum\limits_{i=1}^{n-j}C_{n-j}^i\times(i+j)^k$。不难得到递推式:
$f[k][j]=\sum\limits_{i=1}^{n-j}C_{n-j}^i(i+j)^k\\=\sum\limits_{i=1}^{n-j}C_{n-j}^ii(i+j)^{k-1}+j\sum\limits_{i-1}^{n-j}C_{n-j}^i(i+j)^{k-1}\\=(n-j)\sum\limits_{i=1}^{n-j}C_{n-j-1}^{i-1}(i+j)^{k-1}+j\cdot f[k-1][j]\\=(n-j)(\sum\limits_{i=1}^{n-j-1}C_{n-j-1}^i(i+j+1)^{k-1}+(j+1)^{k-1})+j\cdot f[k-1][j]\\=(n-j)(f[k-1][j+1]+(j+1)^{k-1})+j\cdot f[k-1][j]$
$O(n^2)$算即可。注意特判$n<k$的情况。
第二题这么算就过不去了,我们换一种思路,考虑用第二类斯特林数。
$S_n^m$为第二类斯特林数,表示把n个不同的球装进m个相同的箱子里,要求箱子不能为空的方案数。我们可以用容斥来得到第二类斯特林数的通项公式。先假设箱子是有标号的,则答案要除以$m!$。我们枚举至少有i个箱子是空的,球可以随便装进其余的箱子里,那么方案数为$C_m^i\times (m-i)^n$。所以$S_n^m=\frac 1 {m!}\sum\limits_{i=0}^m(-1)^iC_m^i(m-i)^n$。
将组合数拆开,得到$S_n^m=\sum\limits_{i=0}^m\frac {(-1)^i} {i!}\frac {(m-i)^n} {(m-i)!}$。发现这个东西符合卷积的形式,所以我们可以用NTT在$O(n\log n)$的时间里预处理斯特林数。
再来考虑一个组合问题,将n个不同的球装进m个不同的箱子里,允许箱子为空的方案数是多少?我们可以枚举非空的箱子的个数,由于是有标号的,所以还要乘一个排列数,那么答案就是$\sum\limits_{i=1}^mP_m^iS_n^i$。而小学生都知道这个问题的答案就是$m^n$。所以$m^n=\sum\limits_{i=1}^mP_m^iS_n^i$。我们下面就要利用这个恒等式。
$\sum\limits_{i=0}^nC_n^ii^k=\sum\limits_{i=0}^nC_n^i\sum\limits_{j=0}^iP_i^jS_k^j\\=\sum\limits_{i=0}^nC_n^i\sum\limits_{j=0}^iC_i^jj!S_k^j\\=\sum\limits_{j=0}^nj!S_k^j\sum\limits_{i=j}^nC_n^iC_n^j$
后面那个东西是什么?在n个物品里选择i个,再从i个物品里选j个的方案数。我们可以先选j个,然后其他物品可以选也可以不选,所以就是$C_n^j\times 2^{n-j}$。
于是就变成了$\sum\limits_{j=0}^nS_k^jj!C_n^j2^{n-j}=\sum\limits_{j=0}^kS_k^jj!C_n^j2^{n-j}$。
代码:
CF932E:
#include <cstdio> #include <cstring> #include <iostream> using namespace std; typedef long long ll; const ll P=1000000007; ll n,m; ll f[2][5010],g[5010]; inline ll pm(ll x,ll y) { ll z=1; while(y) { if(y&1) z=z*x%P; x=x*x%P,y>>=1; } return z; } int main() { scanf("%lld%lld",&n,&m); int i,j,d; if(n<=m) { for(i=0;i<=n;i++) { d=i&1; memset(f[d],0,sizeof(f[d])); f[d][0]=1; for(j=1;j<=i;j++) f[d][j]=(f[d^1][j-1]+f[d^1][j])%P; } ll ans=0; for(i=1;i<=n;i++) ans=(ans+f[n&1][i]*pm(i,m))%P; printf("%lld",ans); return 0; } ll tmp=pm(2,n-m); for(i=1;i<=m+1;i++) g[i]=1; for(i=m;i>=0;i--) f[0][i]=tmp-1,tmp=(tmp<<1)%P; for(i=1;i<=m;i++) { d=i&1; memset(f[d],0,sizeof(f[d])); for(j=0;j<=m-i;j++) f[d][j]=((n-j)*(f[d^1][j+1]+g[j+1])%P+j*f[d^1][j])%P,g[j+1]=g[j+1]*(j+1)%P; } printf("%lld",f[m&1][0]); return 0; }
BZ5093:
#include <cstring> #include <iostream> #include <cstdio> #include <algorithm> using namespace std; typedef long long ll; const ll P=998244353; const int maxn=(1<<19)+4; int k,len; ll n,ans; ll A[maxn],B[maxn],ine[maxn],jc[maxn],jcc[maxn]; inline ll pm(ll x,ll y) { ll z=1; while(y) { if(y&1) z=z*x%P; x=x*x%P,y>>=1; } return z; } inline ll c(int a,int b) { if(a<b) return 0; return jc[a]*jcc[a-b]%P*jcc[b]%P; } inline void NTT(ll *a,int f) { int i,j,k,h; ll t; for(i=k=0;i<len;i++) { if(i>k) swap(a[i],a[k]); for(j=(len>>1);(k^=j)<j;j>>=1); } for(h=2;h<=len;h<<=1) { ll wn; if(f==1) wn=pm(3,(P-1)/h); else wn=pm(3,P-1-(P-1)/h); for(i=0;i<len;i+=h) { ll w=1; for(j=i;j<i+h/2;j++) t=w*a[j+h/2]%P,a[j+h/2]=(a[j]-t)%P,a[j]=(a[j]+t)%P,w=w*wn%P; } } if(f==-1) { t=pm(len,P-2); for(i=0;i<len;i++) a[i]=a[i]*t%P; } } int main() { scanf("%lld%d",&n,&k),n--; int i; for(len=1;len<(k<<1);len<<=1); ine[0]=ine[1]=jc[0]=jc[1]=jcc[0]=jcc[1]=1; for(i=2;i<=k;i++) ine[i]=P-(P/i)*ine[P%i]%P,jc[i]=jc[i-1]*i%P,jcc[i]=jcc[i-1]*ine[i]%P; for(i=0;i<=k;i++) A[i]=((i&1)?-1:1)*jcc[i],B[i]=pm(i,k)*jcc[i]%P; NTT(A,1),NTT(B,1); for(i=0;i<len;i++) A[i]=A[i]*B[i]%P; NTT(A,-1); ll tmp=1; for(i=0;i<=n&&i<=k;i++) { ans=(ans+A[i]*jc[i]%P*tmp%P*pm(2,n-i))%P; tmp=tmp*(n-i)%P*ine[i+1]%P; } ans=ans*(n+1)%P*pm(2,n*(n-1)/2)%P; ans=(ans+P)%P; printf("%lld",ans); return 0; }