[日常训练]三视图(组合计数+容斥)
Posted cyf32768
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[日常训练]三视图(组合计数+容斥)相关的知识,希望对你有一定的参考价值。
Description
给定两个长度为 (n) 的数组 (a,b)。
要求给一个 (n×n) 的矩阵的每个位置填上一个非负整数,使得第 (i) 行的最大值为 (a_i),第 (j) 列的最大值为 (b_j)。
求方案数对 (998244353) 取模的结果。
(1leq nleq 10^5),(1leq a,bleq 10^9)。
Solution
显然可以把 (a,b) 分别降序排序,不影响结果。
记 (c_{i,j}=min(a_i,b_j))。考虑将 (c_{i,j}) 相同的位置放在一起处理。
显然 (c_{i,j}) 相同的位置会形成一个 (lceil) 反 (L) 形 ( floor)。
具体地,将 (a,b) 降序排序后,若 (a_{x_1}) ~ (a_{x_2}) 和 (b_{y_1}) ~ (b_{y_2}) 都是 (s),则所有满足以下条件的 (i,j) 都有 (c_{i,j}=s):
- (1leq ileq x_2) 且 (1leq jleq y_2)
- (x_1leq i) 或 (y_1leq j)。
那么现在这个 (lceil) 反 (L) 形 ( floor) 的填数要满足以下条件:
- 每个格子的数都在 ([0,s]) 中。
- 第 (x_1) ~ (x_2) 行的最大值是 (s)。
- 第 (y_1) ~ (y_2) 列的最大值是 (s)。
考虑容斥,设 (f(i)) 表示第 (x_1) ~ (x_2) 行中,至少有 (i) 行的最大值不是 (s) 的方案(需要保证第 (y_1) ~ (y_2) 列的最大值是 (s))。
令 (l_x=x_2-x_1+1,l_y=y_2-y_1+1)。(即行数和列数)
那么 (ans=sum_{i=0}^{lx}(-1)^if(i))。
考虑怎么算 (f(i)):
先填强制最大值不是 (s) 的那 (i) 行:首先从 (l_x) 行中选出 (i) 行,然后这 (i) 行最大值不是 (s),所以这 (i) 行的所有格子都只能填 ([0,s-1])。
这一部分的方案数:(c(l_x,i)×s^{i×y_2})。
接着考虑 (lceil) 保证第 (y_1) ~ (y_2) 列的最大值是 (s) ( floor):第 (y_1) ~ (y_2) 列中,每列已经填了 (i) 个,还剩 (x_2-i) 个没填。显然每列的这 (x_2-i) 个格子中,必定至少有一个 (s)。
这一部分的方案数:(((s+1)^{x_2-i}-s^{x_2-i})^{l_y})。
还剩下 ((l_x-i)(y_1-1)) 个格子,随便填。
这一部分的方案数:((s+1)^{(l_x-i)(y_1-1)})。
所以:
[f(i)=c(l_x,i)×s^{i×y_2}×((s+1)^{x_2-i}-s^{x_2-i})^{l_y}×(s+1)^{(l_x-i)(y_1-1)}]
时间复杂度 (O(n (log n+log 1e9)))。
Code
#include <bits/stdc++.h>
using namespace std;
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int e = 2e5 + 5, mod = 1e9 + 7;
int n, a[e], b[e], d[e], m, ans = 1, fac[e], inv[e];
inline int ksm(int x, int y)
{
int res = 1;
while (y)
{
if (y & 1) res = (ll)res * x % mod;
y >>= 1;
x = (ll)x * x % mod;
}
return res;
}
inline int c(int x, int y)
{
if (x < y) return 0;
return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}
inline int plu(int x, int y)
{
(x += y) >= mod && (x -= mod);
return x;
}
inline int sub(int x, int y)
{
(x -= y) < 0 && (x += mod);
return x;
}
inline int solve(int la, int lb, int ra, int rb, int s)
{
int res = 0, i, a = ra - la, b = rb - lb;
for (i = 0; i <= a; i++)
{
int tmp = (ll)c(a, i) * ksm(s, (ll)i * rb % (mod - 1)) % mod *
ksm(sub(ksm(s + 1, ra - i), ksm(s, ra - i)), b) % mod *
ksm(s + 1, (ll)(a - i) * lb % (mod - 1)) % mod;
if (i & 1) res = sub(res, tmp);
else res = plu(res, tmp);
}
return res;
}
int main()
{
read(n);
int i;
for (i = 1; i <= n; i++) read(a[i]), d[++m] = a[i];
for (i = 1; i <= n; i++) read(b[i]), d[++m] = b[i];
sort(d + 1, d + m + 1);
m = unique(d + 1, d + m + 1) - d - 1;
reverse(d + 1, d + m + 1);
sort(a + 1, a + n + 1); reverse(a + 1, a + n + 1);
sort(b + 1, b + n + 1); reverse(b + 1, b + n + 1);
fac[0] = 1;
for (i = 1; i <= n; i++) fac[i] = (ll)fac[i - 1] * i % mod;
inv[n] = ksm(fac[n], mod - 2);
for (i = n - 1; i >= 0; i--) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
int la = 0, lb = 0;
for (i = 1; i <= m; i++)
{
int ra = la, rb = lb;
while (ra < n && a[ra + 1] == d[i]) ra++;
while (rb < n && b[rb + 1] == d[i]) rb++;
ans = (ll)ans * solve(la, lb, ra, rb, d[i]) % mod;
la = ra; lb = rb;
}
cout << ans << endl;
fclose(stdin);
fclose(stdout);
return 0;
}
以上是关于[日常训练]三视图(组合计数+容斥)的主要内容,如果未能解决你的问题,请参考以下文章