http://www.lydsy.com/JudgeOnline/problem.php?id=5024
首先吐槽一下题面是有错误的
那一个"或" 应该改成","
这道题目条件是非常绕的
先看一个简化的问题
我们对于\(u\), \(v\)
如果\(u-v\)是良的 那么染成白色
否则染成黑色
问题转化成一个经典问题
给定一个完全图 求同色三角形的个数
做法就是考虑非同色三角形的个数
找角的个数\(jiao\) 答案就是\(C_n ^ 3 - jiao / 2\)
考虑如何在\(O(distance(u, v))\) 判断是否\(u - v\)是否是良的
观察计算\(P\)的方式
由于是末尾是\(011\)
可以考虑在模\(8\)的意义下来计算
于是$P = 1 \times s_1 + 5 \times s_2 + 1 \times s_3 ... $
此时考虑\(1 = 001_2\), \(5 = 101_2\)
末尾的\(01\)是相同的
于是我们只需要考虑倒数第三位是否是\(1\)或者是\(0\)
倒数第三位的贡献是由谁提供的呢?
分成两部分
- 所有数的和 的倒数第三位
- 偶数位的数 的倒数第一位
由于第一部分贡献已知
只需要看第二部分是否满足
问题变成
考虑一个\(01\)串 可以从头尾取
是否能满足偶数位为\(0\)为\(1\)
如果全\(0\)全\(1\)那么答案固定
否则\(01\)都是满足的
证明的话考虑两头类似于括号序列的消去
到现在为止
我们可以在\(O(distance(u, v))\) 判断是否\(u - v\)是否是良的
于是就可以树形dp了啊
记录状态\(f[sum][dep][k]\) 表示子树的答案
\(sum\)表示路径的和 模 \(8\)
\(dep\)表示最低位的和 模 \(4\)
\(k\) 表示 \(all_0 / all_1 / others\)
然后我们就求出了一个点为根的答案
然后换根 dp 一下就好了
复杂度\(O(n)\)
以下代码bzoj不能过 爆栈了.. xjoi可以过..
不会写手工栈.jpg
#pragma GCC optimize(2)
#pragma comment(linker, "/STACK:2048000000,2048000000")
#include<bits/stdc++.h>
#define int long long
#define fo(i, n) for(int i = 1; i <= (n); i ++)
#define out(x) cerr << #x << " = " << x << "\n"
#define type(x) __typeof((x).begin())
#define foreach(it, x) for(type(x) it = (x).begin(); it != (x).end(); ++ it)
using namespace std;
// by piano
template<typename tp> inline void read(tp &x) {
x = 0; char c = getchar(); bool f = 0;
for(; c < '0' || c > '9'; f |= (c == '-'), c = getchar());
for(; c >= '0' && c <= '9'; x = (x << 3) + (x << 1) + c - '0', c = getchar());
if(f) x = -x;
}
template<typename tp> inline void arr(tp *a, int n) {
for(int i = 1; i <= n; i ++)
cout << a[i] << " ";
puts("");
}
const int N = 3e5 + 233;
struct E {
int nxt, to;
}e[N << 1];
int head[N], e_cnt = 0;
inline void add(int u, int v) {
e[++ e_cnt] = (E) {head[u], v}; head[u] = e_cnt;
}
struct Node {
// sum % 8, dep % 4, all_0 / all_1 / others
int f[8][4][3];
inline void clear(void) {
memset(f, 0, sizeof f);
}
inline void init(int val, int fff) {
int sum = val % 8, dep = val & 1;
int k = dep;
f[sum][dep][k] += fff;
}
inline void ovo(void) {
for(int sum = 0; sum < 8; sum ++)
for(int dep = 0; dep < 4; dep ++)
for(int k = 0; k < 3; k ++)
if(f[sum][dep][k])
printf("f[%lld][%lld][%lld] = %lld\n", sum, dep, k, f[sum][dep][k]);
}
}p[N], tmp;
int n, fat, val[N], ans[N];
inline void U(Node &a, Node b, int fff) {
for(int sum = 0; sum < 8; sum ++)
for(int dep = 0; dep < 4; dep ++)
for(int k = 0; k < 3; k ++)
a.f[sum][dep][k] += b.f[sum][dep][k] * fff;
}
inline void Get(Node u, int val) {
tmp.clear();
for(int sum = 0; sum < 8; sum ++) {
for(int dep = 0; dep < 4; dep ++) {
int ns = (val % 8 + sum) % 8;
int nd = ((val & 1) + dep) % 4;
int ok = val & 1;
if(ns >= 8 || nd >= 4) while(1);
for(int k = 0; k < 3; k ++) {
if(ok == k)
tmp.f[ns][nd][k] += u.f[sum][dep][k];
else
tmp.f[ns][nd][2] += u.f[sum][dep][k];
}
}
}
}
inline int Getans(const Node &u) {
int ans = 0;
for(int dep = 0; dep < 4; dep ++)
ans += u.f[3][dep][0] + u.f[3][dep][2];
for(int dep = 0; dep < 4; dep ++)
if((dep / 2) % 2 == 0)
ans += u.f[3][dep][1];
for(int dep = 0; dep < 4; dep ++) {
if((dep / 2) % 2 == 1)
ans += u.f[7][dep][1];
ans += u.f[7][dep][2];
}
return ans * (n - 1 - ans);
}
inline void dfs(int u, int fat) {
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v != fat) dfs(v, u), Get(p[v], val[u]), U(p[u], tmp, 1);
}
p[u].init(val[u], 1);
}
inline void frt(int u, int fat) {
p[u].init(val[u], -1);
ans[u] = Getans(p[u]);
p[u].init(val[u], 1);
for(int i = head[u]; i; i = e[i].nxt) {
int v = e[i].to;
if(v != fat) {
Get(p[v], val[u]); U(p[u], tmp, -1);
Get(p[u], val[v]); U(p[v], tmp, 1);
frt(v, u);
Get(p[u], val[v]); U(p[v], tmp, -1);
Get(p[v], val[u]); U(p[u], tmp, 1);
}
}
}
main(void) {
read(n);
for(int i = 1; i <= n; i ++) {
read(fat); read(val[i]); val[i] &= 7;
add(i, fat); add(fat, i);
}
dfs(1, 0); frt(1, 0);
int res = 0;
fo(i, n) res += ans[i];
res /= 2;
cout << n * (n - 1) * (n - 2) / 6 - res << "\n";
}