NTT模板 神奇的迷宫 NTT加点分治
Posted goto_1600
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了NTT模板 神奇的迷宫 NTT加点分治相关的知识,希望对你有一定的参考价值。
链接
题意就是求给定一棵树
求
∑
0
n
−
1
W
l
∑
∑
[
d
i
s
t
(
i
,
j
)
=
=
L
]
∗
a
[
i
]
∗
a
[
j
]
\\sum_0^n-1 W_l \\sum\\sum [dist(i,j)==L]*a[i]*a[j]
∑0n−1Wl∑∑[dist(i,j)==L]∗a[i]∗a[j]
大小为1e5
思路:
树上路径考虑点分治,顺便复习了一下点分治,就是每次递归选一个重心,总共期望是
l
o
g
n
logn
logn层,每一层我们可以用
n
l
o
g
n
nlogn
nlogn的复杂度。我们对每一种长度的路径考虑,只要满足i+j==L那么不难想出卷积,每一层递归在新根的作用下,维护A和B数组的卷积,然后贡献给答案那么最后答案就是
∑
1
n
−
1
c
n
t
[
i
]
∗
w
[
i
]
∗
2
+
∑
1
n
w
[
0
]
∗
a
[
i
]
∗
a
[
i
]
\\sum_1^n-1 cnt[i]*w[i]*2 + \\sum_1^n w[0]*a[i]*a[i]
∑1n−1cnt[i]∗w[i]∗2+∑1nw[0]∗a[i]∗a[i],为啥要乘2,因为两个地点不同,有顺序,对于单点的情况,不用乘2。顺便偷一个NTT的板子玩玩
#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
const int mod=998244353;
const int MOD=mod;
const int MAXN=1000010;
const int N=1000010;
const int maxn = 5e6+7;
const int G = 3;
bool st[N];
vector<int>v[N];
int A[maxn],B[maxn];
int a[N];
int pt=0;
int ans=0;
int qt=0;
int W[N];
int maxlen=0;
int C[N];
int cnt[N];
ll qpow(ll a,ll b)
ll ans = 1;
while(b>0)
if(b&1) ans = ans*a%mod;
b>>=1;
a = a*a%mod;
return ans%mod;
int len,r[MAXN];
ll x[MAXN],y[MAXN],w[MAXN];
inline ll q_pow(ll x,ll y)
ll res = 1;
while(y)
if(y & 1) res = res * x % MOD;
x = x * x % MOD;
y >>= 1;
return res;
void NTT(ll *a,ll f)
for(int i = 0;i < len;i ++)
if(i < r[i]) swap(a[i],a[r[i]]);
w[0] = 1;
for(int i = 2;i <= len;i *= 2)
ll wn;
if(f == 1) wn = q_pow(G,(ll)(MOD-1)/i);
else wn = q_pow(G,(ll)(MOD-1)-(MOD-1)/i);
for(int j = i/2;j >= 0;j -= 2) w[j] = w[j/2];
for(int j = 1;j < i/2;j += 2) w[j] = (w[j-1]*wn)%MOD;
for(int j = 0;j < len;j += i)
for(int k = 0;k < i/2;k ++)
ll u = a[j+k],v = (a[j+k+i/2] * w[k]) % MOD;
a[j+k] = (u + v) % MOD;
a[j+k+i/2] = (u - v + MOD) % MOD;
if(f == -1)
ll inv = q_pow(len,MOD-2);
for(int i = 0;i < len;i ++) a[i] = a[i] * inv % MOD;
void MUL(ll *a,ll *b,ll *c,ll n,ll m)
len = 1;
while(len <= (n + m)) len *= 2;
int k = trunc(log(len + 0.5) / log(2));
for(int i = 0;i < len;i ++)
r[i] = (r[i>>1]>>1) | ((i&1) << (k-1));
for(int i = 0;i < len;i ++)
if(i < n) x[i] = a[i];else x[i] = 0;
if(i < m) y[i] = b[i];else y[i] = 0;
NTT(x,1);
NTT(y,1);
for(int i = 0;i < len;i ++) c[i] = x[i] * y[i] % MOD;
NTT(c,-1);
int get_wc(int u,int fa,int tot,int &rt)
if(st[u]) return 0;
int maxv=0;
int sum=1;
for(auto j:v[u])
if(j==fa) continue;
int tt=get_wc(j,u,tot,rt);
sum+=tt;
maxv=max(maxv,tt);
maxv=max(maxv,tot-sum);
if(maxv<=tot/2)
rt=u;
return sum;
int get_sz(int u,int fa)
if(st[u]) return 0;
int res=1;
for(auto j:v[u])
if(j==fa) continue;
res+=get_sz(j,u);
return res;
int mxlen=0;
void get_dist(int u,int fa,int dist)
if(st[u]) return ;
mxlen=max(mxlen,dist);
B[dist]=(B[dist] + a[u])%mod;
for(auto j:v[u])
if(j!=fa)
get_dist(j,u,dist+1);
void dfs(int u,int fa)
if(st[u]) return ;
get_wc(u,fa,get_sz(u,fa),u);
st[u]=true;
maxlen=0;
A[0]=a[u];
for(auto j:v[u])
if(j==fa) continue;
mxlen=0;
get_dist(j,-1,1);
MUL(A,B,C,maxlen+1,mxlen+1);
for(int k=0;k<maxlen+2+mxlen;k++)
cnt[k]=(cnt[k]+C[k])%mod;
for(int k=0;k<=mxlen;k++)
A[k]=(A[k]+B[k])%mod;
B[k]=0;
maxlen=max(maxlen,mxlen);
for(int k=0;k<=maxlen;k++)
A[k]=0;
for(auto j:v[u]) dfs(j,u);
signed main()
int n;
cin>>n;
int sum=0;
for(int i=1;i<=n;i++) cin>>a[i],sum+=a[i],sum%=mod;
for(int i=1;i<=n;i++) a[i]=a[i] * qpow(sum,mod-2)%mod;
for(int i=0;i<=n-1;i++)
cin>>W[i];
for(int i=1;i<=n;i++)
ans=(ans+a[i]*a[i]%mod*W[0])%mod;
for(int i=0;i<n-1;i++)
int a,b;
cin>>a>>b;
v[a].push_back(b);
v[b].push_back(a);
dfs(1,-1);
for(int i=1;i<=n-1;i++)
ans=(ans+2 * cnt[i]*W[i])%mod;
cout<<ans<<endl;
return 0;
以上是关于NTT模板 神奇的迷宫 NTT加点分治的主要内容,如果未能解决你的问题,请参考以下文章