优美的爆搜?KDtree学习

时间:2023-03-09 16:42:25
优美的爆搜?KDtree学习

如果给你平面内一些点,让你求距离某一个指定点最近的点,应该怎么办呢?

O(n)遍历!

但是,在遍历的过程中,我们发现有一些点是永远无法更新答案的。

如果我们把这些点按照一定顺序整理起来,省略对不必要点的遍历,是不是可以降低时间复杂度呢?

这样的话,我们要用到的工具就是KDtree。

KDtree本质上是一颗BST(二叉搜索树),只不过每一层按照不同的维度分割,也就是说,一层划分x,一层划分y,交替进行。大概就是这样:

优美的爆搜?KDtree学习

如果我们把他画在二维平面上的话,会发现KDtree实际上把一个矩形分割成了多个小矩形:

优美的爆搜?KDtree学习

(我是来盗图的QAQ)

更新答案时,采用邻域搜索的方式。我们发现我们查询的点落在了某个小矩形内,我们用这个小矩形内的点去更新答案。然后进行回溯,看一下周围的矩形有没有可能存在更优答案,如果不可能的话,就不用搜索它了。

这样下来的复杂度最优是O(logn),随机数据介于O(logn)~O(sqrt(n))之间。如果是特意构造的数据,可以卡到O(n)(比如精度要求实数,给你一个圆,让你查询距离圆心最近的点)。

然而一般情况下KDtree比较好写,在考场上可以比较经济地拿到大部分分数,还是值得学习的。

关于代码实现,自己YY即可。反正我是脑补出来的。

例题:

BZOJ2716/2648:

KDtree查曼哈顿距离的板子,由于数据比较水所以直接插入可过,不需要考虑平衡性的问题。

代码自己YY。我的仅供参考(虽然我自行胡编的代码没什么参考价值QAQ)。

代码:

 #include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
using namespace std;
const int maxn=1.5e6+1e2;
const int inf=0x3f3f3f3f; int cmp,ans;
struct Point {
int d[];
friend bool operator < (const Point &a,const Point &b) {
return a.d[cmp] < b.d[cmp];
}
int dis(const Point &o) const {
int ret = ;
for(int i=;i<;i++)
ret += abs( d[i] - o.d[i] );
return ret;
}
}ps[maxn]; int lson[maxn],rson[maxn],mi[maxn][],mx[maxn][];
Point dv[maxn];
int cnt; inline void update(int pos) {
if( lson[pos] ) {
for(int i=;i<;i++)
mi[pos][i] = min( mi[pos][i] , mi[lson[pos]][i] ),
mx[pos][i] = max( mx[pos][i] , mx[lson[pos]][i] );
}
if( rson[pos] ) {
for(int i=;i<;i++)
mi[pos][i] = min( mi[pos][i] , mi[rson[pos]][i] ),
mx[pos][i] = max( mx[pos][i] , mx[rson[pos]][i] );
}
}
inline void fill(int pos,const Point &p) {
dv[pos] = p;
for(int i=;i<;i++)
mx[pos][i] = mi[pos][i] = p.d[i];
}
inline void build(int pos,int pl,int pr,int dir) {
const int pmid = ( pl + pr ) >> ;
cmp = dir;
nth_element(ps+pl,ps+pmid,ps+pr+);
fill(pos,ps[pmid]);
if( pl < pmid ) build(lson[pos]=++cnt,pl,pmid-,dir^);
if( pr > pmid ) build(rson[pos]=++cnt,pmid+,pr,dir^);
update(pos);
}
inline void insert(int pos,Point np,int dir) {
cmp = dir;
if( np < dv[pos] ) {
if( lson[pos] ) insert(lson[pos],np,dir^);
else {
lson[pos] = ++cnt;
fill(lson[pos],np);
}
} else {
if( rson[pos] ) insert(rson[pos],np,dir^);
else {
rson[pos] = ++cnt;
fill(rson[pos],np);
}
}
update(pos);
}
inline int dis(int pos,const Point &p) {
int ret = ;
for(int i=;i<;i++)
ret += max( p.d[i] - mx[pos][i] , ) + max( mi[pos][i] - p.d[i] , );
return ret;
}
inline void query(int pos,const Point &p) {
ans = min( ans , p.dis(dv[pos]) );
int dl = lson[pos] ? dis(lson[pos],p) : inf;
int dr = rson[pos] ? dis(rson[pos],p) : inf;
if( dl < dr ) {
if( dl < ans ) query(lson[pos],p);
if( dr < ans ) query(rson[pos],p);
} else {
if( dr < ans ) query(rson[pos],p);
if( dl < ans ) query(lson[pos],p);
}
} int main() {
static int n,m;
static Point p; scanf("%d%d",&n,&m);
for(int i=;i<=n;i++)
scanf("%d%d",ps[i].d,ps[i].d+); build(cnt=,,n,); for(int i=,t;i<=m;i++) {
scanf("%d%d%d",&t,p.d,p.d+);
if(t == ) {
insert(,p,);
} else {
ans = inf;
query(,p);
printf("%d\n",ans);
}
}
return ;
}

BZOJ4066:

查询二维区间和。

本来这个题能够有多种做法,结果强制在线卡了cdq分治,20mb内存卡了树套树(什么你说你敢写?去写吧再见),其他一些奇奇怪怪的算法(分块线段树,分块splay之类的)并不是很容易实现。于是就只好KDtree了。

我们用KDtree上每一个节点去维护当前四边形的sum值,同时维护size。如果size过于不平衡了就进行重构(替罪羊树原理),同时手写内存池回收节点。

然而这样做并不能AC(怕不是我写搓了),需要在替罪羊重构的基础上判定一个修改次数,如果两次重构之间修改太少则不进行重构(否则不停地重构依旧TLE),另外你需要一个文件快读来保证AC。

另外如果这样写了仍不能AC,请注意调参,我用的是替罪羊的alpha设0.8,修改次数的lambda设1000,这样能够卡着50s的时限AC。

另外,如果实在TLE,那就弃了吧......

实在不明白为什么网上别人暴力重构和不重构能轻松AC。

代码:

 #pragma GCC optimize(3)
#include<cstdio>
#include<algorithm>
#include<cctype>
using namespace std;
const int maxn=2e5+1e1;
const double alpha = 0.8; int cmp;
struct Point {
int d[],val;
friend bool operator < (const Point &a,const Point &b) {
return a.d[cmp] < b.d[cmp];
}
friend bool operator == (const Point &a,const Point &b) {
return a.d[] == b.d[] && a.d[] == b.d[];
}
Point operator += (const Point &x) {
val += x.val;
return *this;
}
inline void reset() {
d[] = d[] = val = ;
}
}ps[maxn],nv[maxn]; int lson[maxn],rson[maxn],mi[maxn][],mx[maxn][],sum[maxn],siz[maxn];
int reb,delta,root; namespace RamPool {
int pool[maxn],top;
inline void DelNode(int x) {
lson[x] = rson[x] = ;
pool[++top] = x;
}
inline int NewNode() {
return pool[top--];
}
} using RamPool::DelNode;using RamPool::NewNode; inline void update(int pos) {
sum[pos] = nv[pos].val;
if( lson[pos] ) {
for(int i=;i<;i++)
mi[pos][i] = min( mi[pos][i] , mi[lson[pos]][i] ),
mx[pos][i] = max( mx[pos][i] , mx[lson[pos]][i] );
sum[pos] += sum[lson[pos]];
}
if( rson[pos] ) {
for(int i=;i<;i++)
mi[pos][i] = min( mi[pos][i] , mi[rson[pos]][i] ),
mx[pos][i] = max( mx[pos][i] , mx[rson[pos]][i] );
sum[pos] += sum[rson[pos]];
}
}
inline void fill(int pos,const Point &p) {
nv[pos] = p , siz[pos] = ;
for(int i=;i<;i++)
mi[pos][i] = mx[pos][i] = p.d[i];
sum[pos] = p.val;
}
inline void build(int pos,int ll,int rr,int dir) {
cmp = dir;
const int mid = ( ll + rr ) >> ;
nth_element(ps+ll,ps+mid,ps+rr+);
fill(pos,ps[mid]); siz[pos] = rr - ll + ;
if( ll < mid ) build(lson[pos]=NewNode(),ll,mid-,dir^);
if( rr > mid ) build(rson[pos]=NewNode(),mid+,rr,dir^);
update(pos);
}
inline void recycle(int pos,int& pcnt) {
if( lson[pos] ) recycle(lson[pos],pcnt);
if( rson[pos] ) recycle(rson[pos],pcnt);
ps[++pcnt] = nv[pos];
DelNode(pos);
}
inline int rebuild(int pos,int dir) {
reb = ;
int pcnt = ;
recycle(pos,pcnt);
int ret = NewNode();
build(ret,,pcnt,dir);
return ret;
} inline int insert(int pos,int dir,const Point &p) {
cmp = dir;
if( !nv[pos].val ) {
fill(pos,p);
return pos;
}
if( p == nv[pos] ) {
nv[pos] += p;
sum[pos] += p.val;
return pos;
}
++siz[pos];
if( p < nv[pos] ) {
if( !lson[pos] ) lson[pos] = NewNode();
lson[pos] = insert(lson[pos],dir^,p);
if( !reb && delta > && siz[lson[pos]] > (double) siz[pos] * alpha )
return rebuild(pos,dir);
} else {
if( !rson[pos] ) rson[pos] = NewNode();
rson[pos] = insert(rson[pos],dir^,p);
if( !reb && delta > && siz[rson[pos]] > (double) siz[pos] * alpha )
return rebuild(pos,dir);
}
update(pos);
return pos;
} inline bool inside(const int &Insx,const int &Insy,const int &Intx,const int &Inty,const int &Osx,const int &Osy,const int &Otx,const int &Oty) {
return Osx <= Insx &&Intx <= Otx && Osy <= Insy &&Inty <= Oty;
}
inline bool outside(const int &Insx,const int &Insy,const int &Intx,const int &Inty,const int &Osx,const int &Osy,const int &Otx,const int &Oty) {
return Insx > Otx || Intx < Osx || Insy > Oty || Inty < Osy;
}
inline bool inside(const Point &p,const int &Osx,const int &Osy,const int &Otx,const int &Oty) {
return inside(p.d[],p.d[],p.d[],p.d[],Osx,Osy,Otx,Oty);
}
inline int query(int pos,const int &sx,const int &sy,const int &tx,int const &ty) {
if( outside(mi[pos][],mi[pos][],mx[pos][],mx[pos][],sx,sy,tx,ty) ) return ;
if( inside(mi[pos][],mi[pos][],mx[pos][],mx[pos][],sx,sy,tx,ty) ) return sum[pos];
int ret = ;
if( inside(nv[pos],sx,sy,tx,ty) ) ret += nv[pos].val;
if( lson[pos] ) ret += query(lson[pos],sx,sy,tx,ty);
if( rson[pos] ) ret += query(rson[pos],sx,sy,tx,ty);
return ret;
} inline void init() {
for(int i=;i<maxn;i++)
DelNode(i);
root = NewNode();
} inline char nextchar() {
static char buf[<<],*st=buf+(<<),*ed=buf+(<<);
if( st == ed ) ed = buf + fread(st=buf,,<<,stdin);
return st == ed ? - : *st++;
}
inline int getint() {
int ret = ,ch;
while( !isdigit(ch=nextchar()) );
do ret=ret*+ch-''; while( isdigit(ch=nextchar()) );
return ret;
} int main() {
static int ope,lastans,sx,sy,tx,ty,xx,yy,num;
init(); getint();
while( ( ope = getint() ) != ) {
if( ope == ) {
xx = getint() , yy = getint() , num = getint();
reb = , ++delta;
xx ^= lastans , yy ^= lastans , num ^= lastans;
root = insert(root,,(Point){xx,yy,num});
} else {
sx = getint() , sy = getint() , tx = getint() , ty = getint();
sx ^= lastans , sy ^= lastans , tx ^= lastans , ty ^= lastans;
printf("%d\n", lastans = query(root,sx,sy,tx,ty) );
}
} return ;
}

Upd20180104:

其实那个TLE是我替罪羊重构写的姿势不对了,正确的姿势是找到最浅的不平衡点进行重构,不需要判断size的。虽然这样还是比暴力重构慢,但好在能稳稳地AC了QAQ。

代码:

 #pragma GCC optimize(3)
#include<cstdio>
#include<algorithm>
#include<cctype>
using namespace std;
const int maxn=2.5e5+1e2;
const double alpha=0.8; int cmp;
struct Point {
int d[],val;
Point(){}
Point(int xx,int yy,int vv) {d[] = xx , d[] = yy , val = vv;}
friend bool operator < (const Point &a,const Point &b) {
return a.d[cmp] < b.d[cmp];
}
friend bool operator == (const Point &a,const Point &b) {
return a.d[] == b.d[] && a.d[] == b.d[];
}
Point operator += (const Point &r) {
val += r.val;
return *this;
}
}ps[maxn],dv[maxn];
struct QNode {
int mi[],mx[];
QNode(int Mix,int Miy,int Mxx,int Mxy) {
mi[] = Mix , mi[] = Miy , mx[] = Mxx , mx[] = Mxy;
}
};
int lson[maxn],rson[maxn],siz[maxn],mi[maxn][],mx[maxn][],sum[maxn];
int root,rebp,rebfa,rebdir; namespace RamPool {
int pool[maxn],top;
inline void DelNode(int x) {
siz[x] = lson[x] = rson[x] = ;
pool[++top] = x;
}
inline int NewNode() {
return pool[top--];
}
}
using RamPool::DelNode; using RamPool::NewNode; inline void fill(int pos,const Point &p) {
dv[pos] = p , sum[pos] = p.val , siz[pos] = ;
for(int i=;i<;i++)
mi[pos][i] = mx[pos][i] = p.d[i];
}
inline void coreupdate(const int &fa,const int &son) {
sum[fa] += sum[son] , siz[fa] += siz[son];
for(int i=;i<;i++)
mi[fa][i] = min( mi[fa][i] , mi[son][i] ) ,
mx[fa][i] = max( mx[fa][i] , mx[son][i] );
}
inline void update(int pos) {
sum[pos] = dv[pos].val , siz[pos] = ;
if( lson[pos] )
coreupdate(pos,lson[pos]);
if( rson[pos] )
coreupdate(pos,rson[pos]);
}
inline void build(int pos,int ll,int rr,int dir) {
cmp = dir;
const int mid = ( ll + rr ) >> ;
nth_element(ps+ll,ps+mid,ps+rr+);
fill(pos,ps[mid]);
if( ll < mid ) build(lson[pos]=NewNode(),ll,mid-,dir^);
if( rr > mid ) build(rson[pos]=NewNode(),mid+,rr,dir^);
update(pos);
}
inline void recycle(int pos,int& pcnt) {
if( lson[pos] ) recycle(lson[pos],pcnt);
if( rson[pos] ) recycle(rson[pos],pcnt);
ps[++pcnt] = dv[pos];
DelNode(pos);
}
inline int rebuild(int pos,int dir) {
int pcnt = ;
recycle(pos,pcnt);
int ret = NewNode();
build(ret,,pcnt,dir);
return ret;
}
inline void insert(int pos,int dir,const Point &p) {
cmp = dir;
if( !dv[pos].val ) {
fill(pos,p);
return;
}
if( dv[pos] == p ) {
dv[pos] += p , sum[pos] += p.val;
return;
}
if( p < dv[pos] ) {
if( !lson[pos] ) lson[pos] = NewNode();
insert(lson[pos],dir^,p);
update(pos);
if( siz[lson[pos]] > siz[pos] * alpha ) rebp = pos , rebdir = dir , rebfa = ;
else if( lson[pos] == rebp ) rebfa = pos;
} else {
if( !rson[pos] ) rson[pos] = NewNode();
insert(rson[pos],dir^,p);
update(pos);
if( siz[rson[pos]] > siz[pos] * alpha ) rebp = pos , rebdir = dir , rebfa = ;
else if( rson[pos] == rebp ) rebfa = pos;
}
} inline bool inside(const int* mi,const int* mx,const QNode &q) {
return q.mi[] <= mi[] && mx[] <= q.mx[] && q.mi[] <= mi[] && mx[] <= q.mx[];
}
inline bool inside(const Point &p,const QNode &q) {
return inside(p.d,p.d,q);
}
inline bool inside(const int &pos,const QNode &q) {
return inside(mi[pos],mx[pos],q);
}
inline bool outside(const int* mi,const int* mx,const QNode &q) {
return mx[] < q.mi[] || q.mx[] < mi[] || mx[] < q.mi[] || q.mx[] < mi[];
}
inline bool outside(const int &pos,const QNode &q) {
return outside(mi[pos],mx[pos],q);
}
inline int query(int pos,const QNode &q) {
if( outside(pos,q) ) return ;
if( inside(pos,q) ) return sum[pos];
int ret = ;
if( inside(dv[pos],q) ) ret = dv[pos].val;
if( lson[pos] ) ret += query(lson[pos],q);
if( rson[pos] ) ret += query(rson[pos],q);
return ret;
} inline void init() {
for(int i=maxn-;i;i--)
DelNode(i);
root = NewNode();
}
inline void rebuild() {
if( !rebfa ) root = rebuild(root,);
else if( rebp == lson[rebfa] ) lson[rebfa] = rebuild(rebp,rebdir);
else rson[rebfa] = rebuild(rebp,rebdir);
} inline char nextchar() {
static char buf[<<],*st=buf+(<<),*ed=buf+(<<);
if( st == ed ) ed = buf + fread(st=buf,,<<,stdin);
return st != ed ? *st++ : -;
}
inline int getint() {
int ret = , ch;
while( !isdigit(ch=nextchar()) );
do ret=ret*+ch-''; while( isdigit(ch=nextchar()) );
return ret;
} int main() {
static int ope,xx,yy,add,sx,sy,tx,ty,lastans;
init();
getint();
while( ( ope = getint() ) != ) {
if( ope == ) {
xx = getint()^lastans , yy = getint()^lastans , add = getint()^lastans;
rebp = rebfa = rebdir = ;
insert(root,,Point(xx,yy,add));
if( rebp ) rebuild();
} else if( ope == ) {
sx = getint()^lastans , sy = getint()^lastans , tx = getint()^lastans , ty = getint()^lastans;
printf("%d\n", lastans = query( root , QNode(sx,sy,tx,ty) ) );
}
} return ;
}

另外KDtree还有一道水题:

BZOJ2850:

让你求平面内ax+by<=c的点的权值和。没有插入只有查询……KDtree随便做一下就好了,看一看当前块是不是全部包含在可行范围内,如果全部包含则返回sum,如果全部不包含则返回0,否则递归查询子树。

这样的复杂度大概是log级的,考虑每次分成4块,最多往下递归3块,这样是log级的。

代码:

 #include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define lli long long int
#define debug cout
using namespace std;
const int maxn=1e6+1e2; int cmp;
struct Point {
lli d[],h;
friend bool operator < (const Point &a,const Point &b) {
return a.d[cmp] < b.d[cmp];
}
inline lli f(int a,int b) const {
return a * d[] + b * d[];
}
}ps[maxn],dv[maxn]; int lson[maxn],rson[maxn],cnt;
lli mx[maxn][],mi[maxn][],sum[maxn];
lli c; inline lli f(lli x,lli y,int a,int b) {
return a * x + b * y;
} inline void update(int pos) {
if( lson[pos] ) {
for(int i=;i<;i++)
mi[pos][i] = min( mi[pos][i] , mi[lson[pos]][i] ),
mx[pos][i] = max( mx[pos][i] , mx[lson[pos]][i] );
sum[pos] += sum[lson[pos]];
}
if( rson[pos] ) {
for(int i=;i<;i++)
mi[pos][i] = min( mi[pos][i] , mi[rson[pos]][i] ),
mx[pos][i] = max( mx[pos][i] , mx[rson[pos]][i] );
sum[pos] += sum[rson[pos]];
}
}
inline void fill(int pos,const Point &p) {
dv[pos] = p;
for(int i=;i<;i++)
mx[pos][i] = mi[pos][i] = p.d[i];
sum[pos] = p.h;
}
inline void build(int pos,int ll,int rr,int dir) {
cmp = dir;
const int mid = ( ll + rr ) >> ;
nth_element(ps+ll,ps+mid,ps+rr+);
fill(pos,ps[mid]);
if( ll < mid ) build(lson[pos]=++cnt,ll,mid-,dir^);
if( rr > mid ) build(rson[pos]=++cnt,mid+,rr,dir^);
update(pos);
}
inline int judge(int pos,int a,int b) {
return ( f(mx[pos][],mx[pos][],a,b) < c ) + ( f(mx[pos][],mi[pos][],a,b) < c ) +
( f(mi[pos][],mx[pos][],a,b) < c ) + ( f(mi[pos][],mi[pos][],a,b) < c ) ;
}
inline lli query(int pos,int a,int b) {
lli ret = ;
if( dv[pos].f(a,b) < c )
ret += dv[pos].h;
if( lson[pos] ) {
int jl = judge(lson[pos],a,b);
if( jl == ) ret += sum[lson[pos]];
else if( jl ) ret += query(lson[pos],a,b);
}
if( rson[pos] ) {
int jr = judge(rson[pos],a,b);
if( jr == ) ret += sum[rson[pos]];
else if( jr ) ret += query(rson[pos],a,b);
}
return ret;
} int main() {
static int n,m;
scanf("%d%d",&n,&m);
for(int i=;i<=n;i++)
scanf("%lld%lld%lld",ps[i].d,ps[i].d+,&ps[i].h); build(cnt=,,n,); for(int i=,a,b;i<=m;i++) {
scanf("%d%d%lld",&a,&b,&c);
printf("%lld\n",query(,a,b));
}
return ;
}