主席树(可持久化线段树)

时间:2024-01-23 21:52:03

主席树

前言

主席树也是一种数据结构,是线段树的进阶,作用是可以保留历史版本,支持回退,就是回到之前某个未修改的状态。就像我们在写博客时如果误删了重要的内容,就可以用 Ctrl + z 撤回一次操作,也可以用Ctrl + Shift +z 还原我们撤回的操作,这就需要存下每次编辑的操作。

基本原理

可持久化线段树,顾名思义,是可持久的线段树(好像是废话)。那么问题就在于怎么去可持久化,支持访问历史版本。
最暴力的,直接复制整棵树,再对新的这棵树进行修改。但这样的话,节点数为 \(n\),操作数为 \(m\),时间复杂就是 \(O(nm)\)img
好了,直接 TLE。显然复制这棵树会有非常多的浪费,思考一下,我们只需要复制修改后发生改变的点就好了。

img
img

我们要复制的只是蓝色的点其他的保持不变,要在蓝色的点边上复制一个一模一样的点然后修改。如图:
img
红色的点即为修改后的点,继承原节点的所有信息,包括未修改的儿子。这样我们只需要存下每次修改后的根,就能访问该版本的树。修改的时间复杂度为 \(O(\log{n})\)

操作

前置操作

定义节点,存 5 个信息:左右儿子,左右端点,权值。

#define lc a[u].l
#define rc a[u].r
#define L a[u].ls
#define R a[u].rs
struct tree
{
    int l,r;
    int ls,rs;
    int v;
}a[N];

新建节点

int tot;
int New(int u)
{
    a[++tot]=a[u];//结构体直接复制所有信息
    return tot;
}

建树

和线段树类似,不过节点编号依次增加,但左右儿子不一定是乘二和乘二加一了。

int build(int u,int l,int r)
{
    u=++tot;
    L=l,R=r;//更新区间端点
    if(l==r) 
    {
        a[u].v=w[l];
        return u;
    }
    int mid=(l+r)>>1;
    lc=build(lc,l,mid);
    rc=build(rc,mid+1,r);
    return u;
}

修改

先新建节点,然后在新建的树中修改,要记得更新修改了的儿子。

int modify(int u,int x,int v)
{
    u=New(u);
    if(L==R) a[u].v=v;
    else 
    {
        int mid=(L+R)>>1;
        if(x<=mid) lc=modify(lc,x,v);
        else rc=modify(rc,x,v);
    }
    return u;
}

查询

和线段树一毛一样的。

int query(int u,int x)
{
    if(L==R) return a[u].v;
    else 
    {
        int mid=(L+R)>>1;
        if(x<=mid) return query(lc,x);
        else return query(rc,x);
    }
}

存根

每次修改后最终返回的是新的根节点,所以直接用数组存就好了。

P3919 【模板】可持久化线段树 1(可持久化数组)

P3919 【模板】可持久化线段树 1(可持久化数组)

code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define lc a[u].l
#define rc a[u].r
#define L a[u].ls
#define R a[u].rs
const int M=3e7+5,N=1e6+5;

struct tree
{
    int l,r;
    int ls,rs;
    int v;
}a[M];

int n,tot,w[N],rt,root[N],m;
int New(int u)
{
    a[++tot]=a[u];//结构体直接复制所有信息
    return tot;
}
int build(int u,int l,int r)
{
    u=++tot;
    L=l,R=r;//更新区间端点
    if(l==r) 
    {
        a[u].v=w[l];
        return u;
    }
    int mid=(l+r)>>1;
    lc=build(lc,l,mid);
    rc=build(rc,mid+1,r);
    return u;
}
int modify(int u,int x,int v)
{
    u=New(u);
    if(L==R) a[u].v=v;
    else 
    {
        int mid=(L+R)>>1;
        if(x<=mid) lc=modify(lc,x,v);
        else rc=modify(rc,x,v);
    }
    return u;
}
int query(int u,int x)
{
    if(L==R) return a[u].v;
    else 
    {
        int mid=(L+R)>>1;
        if(x<=mid) return query(lc,x);
        else return query(rc,x);
    }
}
int main ()
{
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>w[i];
    root[0]=build(0,1,n);
    rt=0;
    for(int i=1,op,x,y;i<=m;i++)
    {
        cin>>rt>>op>>x;
        if(op==1) cin>>y,root[i]=modify(root[rt],x,y);
        else cout<<query(root[rt],x)<<"\n",root[i]=root[rt];
    }
    return 0;
}

要注意节点实际个数。

P3834 【模板】可持久化线段树 2

P3834 【模板】可持久化线段树 2

分析

首先,因为数据跨度太大,所以先离散化。嗯,不会?没事,看 离散化

然后,基于数值建立线段树,即权值线段树,维护一段值域内的数的个数。权值线段树求整个区间的第 \(k\) 小的数,就是在权值线段树上二分,从根节点开始,如果 \(k\) 比右儿子大,就说明第 \(k\) 小的数在右儿子所表示的值域中。接着,\(k\) 要减去左儿子,再进入右儿子中继续;如果 \(k\) 比左儿子小,就直接进入左儿子。

那么好,子区间 \([l,r]\) 的第 \(k\) 小的数怎么求呢?(区间都是离散化好了的)
首先思考求 \([1,r]\) 的第 \(k\) 小的数,用主席树,依次加入 \([1,r]\) 的值的节点,生成 \(r\) 个版本,从 \(root[r]\) 开始找就可以了。
那么右边有限制,左边也有限制,该怎么办呢?先明确一点,主席树,每次修改后,对于每个根节点对应的每一棵树,结构都是相同的,就是拓扑排序相同。
也就是说,所有节点对应的区间都是一一对应的。在根节点为 \(root[R]\) 的树中值域为 \([l,r]\) 的个数为 \(cnt1\),在根节点为 \(root[L-1]\) 的树中值域为 \([l,r]\) 的个数为 \(cnt2\),类似于前缀和的思想,\([L,R]\)\([l,r]\) 的个数为 \(cnt1-cnt2\)
这样就可以了,只需要在递归的时候统计个数,找到对应的那个数输出即可。

code

#include <bits/stdc++.h>
using namespace std;
#define lc a[u].l
#define rc a[u].r
#define L a[u].ls
#define R a[u].rs
const int N=2e5+5,M=3e7+5;

struct tree
{
    int l,r;
    int ls,rs;
    int cnt;
}a[M];

int n,m,tot,root[N],rt,idx;
int b[N],c[N],C[N];

int New(int u)
{
    a[++idx]=a[u];
    return idx;
}
void pushup(int u)
{
    a[u].cnt=a[lc].cnt+a[rc].cnt;
}
int build(int l,int r)
{
    int u=++idx;
    L=l,R=r;
    if(l==r) return u;
    int mid=(l+r)>>1;
    lc=build(l,mid);
    rc=build(mid+1,r);
    return u;
}
int find(int x)
{
    return lower_bound(c+1,c+tot+1,x)-c;
}
int modify(int u,int x)
{
    u=New(u);
    if(L==R) 
    {
        a[u].cnt++;
        return u;
    }
    int mid=(L+R)>>1;
    if(x<=mid) lc=modify(lc,x);
    else rc=modify(rc,x);
    pushup(u);
    return u;
}
int query(int u,int v,int x)
{
    if(L==R) return R;
    int cnt=a[a[v].l].cnt-a[a[u].l].cnt;
    if(x<=cnt) return query(a[u].l,a[v].l,x);
    else return query(a[u].r,a[v].r,x-cnt);
}
int main ()
{
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n>>m;
    for(int i=1;i<=n;i++) cin>>b[i],c[i]=b[i];
    sort(c+1,c+n+1);
    tot=n;//不用去重
    root[0]=build(1,tot);
    for(int i=1;i<=tot;i++) root[i]=modify(root[i-1],find(b[i]));
    while(m--)
    {
        int l,r,k;
        cin>>l>>r>>k;
        cout<<c[query(root[l-1],root[r],k)]<<"\n";
    }
}

本题离散化不需要去重,因为按值域加点每次只会增加一 QAQ。