https://www.codechef.com/status/COUNTARI
题意:
给出n个数,求满足i<j<k且a[j]-a[i]==a[j]-a[k] 的三元组(i,j,k)的个数
n^2 做法:
枚举j和k,当j右移时,令sum[num[右移之前j的值]]++
每次统计sum[num[j]*2-num[k]]即可
如果没有i<j<k,直接上FFT
但是有了这个限制,可以枚举j,再FFT,复杂度为n*n*log(30000)
考虑一次FFT只算1个j有点儿浪费
能不能算好几个j?
分块!
设每一块的大小为S
答案分三种:
一、3个数都在一个块
用平方复杂度的做法,枚举同一块内的j和k,总时间复杂度为O(n/S*S*S)=O(n*S)
二、2个数在两个块
如果在同一块的数是j和k,从第一块开始枚举j和k,记录前面块的sum,累加sum[num[j]*2-num[k]]
如果在同一块的数是i和j,从最后一块开始枚举i和j,记录后面块的sum,累加sum[num[j]*2-num[i]]
总时间复杂度为O(n/S*S*S)=O(n*S)
三、3个数在三个块
枚举中间的的那一块,sumL记录这个块左边所有数,sumR记录这个块右边所有数
用FFT对sumL和sumR做一次卷积,得到sum
枚举中间那一块的每个数j,累加sum[num[j]*2]
FFT的一个小细节:
不能出现次数为0的项,所以所有数向左移一位,所以最后得到的sum向左移了两位,实际累加sum[num[j]*2-2]
#include<cmath> #include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; #define N 100001 #define M 30001 #define S 500 const int K=(1<<16)+2; typedef long long LL; const double pi=acos(-1); int n,mx,a[N]; int l[M],r[M]; int len=1,rev[K]; struct Complex { double x,y; Complex(double x_=0,double y_=0):x(x_),y(y_){} Complex operator + (Complex P) { return Complex(x+P.x,y+P.y); } Complex operator - (Complex P) { return Complex(x-P.x,y-P.y); } Complex operator * (Complex P) { return Complex(x*P.x-y*P.y,x*P.y+y*P.x); } }; typedef Complex E; E A[K],B[K]; LL ans; void read(int &x) { x=0; char c=getchar(); while(!isdigit(c)) c=getchar(); while(isdigit(c)) { x=x*10+c-‘0‘; c=getchar(); } } void fft(E *a,int ty) { for(int i=0;i<len;++i) if(i<rev[i]) swap(a[i],a[rev[i]]); for(int i=1;i<len;i<<=1) { E wn(cos(pi/i),ty*sin(pi/i)); for(int p=i<<1,j=0;j<len;j+=p) { E w(1,0); for(int k=0;k<i;++k,w=w*wn) { E x=a[j+k],y=a[j+k+i]*w; a[j+k]=x+y; a[j+k+i]=x-y; } } } if(ty==-1) { for(int i=0;i<len;++i) a[i].x=a[i].x/len+0.5; } } void three() { int num=mx*2-2,bit=0; while(len<=num) len<<=1,bit++; for(int i=0;i<len;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1); for(int i=1;i<=n;++i) r[a[i]]++; int ed; for(int t=1;t<=n;t+=S) { ed=min(n,t+S-1); for(int i=t;i<=ed;++i) r[a[i]]--; for(int i=0;i<mx;++i) A[i].x=l[i+1],A[i].y=0; for(int i=mx;i<len;++i) A[i].x=A[i].y=0; fft(A,1); for(int i=0;i<mx;++i) B[i].x=r[i+1],B[i].y=0; for(int i=mx;i<len;++i) B[i].x=B[i].y=0; fft(B,1); for(int i=0;i<len;++i) A[i]=A[i]*B[i]; fft(A,-1); for(int i=t;i<=ed;++i) ans+=A[(a[i]<<1)-2].x; for(int i=t;i<=ed;++i) l[a[i]]++; } memset(l,0,sizeof(l)); } void two() { int ed; for(int t=1;t<=n;t+=S) { ed=min(n,t+S-1); for(int j=t;j<ed;++j) for(int k=j+1;k<=ed;++k) if(a[j]<<1>a[k] && (a[j]<<1)-a[k]<=mx) ans+=l[(a[j]<<1)-a[k]]; for(int i=t;i<=ed;++i) l[a[i]]++; } memset(l,0,sizeof(l)); int t=0,st; while(t<n) t+=S; t-=S; for(int i=t+1;i<=n;++i) r[a[i]]++; for(;t>0;t-=S) { st=t-S+1; for(int i=st;i<t;++i) for(int j=i+1;j<=t;++j) if(a[j]<<1>a[i] && (a[j]<<1)-a[i]<=mx) ans+=r[(a[j]<<1)-a[i]]; for(int i=st;i<=t;++i) r[a[i]]++; } } void one() { int ed; for(int t=1;t<=n;t+=S) { ed=min(t+S-1,n); for(int j=t;j<=ed;++j) { for(int k=j+1;k<=ed;++k) if(a[j]<<1>a[k] && (a[j]<<1)-a[k]<=mx) ans+=l[(a[j]<<1)-a[k]]; l[a[j]]++; } for(int j=t;j<=ed;++j) l[a[j]]--; } } int main() { read(n); for(int i=1;i<=n;++i) read(a[i]),mx=max(mx,a[i]); three(); two(); one(); cout<<ans; }