参考:http://blog.csdn.net/wzq_qwq/article/details/46709471
首先推组合数,设sum为每个人礼物数的和,那么答案为
\[
( C_{n}^{sum}C_{sum}^{w[1]}c_{sum-w[1]}^{w[2]}...
\]
设w[0]=n-sum,然后化简成阶乘的形式:
\[
\frac{n!}{w[0]!w[1]!...w[n]!}
\]
注意到这里p不是质数,所以把p拆成质数的方相乘的形式,最后用中国剩余定理合并即可
然后现在的问题是怎么快速求出阶乘
假设当前的质数的方为p=3那么1x2x3x4x5x6x7x8x9x10x11=1x2x4x5x7x8x10x11x 3x(1x2x3),注意到后面又是一个阶乘,但是范围更小,所以可以递归来做,然后前面乘的3被模消去了
#include<iostream>
#include<cstdio>
using namespace std;
const int N=100005;
long long P,n,m,w[10],p[N],cnt[N],mod[N],tot,sum,a[N];
struct qwe
{
int a,b;
};
void exgcd(long long a,long long b,long long &x,long long &y,long long &d)
{
if(!b)
{
x=1;
y=0;
d=a;
return;
}
exgcd(b,a%b,y,x,d);
y=y-a/b*x;
}
long long china()
{
long long d,x=0,y;
for(int i=1;i<=tot;i++)
{
long long r=P/mod[i];
exgcd(mod[i],r,d,y,d);
x=(x+r*y*a[i])%P;
}
return (x+P)%P;
}
long long ksm(long long a,long long b,long long mod)
{
long long r=1ll;
while(b)
{
if(b&1)
r=r*a%mod;
a=a*a%mod;
b>>=1;
}
return r;
}
long long inv(long long a,long long b)
{
long long x,y,d;
exgcd(a,b,x,y,d);
return (x%b+b)%b;
}
qwe fac(long long k,long long n)
{
qwe r;
if(!n)
{
r.a=0,r.b=1;
return r;
}
long long x=n/p[k],y=n/mod[k],ans=1ll;
if(y)
{
for(int i=2;i<mod[k];i++)
if(i%p[k]!=0)
ans=ans*i%mod[k];
ans=ksm(ans,y,mod[k]);
}
for(int i=y*mod[k]+1;i<=n;i++)
if(i%p[k]!=0)
ans=ans*i%mod[k];
qwe tmp=fac(k,x);
r.a=x+tmp.a,r.b=ans*tmp.b%P;
return r;
}
long long clc(int k,long long n,long long m)
{
if(n<m)
return 0;
qwe a=fac(k,n),b=fac(k,m),c=fac(k,n-m);
return ksm(p[k],a.a-b.a-c.a,mod[k])*a.b%mod[k]*inv(b.b,mod[k])%mod[k]*inv(c.b,mod[k])%mod[k];
}
long long wk(long long n,long long m)
{
for(int i=1;i<=tot;i++)
a[i]=clc(i,n,m);
return china();
}
int main()
{
scanf("%lld%lld%lld",&P,&n,&m);
for(int i=1;i<=m;i++)
scanf("%lld",&w[i]),sum+=w[i];
int x=P;
for(int i=2;i*i<=x;i++)
if(x%i==0)
{
p[++tot]=i;
mod[tot]=1;
while(x%i==0)
{
x/=i;
cnt[tot]++;
mod[tot]*=i;
}
}
if(x>1)
{
p[++tot]=x;
mod[tot]=x;
cnt[tot]=1;
}
if(sum>n)
{
puts("Impossible");
return 0;
}
long long ans=wk(n,sum)%P;
for(int i=1;i<=m;i++)
{
ans=ans*wk(sum,w[i])%P;
sum-=w[i];
}
printf("%lld\n",ans);
return 0;
}