题目大意
给你三个$1$到$n$的排列$ai, bi, ci$
称三元组$(x, y, z)$是合法的,当且仅当存在一个下标集合$S \in [n]$满足
$(x, y, z) = (max_{i \in S}a_i, max_{i \in S}b_i, max_{i \in S}c_i)$
询问合法三元组的数量。
题解
神仙题.jpg
考虑只保留一个合法下标集合中最大值所处的下标,将这样的集合称作最简下标集合。
可以发现,一组下标集合对应了一个合法三元组。
于是,现在问题转换成了统计最简下标集合的个数。
可以发现,一个最简下标集合$s$的大小$|s| \leq 3$,因此可以对于每种大小分别统计。
对于$|s| == 1$的情况,显然答案就是$n$。
对于$|s| == 2$的情况,考虑容斥,统计不合法的大小为$2$的下标集合方案数,再用总数减去不合法的数量。
可以发现,一个下标集合不合法当且仅当其中一个下标在$a,b,c$三个值上都要大于另一个。
于是这就是个三维偏序了,排序+cdq+树状数组即可。
对于$|s| == 3$的情况,同样考虑容斥。
这时可以发现,不合法的状态有如下四种:
令$x,y,z$为一个下标集合中的三个下标,$A,B,C$为最大值所处的位置:
X A B C Y Z
X A B Y C Z
X A C Y B Z
X B C Y A Z
为了方便描述,将$1$定义为情况$A$,$2,3,4$统称为情况$B$。
对于统计情况$A$,可以发现这还是个三维偏序......
对于情况$B$,可以发现直接统计不太方便。
考虑在$a,b,c$三个数组中枚举其中两个,统计有多少下标三元组满足其中一个下标在被枚举的两个数组中的值均大于另外两个下标。
将枚举得到的答案定义为$C$,可以发现$A$在$C$中被计算了$3$次,而$B$恰好被计算了一次。
于是就有$3*A+B=C$。
于是就可以解出$B$的值了。
然后用所有三元组方案减去$A$和$B$即可得到$|s| == 3$时的答案。
于是三种答案累加即可得到最终答案~
代码:
#include<bits/stdc++.h>
using namespace std;
inline int read()
{
int x=0;char ch=getchar();
while(ch<'0' || '9'<ch)ch=getchar();
while('0'<=ch && ch<='9')x=x*10+(ch^48),ch=getchar();
return x;
}
typedef long long ll;
const int N=100009;
struct node
{
int a,b,c,id;
}p[N];
ll n;
int a[N],b[N],c[N];
ll ans[N];
inline bool cmpa(node a,node b){return a.a<b.a;}
inline bool cmpb(node a,node b){return a.b<b.b;}
namespace bit
{
ll bi[N];
inline void modify(int x,ll v)
{
for(int i=x;i<=n;i+=i&-i)
bi[i]+=v;
}
inline ll query(int x)
{
ll ret=0;
for(int i=x;i;i-=i&-i)
ret+=bi[i];
return ret;
}
}
inline void cdq(int l,int r)
{
if(l==r)return;
int mid=l+r>>1,pl=l;
cdq(l,mid);cdq(mid+1,r);
for(int pr=mid+1;pr<=r;pr++)
{
while(pl<=mid && p[pl].b<p[pr].b)
bit::modify(p[pl++].c,1);
ans[p[pr].id]+=bit::query(p[pr].c-1);
}
for(int i=l;i<pl;i++)
bit::modify(p[i].c,-1);
sort(p+l,p+r+1,cmpb);
}
namespace p2
{
inline ll run()
{
ll ret=n*(n-1)/2ll;
for(int i=1;i<=n;i++)
ret-=ans[i];
return ret;
}
}
namespace p3
{
inline ll run1()
{
ll anss=0;
for(int i=1;i<=n;i++)
anss+=ans[i]*(ans[i]-1)/2ll;
return anss;
}
inline ll run2_p(int *x,int *y)
{
for(int i=1;i<=n;i++)
p[i].a=x[i],p[i].b=y[i];
sort(p+1,p+n+1,cmpa);
ll anss=0,tmp;
for(int i=1;i<=n;i++)
{
tmp=bit::query(p[i].b);
anss+=(tmp-1)*tmp/2ll;
bit::modify(p[i].b,1);
}
memset(bit::bi,0,sizeof(bit::bi));
return anss;
}
ll run2()
{
ll ans=0;
ans+=run2_p(a,b);
ans+=run2_p(a,c);
ans+=run2_p(b,c);
return ans;
}
inline ll calc()
{
ll sa=run1();
ll sx=run2();
ll sb=sx-3ll*sa;
return n*(n-1ll)*(n-2ll)/6ll-sa-sb;
}
}
int main()
{
freopen("subset.in","r",stdin);
freopen("subset.out","w",stdout);
n=read();
for(int i=1;i<=n;i++)a[i]=read();
for(int i=1;i<=n;i++)b[i]=read();
for(int i=1;i<=n;i++)c[i]=read();
for(int i=1;i<=n;i++)
p[i]=(node){a[i],b[i],c[i],i};
sort(p+1,p+n+1,cmpa);
cdq(1,n);
printf("%lld\n",n+p2::run()+p3::calc());
return 0;
}