记本题数组长度为\(n\),权值大小为\(m\)。
首先,暴力显然是\(O(n^2)\)的。
先瞄一眼tag,然后发现这是FFT。
显然,问题的关键在于要满足i,j,k之间的位置关系。于是考虑分治FFT。但遗憾的是,我们的分治FFT是对权值进行多项式乘法的,分治并不能使得FFT的规模减小。因此,分治做法在复杂度上就是错误的。
然后考虑分块。以下记块大小为\(K\)。
考虑一下三种情况:
- i,j在同一块中,但k在另一块里。
- j,k在同一块中,但i在另一块里。
- i,j,k都在同一块中。
- i,j,k都不在同一块中。
对于前三种情况,维护从1到每个块末端的权值的前缀和,用暴力就能解决。时间复杂度均为\(O(\frac {n}{K} \times K^2) = O(nK)\)。
对于最后一种情况,我们枚举j在哪一块,然后用FFT生成所有满足i在左边的块里,k在右边的块里的\(a_i+a_k\)的个数,利用\(2a_j=a_i+a_k\)就能统计出答案。这个的时间复杂度是\(O(\frac {n}{K} \times mlogm) < O(\frac {n^2logn}{K})\)。
那么有\(\frac {n^2logn}{K} + nK >= 2n^{\frac{3}{2}}log^{\frac{1}{2}}n\)。即复杂度为\(O(n^{\frac{3}{2}}log^{\frac{1}{2}}n)\)。
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 200010, MAX = 200010;
typedef long double db;
const db pi = acos(-1);
struct cpl {
db x,y;
cpl(db x_=0,db y_=0): x(x_), y(y_) {};
cpl operator + (const cpl& a) const {
return cpl(x + a.x,y + a.y);
}
cpl operator - (const cpl& a) const {
return cpl(x - a.x,y - a.y);
}
cpl operator * (const cpl& a) const {
return cpl(x * a.x - y * a.y,x * a.y + y * a.x);
}
cpl operator * (const db& a) const {
return cpl(x * a,y * a);
}
};
cpl ta[MAX],tb[MAX];
int rev[MAX];
void prework(int n) {
rev[0] = 0;
for (int i = 1 ; i < n ; ++ i)
rev[i] = i&1 ? rev[i-1] | (n>>1) : rev[i>>1]>>1;
}
void fft(cpl* a,int n,int sgn) {
static cpl tmp[MAX];
for (int i = 0 ; i < n ; ++ i)
tmp[rev[i]] = a[i];
cpl wp,w,u,v;
for (int s = 2 ; s <= n ; s <<= 1) {
wp = cpl(cos(2 * pi / s),sin(2 * pi / s));
if (sgn) wp.y = -wp.y;
for (int k = 0 ; k < n ; k += s) {
w = cpl(1,0);
for (int j = 0 ; j < s/2 ; ++ j) {
u = tmp[k + j];
v = tmp[k + j + s/2] * w;
tmp[k + j] = u + v;
tmp[k + j + s/2] = u - v;
w = wp * w;
}
}
}
for (int i = 0 ; i < n ; ++ i)
a[i] = sgn ? tmp[i] * (1.0/n) : tmp[i];
}
const int SZ = 1500;
#define suit(x) ((x) >= 1 && (x) <= mx)
int bel[N],n,arr[N],tmp[N],cnt[N / SZ][MAX];
void solve() {
int mx = 0, ans = 0;
for (int i = 1 ; i <= n ; ++ i)
bel[i] = (i % SZ == 1 ? bel[i-1] + 1 : bel[i-1]);
for (int i = 1 ; i <= n ; ++ i) {
++ tmp[arr[i]];
mx = max(mx,arr[i]);
if (i % SZ == 0 || i == n) {
for (int j = 1 ; j <= mx ; ++ j)
cnt[bel[i]][j] = tmp[j];
}
}
int l = 1;
while (l < mx + mx + 1) l <<= 1;
prework(l);
for (int i = 2 ; i < bel[n] ; ++ i) {
for (int j = 0 ; j < l ; ++ j)
ta[j] = tb[j] = cpl();
for (int j = 1 ; j <= mx ; ++ j)
ta[j] = cpl(cnt[i-1][j],0);
for (int j = 1 ; j <= mx ; ++ j)
tb[j] = cpl(cnt[bel[n]][j] - cnt[i][j],0);
fft(ta,l,0);
fft(tb,l,0);
for (int j = 0 ; j < l ; ++ j)
ta[j] = ta[j] * tb[j];
fft(ta,l,1);
for (int j = 2 ; j <= mx * 2 ; j += 2)
tmp[j] = (int)(ta[j].x + 0.5);
for (int j = 1 ; j <= mx ; ++ j)
ans += tmp[j << 1] * (cnt[i][j] - cnt[i-1][j]);
}
for (int i = 1 ; i <= bel[n] ; ++ i) {
for (int j = (i-1) * SZ + 1 ; j <= i * SZ && j <= n ; ++ j)
for (int k = j + 1 ; k <= i * SZ && k <= n ; ++ k) {
if (suit(2 * arr[j] - arr[k]))
ans += cnt[i-1][2 * arr[j] - arr[k]];
if (suit(2 * arr[k] - arr[j]))
ans += cnt[bel[n]][2 * arr[k] - arr[j]] - cnt[i][2 * arr[k] - arr[j]];
}
}
memset(tmp,0,sizeof tmp);
for (int i = 1 ; i <= bel[n] ; ++ i) {
for (int j = (i-1) * SZ + 1 ; j <= i * SZ && j <= n ; ++ j) {
for (int k = j + 1 ; k <= i * SZ && k <= n ; ++ k) {
if ((!((arr[j] + arr[k]) & 1)) && suit((arr[j] + arr[k]) >> 1))
ans += tmp[(arr[j] + arr[k]) >> 1];
++ tmp[arr[k]];
}
for (int k = j + 1 ; k <= i * SZ && k <= n ; ++ k)
tmp[arr[k]] = 0;
}
}
printf("%lld\n",ans);
}
signed main() {
scanf("%lld",&n);
for (int i = 1 ; i <= n ; ++ i)
scanf("%lld",&arr[i]);
solve();
return 0;
}
小结:这种难以用分治减小规模的问题,不妨用分块来简化。