[BZOJ3684]大朋友和多叉树
试题描述
我们的大朋友很喜欢计算机科学,而且尤其喜欢多叉树。对于一棵带有正整数点权的有根多叉树,如果它满足这样的性质,我们的大朋友就会将其称作神犇的:点权为 \(1\) 的结点是叶子结点;对于任一点权大于 \(1\) 的结点 \(u\),\(u\) 的孩子数目 \(deg[u]\) 属于集合 \(D\),且 \(u\) 的点权等于这些孩子结点的点权之和。
给出一个整数 \(s\),你能求出根节点权值为 \(s\) 的神犇多叉树的个数吗?请参照样例以更好的理解什么样的两棵多叉树会被视为不同的。
我们只需要知道答案关于 \(950009857\)(\(453 \times 2^{21} + 1\),一个质数)取模后的值。
输入
第一行有 \(2\) 个整数 \(s,m\)。
第二行有 \(m\) 个互异的整数,\(d[1],d[2], \cdots ,d[m]\),为集合 \(D\) 中的元素。
输出
输出一行仅一个整数,表示答案模 \(950009857\) 的值。
输入示例
4 2
2 3
输出示例
10
样例解释
数据规模及约定
\(1 \le m < s \le 10^5, 2 \le d[i] \le s\),有 \(3\) 组小数据和 \(3\) 组大数据。
题解
我们还从暴力 dp 入手。令 \(f(i)\) 表示根节点权值为 \(i\) 的树有多少个,令 \(g(i, j)\) 表示由 \(j\) 棵树组成的根节点权值总和为 \(i\) 的森林有多少种,则显然由以下转移:
\[ f(i) = \sum_{j \in D} g(i, j) \f(1) = 1 \g(i, j) = \sum_{k = 1}^j g(i - k, j - 1) \cdot f(k) \]
现在如果用 \(F(x)\) 表示 \(f(i)\) 的生成函数,\(G_j(x)\) 表示 \(g(i, j)\) 的生成函数,不难发现 \(G_j(x) = F(x)^j\),那么就能得到这样一个方程:
\begin{equation}
F(x) = x + \sum_{j \in D} F(x)^j
\end{equation}
为什么要加一个 \(x\) 呢?从 dp 的边界条件看来 \(f(1) = 1\) 而 \(f(0) = 0\);此外根据题意 \(D\) 中元素都 \(\ge 2\),而显然 \(F(x)\) 的常数项为 \(0\),那么 \(\sum_{j \in D} F(x)^j\) 的最低次项就是 \(D\) 中最小元素了,显然比一次项要高,所以我们需要加一个 \(x\)。
但是这个方程我并不会解,怎么办?这就是拉格朗日反演定理的用处了。先看一下这个定理在什么条件下能用。
\(f(x)\) 与 \(g(x)\) 常数项为 \(0\),且 \(g(f(x)) = x\),那么若已知 \(g(x)\),我们可以快速求 \(f(x)\) 的某一项(不妨设要求的是 \(x^n\) 项),我们有:
\[ [x^n]f(x) = \frac{1}{n} [x^{-1}] \left( \frac{1}{g(x)} \right) ^n \]
事实上若满足常数项都是 \(0\) 的条件,\(f(g(x)) = g(f(x)) = x\) 是成立的,所以网上对这个定理的描述可能稍有不同,其实都是等价的。
可能对上面的公式你还不明白 \([x^{-1}]\) 是什么操作。注意到 \(g(x)\) 的常数项为 \(0\),所以它的逆元是不存在的,即没有任何整式能够表示 \(\frac{1}{g(x)}\),这时候就需要扩充一下这个域,引入分式域。在这里所有的多项式能够被表示成 \(\cdots + a_{-2} x^{-2} + a_{-1} x^{-1} + a_0 + a_1 x + a_2 x^2 + \cdots\),可以证明这种形式能够表示所有的 \(\frac{A(x)}{B(x)}\),其中 \(A(x)\) 和 \(B(x)\) 都是整式。说白了 \(\frac{1}{g(x)}\) 是一个分式,所以它有 \(x^{-1}\) 项。
但是你还想知道如何求一个分式的 \(x^{-1}\) 项。其实我们不需要真正求一个分式,我们可以把上面的 \(g(x)\) 的末尾的系数为 \(0\) 的项都去掉,然后正常地求逆元,然后再左移对应项数。即将上式变为:
\[ [x^n]f(x) = \frac{1}{n} [x^{dn-1}] \left( \frac{x^d}{g(x)} \right) ^n \]
其中 \(d\) 表示 \(g(x)\) 的系数中前缀 \(0\) 的个数。这个式子就是求一下 \(\mod x^{dn}\) 下 \(g‘(x)\) 的逆元就好了,\(g‘(x)\) 是 \(g(x)\) 左移直到常数项非零后的多项式。
回到这题,不难发现 \((1)\) 式移项即可得到:
\[ F(x) - \sum_{j \in D} F(x)^j = x \G(F(x)) = x \]
所以得到了 \(G(x) = x - \sum_{j \in D} x^j\) 这个多项式,问题就解决了。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s), mi = (t); i <= mi; i++)
#define dwn(i, s, t) for(int i = (s), mi = (t); i >= mi; i--)
int read() {
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){ if(c == ‘-‘) f = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + c - ‘0‘; c = getchar(); }
return x * f;
}
#define maxn 262144
#define MOD 950009857
#define Groot 7
#define LL long long
int Pow(int a, int b) {
int ans = 1, t = a;
while(b) {
if(b & 1) ans = (LL)ans * t % MOD;
t = (LL)t * t % MOD; b >>= 1;
}
return ans;
}
int inv[maxn];
void init() {
inv[1] = 1;
rep(i, 2, maxn - 1) inv[i] = (LL)(MOD - MOD / i) * inv[MOD%i] % MOD;
return ;
}
int brev[maxn];
void FFT(int *a, int len, int tp) {
int n = 1 << len;
rep(i, 0, n - 1) if(i < brev[i]) swap(a[i], a[brev[i]]);
rep(i, 1, len) {
int wn = Pow(Groot, MOD - 1 >> i);
if(tp < 0) wn = Pow(wn, MOD - 2);
for(int j = 0; j < n; j += 1 << i) {
int w = 1;
rep(k, 0, (1 << i >> 1) - 1) {
int la = a[j+k], ra = (LL)w * a[j+k+(1<<i>>1)] % MOD;
a[j+k] = (la + ra) % MOD;
a[j+k+(1<<i>>1)] = (la - ra + MOD) % MOD;
w = (LL)w * wn % MOD;
}
}
}
if(tp < 0) {
int invn = Pow(n, MOD - 2);
rep(i, 0, n - 1) a[i] = (LL)a[i] * invn % MOD;
}
return ;
}
void Mul(int *A, int *B, int n, int m, bool recover = 0) {
int N = 1, len = 0;
while(N <= n + m) N <<= 1, len++;
rep(i, 0, N - 1) brev[i] = (brev[i>>1] >> 1) | ((i & 1) << len >> 1);
rep(i, n + 1, N - 1) A[i] = 0;
rep(i, m + 1, N - 1) B[i] = 0;
FFT(A, len, 1); FFT(B, len, 1);
rep(i, 0, N - 1) A[i] = (LL)A[i] * B[i] % MOD;
FFT(A, len, -1); if(recover) FFT(B, len, -1);
return ;
}
int tmp[maxn];
void inverse(int *f, int *g, int n) {
if(n == 1) return (void)(f[0] = Pow(g[0], MOD - 2));
inverse(f, g, n + 1 >> 1);
rep(i, 0, n - 1) tmp[i] = g[i];
int N = 1, len = 0;
while(N < (n << 1)) N <<= 1, len++;
rep(i, 0, N - 1) brev[i] = (brev[i>>1] >> 1) | ((i & 1) << len >> 1);
rep(i, n, N - 1) tmp[i] = 0;
rep(i, n + 1 >> 1, N - 1) f[i] = 0;
FFT(f, len, 1); FFT(tmp, len, 1);
rep(i, 0, N - 1) f[i] = ((LL)f[i] * (2ll - (LL)tmp[i] * f[i] % MOD) % MOD + MOD) % MOD;
FFT(f, len, -1);
return ;
}
int deri[maxn], pinv[maxn];
void logarithm(int *f, int *g, int n) { // g[0] must be 1
rep(i, 1, n - 1) deri[i-1] = (LL)g[i] * i % MOD;
inverse(pinv, g, n);
Mul(deri, pinv, n - 2, n - 1);
rep(i, 0, n - 2) f[i+1] = (LL)deri[i] * inv[i+1] % MOD; f[0] = 0;
return ;
}
int lnf[maxn], tg[maxn];
void exponential(int *f, int *g, int n) { // g[0] must be 0
if(n == 1) return (void)(f[0] = 1);
exponential(f, g, n + 1 >> 1);
rep(i, n + 1 >> 1, n - 1) f[i] = 0;
logarithm(lnf, f, n);
rep(i, 0, n - 1) tg[i] = (g[i] - lnf[i] + MOD) % MOD; tg[0]++; if(tg[0] >= MOD) tg[0] -= MOD;
Mul(f, tg, (n + 1 >> 1) - 1, n - 1);
return ;
}
int t1[maxn];
void p_pow(int *a, int n, int b) { // a[0] must be 1
logarithm(t1, a, n);
rep(i, 0, n - 1) t1[i] = (LL)t1[i] * b % MOD;
exponential(a, t1, n);
return ;
}
int n, m, d[maxn], ans[maxn];
int main() {
init();
n = read(); m = read();
rep(i, 1, m) d[read()-1] = MOD - 1;
d[0] = 1;
p_pow(d, n, n);
inverse(ans, d, n);
printf("%lld\n", (LL)inv[n] * ans[n-1] % MOD);
return 0;
}