[BZOJ3489]A simple rmq problem(kd-tree)

时间:2022-12-17 17:03:11

题目描述

传送门

题解

听说这道题是可以写主席树套树的(ATP%%%)
我的做法是三维kd-tree,分别是:这个点的位置,上一个相同数字的位置,下一个相同数字的位置
然后每一个有一个点权即为这一位上的数字,然后对于每一个子树维护位置的最大值和最小值,以及前面的最小值和后面的最大值,这样来判断、暴力即可
我刚开始的时候强行把点权塞到一维里去,然后吃惊地发现比3d要快一些,大概是因为某些时候按照点权这一维度划分能一次卡掉更多的点

卡常技巧:能不维护的就不维护

代码

3d

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
#define N 100005

int n,m,cmpd,root,l,r,ans;
int a[N],head[N],nxt[N],tail[N],pre[N];
struct data
{
    int l,r;
    int d[3];
    int val,Max,mn[3],mx[3];
}tr[N];

int cmp(data a,data b)
{
    return a.d[cmpd]<b.d[cmpd];
}
void update(int now)
{
    int l=tr[now].l,r=tr[now].r;
    if (l) tr[now].Max=max(tr[now].Max,tr[l].Max);
    if (r) tr[now].Max=max(tr[now].Max,tr[r].Max);
    for (int i=0;i<3;++i)
    {
        if (l) tr[now].mx[i]=max(tr[now].mx[i],tr[l].mx[i]),tr[now].mn[i]=min(tr[now].mn[i],tr[l].mn[i]);
        if (r) tr[now].mx[i]=max(tr[now].mx[i],tr[r].mx[i]),tr[now].mn[i]=min(tr[now].mn[i],tr[r].mn[i]);
    }
}
int build(int l,int r,int d)
{
    if (d==3) d=0;
    int mid=(l+r)>>1;
    cmpd=d;
    nth_element(tr+l,tr+mid,tr+r+1,cmp);
    for (int i=0;i<3;++i)
        tr[mid].mx[i]=tr[mid].mn[i]=tr[mid].d[i];
    if (l<mid) tr[mid].l=build(l,mid-1,d+1);
    if (mid<r) tr[mid].r=build(mid+1,r,d+1);
    update(mid);
    return mid;
}
int calc(int now)
{
    if (tr[now].mx[0]<l||tr[now].mn[0]>r||tr[now].mn[1]>=l||tr[now].mx[2]<=r) return -1;
    return tr[now].Max;
}
void query(int now)
{
    if (tr[now].val>ans&&tr[now].d[0]>=l&&tr[now].d[0]<=r&&tr[now].d[1]<l&&tr[now].d[2]>r)
        ans=tr[now].val;
    int dl,dr;
    if (tr[now].l) dl=calc(tr[now].l);
    else dl=-1;
    if (tr[now].r) dr=calc(tr[now].r);
    else dr=-1;
    if (dl>dr)
    {
        if (dl>ans) query(tr[now].l);
        if (dr>ans) query(tr[now].r);
    }
    else
    {
        if (dr>ans) query(tr[now].r);
        if (dl>ans) query(tr[now].l);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;++i) scanf("%d",&a[i]);
    for (int i=1;i<=n;++i) head[i]=n+1;
    for (int i=1;i<=n;++i) pre[i]=tail[a[i]],tail[a[i]]=i;
    for (int i=n;i>=1;--i) nxt[i]=head[a[i]],head[a[i]]=i;
    for (int i=1;i<=n;++i)
    {
        tr[i].d[0]=i,tr[i].d[1]=pre[i],tr[i].d[2]=nxt[i];
        tr[i].val=tr[i].Max=a[i];
    }
    root=build(1,n,0);
    for (int i=1;i<=m;++i)
    {
        scanf("%d%d",&l,&r);
        l=(l+ans)%n+1;r=(r+ans)%n+1;
        if (l>r) swap(l,r);
        ans=0;query(root);
        printf("%d\n",ans);
    }
}

4d

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
#define N 100005

int n,m,cmpd,root,l,r,ans;
int a[N],head[N],nxt[N],tail[N],pre[N];
struct data
{
    int l,r;
    int d[4];
    int val,mn,mx,pre,nxt;
}tr[N];

int cmp(data a,data b)
{
    return a.d[cmpd]<b.d[cmpd];
}
void update(int now)
{
    int l=tr[now].l,r=tr[now].r;
    if (l)
    {
        tr[now].val=max(tr[now].val,tr[l].val);
        tr[now].mx=max(tr[now].mx,tr[l].mx);
        tr[now].mn=min(tr[now].mn,tr[l].mn);
        tr[now].pre=min(tr[now].pre,tr[l].pre);
        tr[now].nxt=max(tr[now].nxt,tr[l].nxt);
    }
    if (r)
    {
        tr[now].val=max(tr[now].val,tr[r].val);
        tr[now].mx=max(tr[now].mx,tr[r].mx);
        tr[now].mn=min(tr[now].mn,tr[r].mn);
        tr[now].pre=min(tr[now].pre,tr[r].pre);
        tr[now].nxt=max(tr[now].nxt,tr[r].nxt);
    }
}
int build(int l,int r,int d)
{
    if (d==4) d=0;
    int mid=(l+r)>>1;
    cmpd=d;
    nth_element(tr+l,tr+mid,tr+r+1,cmp);
    tr[mid].val=tr[mid].d[0];
    tr[mid].mx=tr[mid].mn=tr[mid].d[1];
    tr[mid].pre=tr[mid].d[2];
    tr[mid].nxt=tr[mid].d[3];
    if (l<mid) tr[mid].l=build(l,mid-1,d+1);
    if (mid<r) tr[mid].r=build(mid+1,r,d+1);
    update(mid);
    return mid;
}
int calc(int now)
{
    if (tr[now].mx<l||tr[now].mn>r||tr[now].pre>=l||tr[now].nxt<=r) return -1;
    return tr[now].val;
}
void query(int now)
{
    if (tr[now].d[0]>ans&&tr[now].d[1]>=l&&tr[now].d[1]<=r&&tr[now].d[2]<l&&tr[now].d[3]>r)
        ans=tr[now].d[0];
    int dl,dr;
    if (tr[now].l) dl=calc(tr[now].l);
    else dl=-1;
    if (tr[now].r) dr=calc(tr[now].r);
    else dr=-1;
    if (dl>dr)
    {
        if (dl>ans) query(tr[now].l);
        if (dr>ans) query(tr[now].r);
    }
    else
    {
        if (dr>ans) query(tr[now].r);
        if (dl>ans) query(tr[now].l);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;++i) scanf("%d",&a[i]);
    for (int i=1;i<=n;++i) head[i]=n+1;
    for (int i=1;i<=n;++i) pre[i]=tail[a[i]],tail[a[i]]=i;
    for (int i=n;i>=1;--i) nxt[i]=head[a[i]],head[a[i]]=i;
    for (int i=1;i<=n;++i)
        tr[i].d[0]=a[i],tr[i].d[1]=i,tr[i].d[2]=pre[i],tr[i].d[3]=nxt[i];
    root=build(1,n,0);ans=0;
    for (int i=1;i<=m;++i)
    {
        scanf("%d%d",&l,&r);
        l=(l+ans)%n+1;r=(r+ans)%n+1;
        if (l>r) swap(l,r);
        ans=0;query(root);
        printf("%d\n",ans);
    }
}