HDU 4638 树状数组 想法题

时间:2023-03-09 19:22:51
HDU 4638 树状数组 想法题

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4638

解题思路:

题意为询问一段区间里的数能组成多少段连续的数。先考虑从左往右一个数一个数添加,考虑当前添加了i - 1个数的答案是x,那么添加完i个数后的答案是多少?可以看出,是根据a[i]-1和a[i]+1是否已经添加而定的,如果a[i]-1或者a[i]+1已经添加一个,则段数不变,如果都没添加则段数加1,如果都添加了则段数减1。设v[i]为加入第i个数后的改变量,那么加到第x数时的段数就是sum{v[i]} (1<=i<=x}。仔细想想,若删除某个数,那么这个数两端的数的改变量也会跟着改变,这样一段区间的数构成的段数就还是他们的v值的和。将询问离线处理,按左端点排序后扫描一遍,左边删除,右边插入,查询就是求区间和。

以上摘自杭电的解题报告

标程:标程是从左边插入,从右边删除

 #include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
using namespace std;
const int N = ;
struct Q
{
int l,r,id;
}query[N];
bool cmp(Q x,Q y)
{
return x.r>y.r;
}
int a[N],c[N],d[N];
int lowbit(int x)
{
return x&(-x);
}
void up(int x,int v)
{
while(x<N)
{
c[x]+=v;
x+=lowbit(x);
}
}
int getsum(int x)
{
int r=;
while(x>)
{
r+=c[x];
x-=lowbit(x);
}
return r;
}
bool u[N];
int ret[N],ps[N];
int main()
{
int T,n,m,i,j;
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
for(i=;i<=n;i++)
{
scanf("%d",&a[i]);
ps[a[i]]=i;
}
memset(c,,sizeof(c));
memset(d,,sizeof(d));
memset(u,,sizeof(u));
for(i=n;i>;i--)
{
if(u[a[i]-])d[i]++;
if(u[a[i]+])d[i]++;
if(d[i]==)
{
up(i,);
}
else if(d[i]==)
{
up(i,-);
}
u[a[i]]=true;
}
for(i=;i<m;i++)
{
scanf("%d%d",&query[i].l,&query[i].r);
query[i].id=i;
}
sort(query,query+m,cmp);
int j=n;
for(i=;i<m;i++)
{
while(j>query[i].r)
{
if(a[j]>&&ps[a[j]-]<j)
{
d[ps[a[j]-]]--;
up(ps[a[j]-],);
}
if(a[j]<n&&ps[a[j]+]<j)
{
d[ps[a[j]+]]--;
up(ps[a[j]+],);
}
j--;
}
ret[query[i].id]=getsum(query[i].r)-getsum(query[i].l-);
}
for(i=;i<m;i++)printf("%d\n",ret[i]);
}
return ;
}

我觉得看完解题思路后我还是不太明白,就模拟了一下样例。

3 1 2 5 4

得到的数组v 为 1 1 -1 1 0

这时查询任意1-k区间的段数=sum(vi)(1=<i<=k),都是对的···神奇···

在删除3后,3-1所在的位置为3,即v[3]=v[3]+1,3+1所在的位置为5。++v[5]。

v数组变为 1  1 0 1 0 ,查询2-i区间的段数=sum(vi)(2=<i<=k),同样是对的···更神奇···有兴趣的可以自己证明

我的代码:

 #include <cstdio>
#include <algorithm>
#define N 100005
using namespace std;
struct Node
{
int l,r,index;
} p[N];
int a[N],rank[N],ans[N],c[N];
int n,m; bool cmp(Node a,Node b)
{
return a.l < b.l;
}
int lowbit(int x)
{
return x&(-x);
}
void add(int x,int v)
{
while(x<N)
{
c[x] += v;
x += lowbit(x);
}
}
int sum(int x)
{
int s =;
while(x)
{
s += c[x];
x -= lowbit(x);
}
return s;
} void init()
{
scanf("%d%d",&n,&m);
rank[] =rank[n+] = n+;
for(int i=; i<=n; ++i)
{
scanf("%d",&a[i]);
rank[a[i]] = i;
}
for(int i=; i<m; ++i)
{
scanf("%d%d",&p[i].l,&p[i].r);
p[i].index = i;
}
}
int main()
{
// freopen("in.cpp","r",stdin);
int t;
scanf("%d",&t);
while(t--)
{
init();
memset(c, , sizeof(c) );
for(int i=; i<=n; ++i)
{
int d =;
if(rank[a[i]+] < i ) ++d;
if(rank[a[i]-] < i ) ++d;
if(d == )
add(i,);
else if(d == )
add(i,-);
}
sort(p,p+m,cmp);
rank[] = rank[n+] = ;
int j=;
for(int i=; i<m; ++i)
{
int t = p[i].l;
while(j < t)
{
if(rank[a[j]+] > j)
add(rank[a[j]+],);
if(rank[a[j]-] > j)
add(rank[a[j]-],);
++j;
}
ans[p[i].index] = sum(p[i].r)-sum(p[i].l-);
}
for(int i=; i<m; ++i)
printf("%d\n",ans[i]);
}
return ;
}