初涉k-d tree

时间:2023-03-08 22:09:51

听说k-d tree是一个骗分的好东西?(但是复杂度差评???

还听说绍一的kdt常数特别小?

KDT是什么

KDT的全称是k-degree tree,顾名思义,这是一种处理多维空间的数据结构。

例如,给定一张二维图,每次会插入一些点,并且查询一个矩形区域内的点数。

上面这个问题可以离线cdq分治,也可以离线离散化处理,这两个做法可以参见初涉二维数点问题。不过这就是2-d tree基础的应用,使得我们可以在线处理这个问题。

网上关于KDT的解释博客有很多,但我认为在了解了k-d tree的作用之后,直接上代码更利于理解一些。

预备知识:

1.平衡树

(2.线段树)

KDT的题目

【矩形求和】bzoj4066: 简单题

Description

你有一个N*N的棋盘,每个格子内有一个整数,初始时的时候全部为0,现在需要维护两种操作:

命令

参数限制

内容

1 x y A

1<=x,y<=N,A是正整数

将格子x,y里的数字加上A

2 x1 y1 x2 y2

1<=x1<= x2<=N

1<=y1<= y2<=N

输出x1 y1 x2 y2这个矩形内的数字和

3

终止程序

Input

输入文件第一行一个正整数N。
接下来每行一个操作。每条命令除第一个数字之外,
均要异或上一次输出的答案last_ans,初始时last_ans=0。

Output

对于每个2操作,输出一个对应的答案。

HINT

数据规模和约定
1<=N<=500000,操作数不超过200000个,内存限制20M,保证答案在int范围内并且解码之后数据仍合法。
样例解释见OJ2683

题目分析

直接挂代码吧。

 #include<bits/stdc++.h>
const int maxn = ; int n,root,D;
long long lastAns,tot,lim;
struct point
{
int d[],mn[],mx[],l,r;  //d[]表示当前点(下面有用[]重载过)
long long sum,v;     //mn[]表示这棵子树内的点中坐标最小值,mx[]同理
int &operator [] (int a)
{
return d[a];
}
friend bool operator == (point a, point b)
{
return b.d[]==a.d[]&&b.d[]==a.d[];
}
friend bool operator < (point a, point b)
{
return a[D] < b[D];
}
}now,a[maxn],t[maxn]; int read()
{
char ch = getchar();
long long num = ;
bool fl = ;
for (; !isdigit(ch); ch = getchar())
if (ch=='-') fl = ;
for (; isdigit(ch); ch = getchar())
num = (num<<)+(num<<)+ch-;
if (fl) num = -num;
return num;
}
bool inside(int a1, int b1, int c1, int d1, int a2, int b2, int c2, int d2)      //如果矩形1完全包含在矩形2里
{
return a1 >= a2 && c1 <= c2 && b1 >= b2 && d1 <= d2;
}
bool ouside(int a1, int b1, int c1, int d1, int a2, int b2, int c2, int d2)      //如果矩形1和矩形2丝毫不相交      
{
return c1 < a2 || a1 > c2 || d1 < b2 || d2 < b1;
}
void update(int x)          //类似于线段树的pushup,将子树x的信息更新
{
int l = a[x].l, r = a[x].r;
for (int i=; i<=; i++)
{
a[x].mn[i] = a[x].mx[i] = a[x][i];
if (l){
a[x].mn[i] = std::min(a[x].mn[i], a[l].mn[i]);
a[x].mx[i] = std::max(a[x].mx[i], a[l].mx[i]);
}
if (r){
a[x].mn[i] = std::min(a[x].mn[i], a[r].mn[i]);
a[x].mx[i] = std::max(a[x].mx[i], a[r].mx[i]);
}
}
a[x].sum = a[l].sum+a[r].sum+a[x].v;    //本题中有点权
}
void insert(int &k, int D)            //类似平衡树的插入,每一次换一维比较
{
if (!k){
k = ++tot;
a[k][] = a[k].mn[] = a[k].mx[] = now[];
a[k][] = a[k].mn[] = a[k].mx[] = now[];
}
if (a[k]==now){
a[k].v += now.v, a[k].sum += now.v;  //如果已经存在这个点
return;
}
if (a[k][D] > now[D])
insert(a[k].l, D^);          //D^1代表每一层插入换一维度比较(kdt核心)
else insert(a[k].r, D^);
update(k);
}
int rebuild(int l, int r, int D)        //kdt的重构(此处不同于替罪羊重构)
{
if (l > r) return ;
int mid = (l+r)>>;
::D = D;
std::nth_element(t+l, t+mid, t+r+);
a[mid] = t[mid];
a[mid].l = rebuild(l, mid-, D^);
a[mid].r = rebuild(mid+, r, D^);
update(mid);
return mid;
}
long long query(int x, int aa, int bb, int cc, int dd)
{
if (!x) return ;
long long tmp = ;
if (inside(a[x].mn[], a[x].mn[], a[x].mx[], a[x].mx[], aa, bb, cc, dd))
return a[x].sum;
if (ouside(a[x].mn[], a[x].mn[], a[x].mx[], a[x].mx[], aa, bb, cc, dd))
return ;
if (inside(a[x][], a[x][], a[x][], a[x][], aa, bb, cc, dd)) tmp = a[x].v;
tmp += query(a[x].l, aa, bb, cc, dd)+query(a[x].r, aa, bb, cc, dd);
return tmp;
}
int main()
{
n = read();
lim = ;      //设置阈值,插入次数超过阈值就重构
int tt,aa,bb,cc,dd,w,i;
memset(a, , sizeof a);
for (;;)
{
tt = read();
if (tt==) break;
aa = read()^lastAns, bb = read()^lastAns;
if (tt==){
cc = read()^lastAns, dd = read()^lastAns;
lastAns = query(root, aa, bb, cc, dd);
printf("%lld\n",lastAns);
}else{
w = read()^lastAns;
now[] = aa, now[] = bb, now.v = w, now.sum = w;
insert(root, );
if (tot == lim){
for (i=; i<=tot; i++) t[i] = a[i];
root = rebuild(, tot, );      //暴力重构
lim += ;
}
}
}
return ;
}

本题的关键已经注释在程序里了。

【单点最近点】bzoj2648: SJY摆棋子

Description

这天,SJY显得无聊。在家自己玩。在一个棋盘上,有N个黑色棋子。他每次要么放到棋盘上一个黑色棋子,要么放上一个白色棋子,如果是白色棋子,他会找出距离这个白色棋子最近的黑色棋子。此处的距离是 曼哈顿距离 即(|x1-x2|+|y1-y2|) 。现在给出N<=500000个初始棋子。和M<=500000个操作。对于每个白色棋子,输出距离这个白色棋子最近的黑色棋子的距离。同一个格子可能有多个棋子。

Input

第一行两个数 N M
以后M行,每行3个数 t x y
如果t=1 那么放下一个黑色棋子
如果t=2 那么放下一个白色棋子

Output

对于每个T=2 输出一个最小距离

题目分析

不管题目要求的是什么,其实我们对于KDT的维护都是大致相同的,变化的大多是query操作。

例如这题,结合我们已经维护好的节点信息:$mn[]$,$mx[]$等,应该如何query呢。

比方说我们现在在root节点上,那么我们可以得到的是root这个点与查询的点的距离,并且还知道整张图x,y坐标的最小及最大值。

这里我们可以用类似于启发式搜索的思想处理query。

我们处理出$f(x,y)=y到x矩形的距离$——其中“到x矩形的距离”指的是y点离最近属于矩形的点的距离。形象来说就是,一个网格图上这个矩形是一座城市,现在我在一个点上想要最快到达这个城市(到这个城市任何一个点都行)的距离。

 inline int get(node a, node b)
{
int ret = ;
for (int i=; i<=; i++) ret += std::max(, a.mn[i]-b[i]);
for (int i=; i<=; i++) ret += std::max(, b[i]-a.mx[i]);
return ret;
}

这就是这个距离函数。

有了这个距离函数,我们就可以得出当前节点的左右儿子所管辖的矩形距离查询点最少有多远。

那么为了更优,我们当然是要先走估价少的那一边。注意这里是启发式地query,而不是单纯的贪心。两者的区别是:启发式先走估计花费低的;贪心只走估计花费低的。

于是我们的query就解决了。

 /**************************************************************
    Problem: 2648
    User: AntiQuality
    Language: C++
    Result: Accepted
    Time:13684 ms
    Memory:57548 kb
****************************************************************/
 
#include<bits/stdc++.h>
const int maxn = ;
 
int n,m,D,root,tot,sum,ans;
struct node
{
    int d[],mn[],mx[],l,r;
    int &operator [](int a){return d[a];}
    bool operator == (node a)
    {
        return a.d[]==d[]&&d[]==a.d[];
    }
    bool operator < (node a) const
    {
        return d[D] < a.d[D];
    }
}a[maxn],t[*maxn],now;
 
int read()
{
    char ch = getchar();
    int num = ;
    bool fl = ;
    for (; !isdigit(ch); ch = getchar())
        if (ch=='-') fl = ;
    for (; isdigit(ch); ch = getchar())
        num = (num<<)+(num<<)+ch-;
    if (fl) num = -num;
    return num;
}
void rec(int x)
{
    for (int i=; i<=; i++)
        a[x].mn[i] = a[x].mx[i] = a[x][i];
}
inline void update(int x)
{
    int l = a[x].l, r = a[x].r;
    rec(x);
    for (int i=; i<=; i++)
    {
        if (l){
            a[x].mn[i] = std::min(a[x].mn[i], a[l].mn[i]);
            a[x].mx[i] = std::max(a[x].mx[i], a[l].mx[i]);
        }
        if (r){
            a[x].mn[i] = std::min(a[x].mn[i], a[r].mn[i]);
            a[x].mx[i] = std::max(a[x].mx[i], a[r].mx[i]);
        }
    }
}
int rebuild(int l, int r, int D)
{
    if (l > r) return ;
    int mid = (l+r)>>;
    ::D = D;
    std::nth_element(t+l, t+mid, t+r+);
    a[mid] = t[mid];
    a[mid].l = rebuild(l, mid-, D^);
    a[mid].r = rebuild(mid+, r, D^);
    update(mid);
    return mid;
}
void insert(int &x, int k)
{
    if (!x){
        x = ++tot;
        a[x][] = now[], a[x][] = now[];
        rec(x);
    }
    if (a[x]==now) return;
    if (now[k] < a[x][k])
        insert(a[x].l, k^);
    else insert(a[x].r, k^);
    update(x);
}
inline int dis(node a, node b)
{
    return abs(a[]-b[])+abs(a[]-b[]);
}
inline int get(node a, node b)
{
    int ret = ;
    for (int i=; i<=; i++) ret += std::max(, a.mn[i]-b[i]);
    for (int i=; i<=; i++) ret += std::max(, b[i]-a.mx[i]);
    return ret;
}
void query(int x)
{
    int l = a[x].l, r = a[x].r, rl = 2e9, rr = 2e9, lgh = dis(a[x], now);
    ans = std::min(ans, lgh);
    if (l) rl = get(a[l], now);
    if (r) rr = get(a[r], now);
    if (rl < rr){
        if (rl < ans) query(l);
        if (rr < ans) query(r);
    }else{
        if (rr < ans) query(r);
        if (rl < ans) query(l);
    }
}
int query()
{
    ans = 2e9;
    query(root);
    return ans;
}
int main()
{
    n = read(), m = read();
    for (int i=; i<=n; i++) t[i][] = read(), t[i][] = read();
    root = rebuild(, n, );
    tot = n;
    for (int i=; i<=m; i++)
    {
        int tt = read(), x = read(), y = read();
        now[] = x, now[] = y;
        if (tt==){
            insert(root, );
        }else
            printf("%d\n",query());
    }
    return ;
}

【单点k远点】2626: JZPFAR

Description

  平面上有n个点。现在有m次询问,每次给定一个点(px, py)和一个整数k,输出n个点中离(px, py)的距离第k大的点的标号。如果有两个(或多个)点距离(px, py)相同,那么认为标号较小的点距离较大。

Input

  第一行,一个整数n,表示点的个数。
  下面n行,每行两个整数x_i, y_i,表示n个点的坐标。点的标号按照输入顺序,分别为1..n。
  下面一行,一个整数m,表示询问个数。
  下面m行,每行三个整数px_i, py_i, k_i,表示一个询问。

Output

  m行,每行一个整数,表示相应的询问的答案。

数据规模和约定

  50%的数据中,n个点的坐标在某范围内随机分布。
  100%的数据中,n<=10^5, m<=10^4, 1<=k<=20,所有点(包括询问的点)的坐标满足绝对值<=10^9,n个点中任意两点坐标不同,m个询问的点的坐标在某范围内随机分布。


题目分析

这题是KDT查询单点第k远的应用。

其间有一个技巧:在小根堆里插入$k$个$-INF$,每次考虑当前值是否大于堆顶。如果当前值大于堆顶,那么弹出堆顶并且插入当前值。这里小根堆就相当于一个缓存区的作用,是挺巧妙的一种技巧。

那么有了上面这个trick我们就可以大胆query了。

 /**************************************************************
Problem: 2626
User: AntiQuality
Language: C++
Result: Accepted
Time:22804 ms
Memory:15364 kb
****************************************************************/ #include<bits/stdc++.h>
const int maxn = ; int D;
struct point
{
long long d[],mn[],mx[],l,r,id;
long long &operator [](int x){return d[x];}
bool operator < (point a) const
{
return d[D] < a.d[D];
}
}t[maxn],a[maxn],now;
struct node
{
long long val,id;
bool operator < (node a) const
{
return val > a.val||(val==a.val&&id < a.id);
}
node(long long a, long long b):val(a),id(b) {}
};
int n,m,k,root;
std::priority_queue<node> q; int read()
{
char ch = getchar();
int num = ;
bool fl = ;
for (; !isdigit(ch); ch = getchar())
if (ch=='-') fl = ;
for (; isdigit(ch); ch = getchar())
num = (num<<)+(num<<)+ch-;
if (fl) num = -num;
return num;
}
void clear(std::priority_queue<node> &q)
{
std::priority_queue<node> emt;
std::swap(q, emt);
}
void rec(int x)
{
for (int i=; i<=; i++) a[x].mn[i] = a[x].mx[i] = a[x][i];
}
void update(int x)
{
int l = a[x].l, r = a[x].r;
rec(x);
for (int i=; i<=; i++)
{
if (l){
a[x].mn[i] = std::min(a[x].mn[i], a[l].mn[i]);
a[x].mx[i] = std::max(a[x].mx[i], a[l].mx[i]);
}
if (r)
{
a[x].mn[i] = std::min(a[x].mn[i], a[r].mn[i]);
a[x].mx[i] = std::max(a[x].mx[i], a[r].mx[i]);
}
}
}
int build(int l, int r, int k)
{
if (l > r) return ;
int mid = (l+r)>>;
D = k;
std::nth_element(t+l, t+mid, t+r+);
a[mid] = t[mid];
a[mid].l = build(l, mid-, k^);
a[mid].r = build(mid+, r, k^);
update(mid);
return mid;
}
long long sqr(long long x){return x*x;}
long long dis(point a, point b)
{
return sqr(a[]-b[])+sqr(a[]-b[]);
}
long long get(point a)
{
long long ret = ;
for (int i=; i<=; i++)
ret += std::max(sqr(now[i]-a.mn[i]), sqr(now[i]-a.mx[i]));
return ret;
}
void query(int x)
{
if (!x) return;
long long ll = -1e17, lr = -1e17, pur = dis(now, a[x]);
int l = a[x].l, r = a[x].r;
if (pur > q.top().val||(pur==q.top().val&&a[x].id<q.top().id))
q.pop(), q.push(node(pur, a[x].id));
if (l) ll = get(a[l]);
if (r) lr = get(a[r]);
if (ll > lr){
if (ll >= q.top().val) query(l);
if (lr >= q.top().val) query(r);
}else{
if (lr >= q.top().val) query(r);
if (ll >= q.top().val) query(l);
}
}
int main()
{
n = read();
for (int i=; i<=n; i++) t[i][] = read(), t[i][] = read(), t[i].id = i;
root = build(, n, );
m = read();
while (m--)
{
clear(q);
now[] = read(), now[] = read(), k = read();
for (int i=; i<=k; i++) q.push(node(-, ));
query(root);
printf("%lld\n",q.top().id);
}
return ;
}

【全图k远点对】4520: [Cqoi2016]K远点对

Description

已知平面内 N 个点的坐标,求欧氏距离下的第 K 远点对。

Input

输入文件第一行为用空格隔开的两个整数 N, K。接下来 N 行,每行两个整数 X,Y,表示一个点
的坐标。1 < =  N < =  100000, 1 < =  K < =  100, K < =  N*(N−1)/2 , 0 < =  X, Y < 2^31。

Output

输出文件第一行为一个整数,表示第 K 远点对的距离的平方(一定是个整数)。


题目分析

全图k远点对?

乍一看好像很玄学?

其实是和单点k远点的思路一样的。我们先向小根堆添加$2k$个$-INF$,然后枚举查询这n个点。之后的事情就都一样了。

复杂度似乎是$O(n\sqrt{n})$?

 /**************************************************************
Problem: 4520
User: AntiQuality
Language: C++
Result: Accepted
Time:1772 ms
Memory:15364 kb
****************************************************************/ #include<bits/stdc++.h>
const int maxn = ; int D;
struct point
{
long long d[],mn[],mx[],l,r,id;
long long &operator [](int x){return d[x];}
bool operator < (point a) const
{
return d[D] < a.d[D];
}
}t[maxn],a[maxn],now;
struct node
{
long long val,id;
bool operator < (node a) const
{
return val > a.val||(val==a.val&&id < a.id);
}
node(long long a, long long b):val(a),id(b) {}
};
int n,m,k,root;
std::priority_queue<node> q; int read()
{
char ch = getchar();
int num = ;
bool fl = ;
for (; !isdigit(ch); ch = getchar())
if (ch=='-') fl = ;
for (; isdigit(ch); ch = getchar())
num = (num<<)+(num<<)+ch-;
if (fl) num = -num;
return num;
}
void clear(std::priority_queue<node> &q)
{
std::priority_queue<node> emt;
std::swap(q, emt);
}
void rec(int x)
{
for (int i=; i<=; i++) a[x].mn[i] = a[x].mx[i] = a[x][i];
}
void update(int x)
{
int l = a[x].l, r = a[x].r;
rec(x);
for (int i=; i<=; i++)
{
if (l){
a[x].mn[i] = std::min(a[x].mn[i], a[l].mn[i]);
a[x].mx[i] = std::max(a[x].mx[i], a[l].mx[i]);
}
if (r)
{
a[x].mn[i] = std::min(a[x].mn[i], a[r].mn[i]);
a[x].mx[i] = std::max(a[x].mx[i], a[r].mx[i]);
}
}
}
int build(int l, int r, int k)
{
if (l > r) return ;
int mid = (l+r)>>;
D = k;
std::nth_element(t+l, t+mid, t+r+);
a[mid] = t[mid];
a[mid].l = build(l, mid-, k^);
a[mid].r = build(mid+, r, k^);
update(mid);
return mid;
}
long long sqr(long long x)
{
return x*x;
}
long long dis(point a, point b)
{
return sqr(a[]-b[])+sqr(a[]-b[]);
}
long long get(point a)
{
long long ret = ;
for (int i=; i<=; i++)
ret += std::max(sqr(now[i]-a.mn[i]), sqr(now[i]-a.mx[i]));
return ret;
}
void query(int x)
{
if (!x) return;
long long ll = -1e17, lr = -1e17, pur = dis(now, a[x]);
int l = a[x].l, r = a[x].r;
if (pur > q.top().val)
q.pop(), q.push(node(pur, a[x].id));
if (l) ll = get(a[l]);
if (r) lr = get(a[r]);
if (ll > lr){
if (ll >= q.top().val) query(l);
if (lr >= q.top().val) query(r);
}else{
if (lr >= q.top().val) query(r);
if (ll >= q.top().val) query(l);
}
}
int main()
{
n = read(), k = read()<<;
for (int i=; i<=n; i++) t[i][] = read(), t[i][] = read(), t[i].id = i;
root = build(, n, );
for (int i=; i<=k; i++) q.push(node(-, ));
for (int i=; i<=n; i++)
{
now = a[i];
query(root);
}
printf("%lld\n",q.top().val);
return ;
}

【半平面内点权和】2850: 巧克力王国

Description

巧克力王国里的巧克力都是由牛奶和可可做成的。但是并不是每一块巧克力都受王国人民的欢迎,因为大家都不喜
欢过于甜的巧克力。对于每一块巧克力,我们设x和y为其牛奶和可可的含量。由于每个人对于甜的程度都有自己的
评判标准,所以每个人都有两个参数a和b,分别为他自己为牛奶和可可定义的权重,因此牛奶和可可含量分别为x
和y的巧克力对于他的甜味程度即为ax + by。而每个人又有一个甜味限度c,所有甜味程度大于等于c的巧克力他都
无法接受。每块巧克力都有一个美味值h。现在我们想知道对于每个人,他所能接受的巧克力的美味值之和为多少

Input

第一行两个正整数n和m,分别表示巧克力个数和询问个数。接下来n行,每行三个整数x,y,h,含义如题目所示。再
接下来m行,每行三个整数a,b,c,含义如题目所示。

Output

输出m行,其中第i行表示第i个人所能接受的巧克力的美味值之和。

HINT

1 <= n, m <= 50000,1 <= 10^9,-10^9 <= a, b, x, y <= 10^9。


题目分析

一开始把这题想复杂了。这题求的是半平面内点权和没错,不过查询时候并不需要多少复杂的思路。$ax + by<c$这个条件容易发现是单调的,于是我们就把它看作是一个单独的判断函数就行了。

这样query时候就可以像矩形求和一样,如果当前区间都满足条件则加上整颗子树的价值,否则递归判断下去。

 #include<bits/stdc++.h>
const int maxn = ; int D,n,m,a,b,c,root;
struct node
{
int d[],mn[],mx[],l,r,v;
long long sum;
int &operator [](int a){return d[a];}
bool operator < (node a) const {return d[D]<a.d[D];}
}f[maxn],t[maxn]; int read()
{
char ch = getchar();
int num = ;
bool fl = ;
for (; !isdigit(ch); ch = getchar())
if (ch=='-') fl = ;
for (; isdigit(ch); ch = getchar())
num = (num<<)+(num<<)+ch-;
if (fl) num = -num;
return num;
}
void update(int x)
{
int l = f[x].l, r = f[x].r;
for (int i=; i<=; i++)
{
f[x].mx[i] = f[x].mn[i] = f[x][i];
if (l){
f[x].mn[i] = std::min(f[x].mn[i], f[l].mn[i]);
f[x].mx[i] = std::max(f[x].mx[i], f[l].mx[i]);
}
if (r){
f[x].mn[i] = std::min(f[x].mn[i], f[r].mn[i]);
f[x].mx[i] = std::max(f[x].mx[i], f[r].mx[i]);
}
}
f[x].sum = f[l].sum+f[r].sum+f[x].v;
}
int build(int l, int r, int k)
{
if (l > r) return ;
int mid = (l+r)>>;
D = k;
std::nth_element(t+l, t+mid, t+r+);
f[mid] = t[mid];
f[mid].l = build(l, mid-, k^);
f[mid].r = build(mid+, r, k^);
update(mid);
return mid;
}
int legal(int x, int y)
{
return a*x+b*y < c;
}
int calc(int x)
{
if (!x) return ;
node a = f[x];
return legal(a.mn[], a.mn[])+legal(a.mn[], a.mx[])+legal(a.mx[], a.mx[])+legal(a.mx[], a.mn[]);
}
long long query(int x)
{
int l = f[x].l, r = f[x].r, ll = calc(l), lr = calc(r);
long long ret = ;
if (calc(x)==) return f[x].sum;
else if (legal(f[x][], f[x][])) ret += f[x].v;
if (ll) ret += query(l);
if (lr) ret += query(r);
return ret;
}
int main()
{
n = read(), m = read();
for (int i=; i<=n; i++)
{
t[i][] = read(), t[i][] = read();
t[i].v = t[i].sum = read();
}
root = build(, n, );
for (int i=; i<=m; i++)
{
a = read(), b = read(), c = read();
printf("%lld\n",query(root));
}
return ;
}

END