splay板子

时间:2022-02-13 01:20:18

1, splay的一些基本操作.

  • 使用前要插入$-INF,+INF$保证每个点的前驱后继存在.
  • $get$函数在$x$存在时, 调用后, 根为$x$, 否则根为$x$的前驱或后继
const int N = 1e6+10;
int n, tot, rt, sz;
struct {
int cnt,sz,fa,ch[2],v;
} tr[N];
void pu(int x) {
tr[x].sz=tr[tr[x].ch[0]].sz+tr[tr[x].ch[1]].sz+tr[x].cnt;
}
void rot(int x) {
int y=tr[x].fa,z=tr[y].fa;
int f=tr[y].ch[1]==x;
tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y;
tr[x].ch[f^1]=y,tr[y].fa=x,pu(y);
}
void splay(int x, int s=0) {
for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
}
if (!s) rt=x;
}
void get(int x) {
int cur=rt;
while (x!=tr[cur].v&&tr[cur].ch[x>tr[cur].v]) cur=tr[cur].ch[x>tr[cur].v];
splay(cur);
}
void insert(int x) {
int cur=rt,p=0;
while (cur&&x!=tr[cur].v) p=cur,cur=tr[cur].ch[x>tr[cur].v];
if (cur) ++tr[cur].cnt;
else {
cur=++tot;
if (p) tr[p].ch[x>tr[p].v]=cur,tr[cur].fa=p;
tr[cur].v=x,tr[cur].sz=tr[cur].cnt=1;
}
splay(cur);
}
int pre(int x) {
get(x);
if (tr[rt].v<=x) return rt;
int cur=tr[rt].ch[0];
while (tr[cur].ch[1]) cur=tr[cur].ch[1];
return cur;
}
int nxt(int x) {
get(x);
if (tr[rt].v>=x) return rt;
int cur=tr[rt].ch[1];
while (tr[cur].ch[0]) cur=tr[cur].ch[0];
return cur;
}
void erase(int x) {
int s1=pre(x-1),s2=nxt(x+1);
splay(s1),splay(s2,s1);
int &cur=tr[s2].ch[0];
if (tr[cur].cnt>1) --tr[cur].cnt,splay(cur);
else cur=0;
}

2, splay插入区间,区间翻转等操作.

这时候splay维护的是每个下标对应的权值, 下标通过第k大来查询

  • 使用前要调用$build(a,0,rt,1,2);$
const int N = 1e6+10;
int n, rt, tot;
int a[N];
struct _ {
int sz,v,ch[2],fa,rev;
} tr[N];
void pu(int o) {
tr[o].sz=tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1;
}
void pd(int o) {
if (tr[o].rev) {
swap(tr[o].ch[0],tr[o].ch[1]);
tr[tr[o].ch[0]].rev^=1;
tr[tr[o].ch[1]].rev^=1;
tr[o].rev=0;
}
}
void rot(int x) {
int y=tr[x].fa,z=tr[y].fa;
int f=tr[y].ch[1]==x;
tr[z].ch[tr[z].ch[1]==y]=x,tr[x].fa=z;
tr[y].ch[f]=tr[x].ch[f^1],tr[tr[x].ch[f^1]].fa=y;
tr[x].ch[f^1]=y,tr[y].fa=x,pu(y);
}
void splay(int x, int s=0) {
for (int y; y=tr[x].fa,y!=s; rot(x)) if (tr[y].fa!=s) {
rot((tr[y].ch[0]==x)==(tr[tr[y].fa].ch[0]==y)?y:x);
}
if (!s) rt=x;
}
int find(int x, int k) {
pd(x); int s=tr[tr[x].ch[0]].sz;
if (k==s+1) return x;
if (k<=s) return find(tr[x].ch[0],k);
return find(tr[x].ch[1],k-s-1);
}
void build(int *a, int f, int &o, int l, int r) {
if (l>r) return;
o = ++tot;
tr[o].v = a[mid], tr[o].fa = f;
build(s,o,tr[o].ch[0],l,mid-1);
build(s,o,tr[o].ch[1],mid+1,r);
pu(o);
}
void ins(int x, int n) {
build(a,0,p,1,n);
int s1=find(rt,x-1), s2=find(rt,x);
splay(s1),splay(s2,s1);
tr[s2].ch[0]=p,tr[p].fa=s2;
pu(p),pu(s2);
}
void del(int x, int n) {
int s1=find(rt,x-1), s2=find(rt,x+n);
splay(s1),splay(s2,s1);
tr[s2].ch[0]=0;
pu(s1),pu(s2);
}
void reverse(int x, int n) {
int s1=find(rt,x-1), s2=find(rt,x+n);
splay(s1),splay(s2,s1);
tr[tr[s2].ch[0]].rev^=1;
}