题意:给定一个序列,求多少个三元组满足ai+ak=2*aj(i<j<k)。
题解:原来叉姐的讲义上有啊。。完全忘掉了。。
首先这个式子很明显是一个卷积。我们有了FFT的思路。但是肯定不能每一个数都去做一次。那怎么办呢?我们分块吧!(分块大法好)
对于每一个块我们统计出前面块的桶,同理我们也有后面块的桶,两个桶FFT一下我们就得到了以这个块内元素为j,i和k分别在前面的块与后面的块的方案了。然后我们还要解决两个在一个块,三个在一个块的问题。两个在一个块的我们直接去前后的桶里找,同一个块的直接n*n暴力。然后就做完啦!好妙啊!
这题被坑了好久。。因为空间莫名其妙的问题怎么都算不对(块开极端都可以,就是开中间不行),然后一个下午没有了。。
#include<bits/stdc++.h>
using namespace std;
#define LL long long
#define N 205005
#define INF 1e9
#define Bl 70
#define LIM 60000
const double PI=acos(-); inline LL read(){
LL x=,f=; char a=getchar();
while(a<'' || a>'') {if(a=='-') f=-; a=getchar();}
while(a>='' && a<='') x=x*+a-'',a=getchar();
return x*f;
} namespace FFT{
int rev[N]; struct vec{
double r,i;
vec operator * (const vec& w){return (vec){r*w.r-i*w.i,i*w.r+r*w.i};}
vec operator + (const vec& w){return (vec){r+w.r,i+w.i};}
vec operator - (const vec& w){return (vec){r-w.r,i-w.i};}
}A[N],B[N]; inline void fft(vec* x,int len,int f){
for(int i=;i<=len;i++) if(i<rev[i]) swap(x[i],x[rev[i]]);
for(int lnow=;lnow<=len;lnow<<=){
vec w,w0=(vec){cos(2.0*PI/lnow*f),sin(2.0*PI/lnow*f)},t1,t2;
for(int i=;i<len;i+=lnow){
w=(vec){,};
for(int j=i;j<i+lnow/;j++){
t1=x[j]; t2=w*x[j+lnow/];
x[j]=t1+t2; x[j+lnow/]=t1-t2;
w=w*w0;
}
}
}
} inline void work(int a[],int b[],int l1,int l2,LL s[]){
int len,t;
for(len=,t=;len<=(l1+l2+);len<<=,t++); t=<<(t-);
for(int i=;i<=len;i++) rev[i]=rev[i>>]>>|(i&?t:);
for(int i=;i<=len;i++) B[i]=A[i]=(vec){,};
for(int i=;i<=l1;i++) A[i].r=a[i];
for(int i=;i<=l2;i++) B[i].r=b[i];
fft(A,len,); fft(B,len,);
for(int i=;i<=len;i++) A[i]=A[i]*B[i];
fft(A,len,-);
for(int i=;i<=l1+l2;i++)
s[i]=(LL)(1.0*A[i].r/len+0.5);
} } int n,block_size,block_num;
int bel[N],l[Bl+],r[Bl+],a[N];
LL tot,ans[*LIM+];
int lsum[LIM+],rsum[LIM+],cnt[*LIM+]; inline void brutal_force(int x){
for(int i=l[x];i<=r[x];i++) rsum[a[i]]--;
memset(ans,,sizeof(ans));
FFT::work(lsum,rsum,,,ans);
for(int i=l[x];i<=r[x];i++){
tot+=ans[*a[i]];
for(int j=l[x];j<i;j++)
if(*a[i]-a[j]>) tot+=rsum[*a[i]-a[j]];
for(int j=i+;j<=r[x];j++)
if(*a[i]-a[j]>) tot+=lsum[*a[i]-a[j]];
}
for(int i=l[x];i<=r[x];i++) lsum[a[i]]++;
memset(cnt,,sizeof(cnt));
for(int i=l[x];i<=r[x];i++){
tot+=cnt[a[i]];
for(int j=l[x];j<i;j++)
if(*a[i]-a[j]>) cnt[*a[i]-a[j]]++;
}
} int main(){
n=read(); block_size=;
block_num=(n-)/block_size+;
for(int i=;i<=n;i++) a[i]=read(),bel[i]=(i-)/block_size+;
for(int i=;i<=block_num;i++) l[i]=(i-)*block_size+,r[i]=min(n,i*block_size);
for(int i=;i<=n;i++) rsum[a[i]]++;
for(int i=;i<=block_num;i++) brutal_force(i);
printf("%lld\n",tot);
return ;
}