计数难题6:luoguP4935 口袋里的纸飞机

Posted guessycb

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了计数难题6:luoguP4935 口袋里的纸飞机相关的知识,希望对你有一定的参考价值。

计数难题6:luoguP4935 口袋里的纸飞机

标签(空格分隔): 计数难题题选

题目大意:

链接:戳我!
随机生成一个大小为(n)的数列({a_i}),每个数的范围都在([1,R])之间。
对于每种数列,可以生成一个(n*n)的网格,其中格子((i,j))的数为(a_i*a_j\% P)
对于一个数列,定义其价值为形成的网格中不同的数的个数。
现在你需要求出所有数列的价值之和,答案对(10^9 + 7)取模。
数据范围:(nleq 500 , Pleq 5000 , Rleq 10^9) ,保证(P)为大于(3)的质数。

题解

显然枚举每一个数(x),然后计算它可以在多少个排列中出现。
直接算在多少个排列中出现并不好算,我们算它的反面:在多少个排列中不出现。
注意到(P)是质数,即有原根。
所以(a*b \%P = x)这样的((a,b))一定是一些互不相交的二元组。
现在数就被分成了三类:

  • (a*b\% P=x),则((a,b))为一个限制二元组。
  • (a*a\% P = x),则(a)不能选。
  • 其它的数则可以任意选。

显然我们只需要处理这些限制二元组的选择,设(f_{i,j})表示前(i)个二元组选了(j)个数的方案数。
转移直接插入(t)个元素:
[f_{i,j} (v_1^t + v_2^t) inom{j + t}{t} o f_{i+1 , j+t} ]
这样子的暴力是(O(P^2 n^2))的。
注意到每一个二元组元素(a,b)的选择个数(v_1)(v_2)只有可能是(lfloorfrac{R}{n} floor)(lfloorfrac{R}{n} floor+1)
所以本质不同的二元组只有(3)种。
(f[op][i][j])表示使用了(i)个二元组(op),一共选了(j)个数的方案数。
暴力转移,最后用(f[0/1/2])合并答案,最后复杂度是(O(Pn^2))的。
看上去已经没有什么可以优化的了。
所以就暴力分块吧。
对于每一种二元组,我们预处理使用(1,2,...sqrt{P})个时的方案数(f)
然后利用(f_{sqrt{n}}),我们又可以处理出使用(sqrt{P},2sqrt{P}...(sqrt{P})^2)个时的方案数(g)
转移全部都是插入元素即可。
那么对于任意一个使用个数(t)
我们就可以通过一个(f)和一个(g)(O(n^2))的时间内通过卷积算出其所有方案数。
然后又一个结论:本质不同的(x)只有(sqrt{P})个。
所以记忆化后,对于每一种(x)暴力算出答案即可,注意特判(0)
复杂度(O(sqrt{P}n^2))

实现代码

#include<bits/stdc++.h>
#define IL inline
#define _
#define ll long long
using namespace std ;

IL int gi(){
    int data = 0 , m = 1; char ch = 0;
    while(ch!='-' && (ch<'0'||ch>'9')) ch = getchar();
    if(ch == '-'){m = 0 ; ch = getchar() ; }
    while(ch >= '0' && ch <= '9'){data = (data<<1) + (data<<3) + (ch^48) ; ch = getchar(); }
    return (m) ? data : -data ; 
}

#define mod 1000000007

IL int Pow(int ts , int js) {
    int al = 1 ;
    while(js) {
        if(js & 1) al = 1ll * al * ts % mod ;
        ts = 1ll * ts * ts % mod ;
        js >>= 1 ; 
    }
    return al ;
}
IL void add(int &x , int y){x += y ; if(x >= mod) x-= mod ;}

int Fac[5005],inv[5005],IFac[5005],n,m,R,P,Base ;
int val[5005],ban[5005],Bac,Bac2,Ans,ALL,pw1[5005],pw2[5005],cnt[5005][5] ; 
int f[3][505][5005] , g[3][505][5005] , dp[5][5005] , ret[5][5005] , Result[505][505] ;
int oo ; 

IL void Numb() {
    Fac[0] = Fac[1] = inv[0] = inv[1] = IFac[0] = IFac[1] = 1 ;
    for(int i = 2; i <= 5000; i ++) {
        inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod ;
        Fac[i] = 1ll * i * Fac[i - 1] % mod ;
        IFac[i] = 1ll * inv[i] * IFac[i - 1] % mod ; 
    }
    return ; 
}
IL int Comb(int N , int M) {
    if(M > N) return 0 ;
    return 1ll * Fac[N] * IFac[M] % mod * IFac[N - M] % mod ; 
}

IL void Solve(int id , int v1 , int v2) {
    pw1[0] = pw2[0] = 1 ;
    for(int i = 1; i <= n; i ++) pw1[i] = 1ll * pw1[i-1] * v1 % mod , pw2[i] = 1ll * pw2[i-1] * v2 % mod ;
    f[id][0][0] = 1 ;
    for(int i = 0; i < Bac; i ++)
        for(int j = 0; j <= n; j ++)
            if(f[id][i][j])
                for(int t = 0; t + j <= n; t ++) {
                    if(t){
                        add(f[id][i+1][t + j] , 1ll * f[id][i][j] * Comb(j + t , t) % mod * pw1[t] % mod) ;
                        add(f[id][i+1][t + j] , 1ll * f[id][i][j] * Comb(j + t , t) % mod * pw2[t] % mod) ; 
                    }
                    else add(f[id][i+1][j] , f[id][i][j]) ;
                }
    g[id][0][0] = 1 ;
    for(int i = 0; i < Bac2; i ++)
        for(int j = 0; j <= n; j ++)
            if(g[id][i][j])
                for(int t = 0; t + j <= n; t ++) 
                    add(g[id][i+1][t + j] , 1ll * g[id][i][j] * Comb(j + t , t) % mod * f[id][Bac][t] % mod) ;
    return ; 
}

IL void Calc(int id , int s) {
    int bs = 0 ;
    while(Bac * (bs + 1) <= s) ++ bs ;
    int rest = s - Bac * bs ;
    for(int j = 0; j <= n; j ++) dp[0][j] = g[id][bs][j] , dp[1][j] = 0 ;
    for(int j = 0; j <= n; j ++)
        for(int t = 0; t + j <= n; t ++)
            add(dp[1][j + t] , 1ll * dp[0][j] * f[id][rest][t] % mod * Comb(j + t , t) % mod) ;
    return ; 
}
IL void Query(int Id , int s0 , int s1 , int s2) {
    int isum[3] = {s0,s1,s2} ;
    for(int id = 0; id < 3; id ++) {
        Calc(id , isum[id]) ;
        for(int j = 0; j <= n; j ++) ret[id + 1][j] = dp[1][j] ;
    }
    for(int i = 0; i <= 3; i ++)
        for(int j = 0; j <= n; j ++) dp[i][j] = 0 ;
    for(int j = 0; j <= n; j ++) dp[0][j] = 0 ; dp[0][0] = 1 ;
    for(int i = 0; i < 3; i ++)
        for(int j = 0; j <= n; j ++)
            if(dp[i][j])
                for(int t = 0; j + t <= n; t ++)
                    add(dp[i+1][j + t] , 1ll * dp[i][j] * ret[i+1][t] % mod * Comb(j + t , t) % mod) ;
    for(int j = 0; j <= n; j ++) Result[Id][j] = dp[3][j] ;
    return ; 
}

struct Hash{
    int a0,a1,a2 ;
    bool operator < (const Hash &B) const {
        return (a2 ^ B.a2) ? a2 < B.a2 : ((a0 ^ B.a0) ? a0 < B.a0 : a1 < B.a1) ; 
    }
};map<Hash,int>ID ; 

int main() {
    n = gi() ; P = gi() ; R = gi() ;
    val[0] = R / P ;
    Numb() ;
    Base = R / P ;
    Bac = sqrt(P) ; Bac2 = (P + Bac - 1) / Bac ;
    for(int i = 1; i < P; i ++)
        if(R - Base * P >= i) val[i] = Base + 1 ; else val[i] = Base ; 
    for(int x1 = 1 ; x1 < P; x1 ++)
        for(int x2 = 1; x2 <= x1; x2 ++) {
            int v = 1ll * x1 * x2 % P ;
            if(x1 == x2) ban[v] += val[x1] ;
            else {
                ban[v] += val[x1] + val[x2] ;
                int s1 = val[x1] , s2 = val[x2] ;
                if(s1 > s2) swap(s1 , s2) ;
                if(s1 == Base && s2 == Base) cnt[v][0] ++ ;
                else if(s1 == Base && s2 == Base + 1) cnt[v][1] ++ ;
                else if(s1 == Base + 1 && s2 == Base + 1) cnt[v][2] ++ ; 
            }
        }
    Solve(0 , Base , Base) ;
    Solve(1 , Base , Base + 1) ;
    Solve(2 , Base + 1 , Base + 1) ;
    for(int v = 1; v < P; v ++) {
        Hash al ; al.a0 = cnt[v][0] ; al.a1 = cnt[v][1] ; al.a2 = cnt[v][2] ;
        if(ID.find(al) == ID.end()) ID[al] = ++ oo ;
    }
    for(map<Hash,int>::iterator it = ID.begin(); it != ID.end(); it ++) {
        Hash al = it->first ;
        int id = it->second ;
        Query(id , al.a0 , al.a1 , al.a2) ;
    }
    ALL = Pow(R , n) ;
    for(int x = 1; x < P; x ++) {
        Hash al ; al.a0 = cnt[x][0] ; al.a1 = cnt[x][1] ; al.a2 = cnt[x][2] ;
        int id = ID[al] ;
        int AL = 0 ;
        for(int j = 0; j <= n; j ++) add(AL , 1ll * Result[id][j] * Pow(R - ban[x] , n - j) % mod * Comb(n , j) % mod) ;
        add(Ans , (ALL - AL + mod) % mod) ;
    }   
    Ans = (Ans + (ALL - Pow((R-val[0] + mod) % mod , n) + mod) % mod) % mod ;
    cout << Ans << endl ;
    return 0 ; 
}

以上是关于计数难题6:luoguP4935 口袋里的纸飞机的主要内容,如果未能解决你的问题,请参考以下文章

luoguP1195 口袋的天空

luoguP1379 八数码难题[IDA*]

[Luogu] P1195 口袋的天空

python面试30-40题

我口袋里的IDE?

FLASH如何制作飞机飞