洛谷 P4148 简单题 KD-Tree 模板题

时间:2023-07-22 14:35:56

Code:

//洛谷 P4148 简单题 KD-Tree 模板题 

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
using namespace std;
void setIO(string a){ freopen((a+".in").c_str(),"r",stdin);} #define maxn 500007 int root,d,tot;
struct Data
{
int ch[2],sum,w,minv[2],maxv[2],p[2];
}node[maxn];
bool cmp(Data i,Data j)
{
return i.p[d] == j.p[d]? i.p[d^1] < j.p[d^1]: i.p[d]<j.p[d];
}
bool isout(int k,int x1,int y1,int x2,int y2)
{
if(node[k].maxv[0]<x1||node[k].minv[0]>x2||node[k].maxv[1]<y1||node[k].minv[1]>y2) return 1;
return 0;
}
bool isin(int k,int x1,int y1,int x2,int y2)
{
if(node[k].maxv[0]<=x2 && node[k].minv[0]>=x1 && node[k].maxv[1]<=y2 && node[k].minv[1]>=y1) return 1;
return 0;
}
void pushup(int x,int o)
{
node[x].minv[0]=min(node[x].minv[0],node[o].minv[0]);
node[x].maxv[0]=max(node[x].maxv[0],node[o].maxv[0]);
node[x].minv[1]=min(node[x].minv[1],node[o].minv[1]);
node[x].maxv[1]=max(node[x].maxv[1],node[o].maxv[1]);
node[x].sum+=node[o].sum;
}
int build(int l,int r,int o)
{
int mid=(l+r)>>1;
d=o, nth_element(node + l, node + mid, node + r + 1, cmp);
node[mid].minv[0] = node[mid].maxv[0] = node[mid].p[0];
node[mid].minv[1] = node[mid].maxv[1] = node[mid].p[1];
node[mid].sum = node[mid].w;
node[mid].ch[0]=node[mid].ch[1]=0;
if(l < mid) node[mid].ch[0] = build(l, mid-1, o^1), pushup(mid, node[mid].ch[0]);
if(r > mid) node[mid].ch[1] = build(mid + 1, r , o^1), pushup(mid, node[mid].ch[1]);
return mid;
}
int query(int k,int x1,int y1,int x2,int y2)
{
if(!k||isout(k,x1,y1,x2,y2)) return 0;
if(isin(k,x1,y1,x2,y2)) return node[k].sum;
int ans=0;
if(x1<=node[k].p[0]&&x2>=node[k].p[0]&&y1<=node[k].p[1]&&y2>=node[k].p[1]) ans+=node[k].w;
ans+=query(node[k].ch[0],x1,y1,x2,y2)+query(node[k].ch[1],x1,y1,x2,y2);
return ans;
}
void insert(int x)
{
int *t = &root;
d = 0;
while(*t) pushup(*t , x) , t = &node[*t].ch[ node[x].p[d] > node[*t].p[d] ] , d ^= 1;
*t = x;
}
int main()
{
//setIO("input");
int n,lastans=0,opt;
scanf("%d",&n);
while(1)
{
scanf("%d",&opt);
if(opt == 1)
{
int x,y,a;
scanf("%d%d%d",&x,&y,&a),x^=lastans,y^=lastans,a^=lastans;
++tot;
node[tot].p[0]=x,node[tot].p[1]=y;
node[tot].maxv[0]=node[tot].minv[0]=x;
node[tot].maxv[1]=node[tot].minv[1]=y;
node[tot].w=node[tot].sum=a;
insert(tot);
if(tot%10000==0) root=build(1,tot,0);
}
if(opt == 2)
{
int x1,y1,x2,y2;
scanf("%d%d%d%d",&x1,&y1,&x2,&y2),x1^=lastans,y1^=lastans,x2^=lastans,y2^=lastans;
lastans=query(root,x1,y1,x2,y2);
printf("%d\n",lastans);
}
if(opt == 3) break;
}
return 0;
}