[BZOJ4520][Cqoi2016]K远点对(kd-tree+堆)

时间:2021-03-12 00:22:49

题目描述

传送门

题解

枚举每一个点,求前k大用一个小根堆维护一下就行了
kd-tree查询的时候一对点会算两次,所以求前2k大
刚开始手残T死了mdzz…

代码

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<queue>
#include<vector>
using namespace std;
#define LL long long
#define N 100005

int n,k,root,cnt,cmpd;
LL x,y,ans;
struct data
{
    int l,r;
    LL d[2],mn[2],mx[2];
}tr[N];
priority_queue <LL,vector<LL>,greater<LL> > q;

void update(int x)
{
    int l=tr[x].l,r=tr[x].r;
    if (l)
    {
        tr[x].mx[0]=max(tr[x].mx[0],tr[l].mx[0]);
        tr[x].mn[0]=min(tr[x].mn[0],tr[l].mn[0]);
        tr[x].mx[1]=max(tr[x].mx[1],tr[l].mx[1]);
        tr[x].mn[1]=min(tr[x].mn[1],tr[l].mn[1]);
    }
    if (r)
    {
        tr[x].mx[0]=max(tr[x].mx[0],tr[r].mx[0]);
        tr[x].mn[0]=min(tr[x].mn[0],tr[r].mn[0]);
        tr[x].mx[1]=max(tr[x].mx[1],tr[r].mx[1]);
        tr[x].mn[1]=min(tr[x].mn[1],tr[r].mn[1]);
    }
}
int cmp(data a,data b)
{
    return a.d[cmpd]<b.d[cmpd]||a.d[cmpd]==b.d[cmpd]&&a.d[cmpd^1]<b.d[cmpd^1];
}
int build(int l,int r,int d)
{
    int mid=(l+r)>>1;
    cmpd=d;
    nth_element(tr+l,tr+mid,tr+r+1,cmp);
    tr[mid].mx[0]=tr[mid].mn[0]=tr[mid].d[0];
    tr[mid].mx[1]=tr[mid].mn[1]=tr[mid].d[1];
    if (l<mid) tr[mid].l=build(l,mid-1,d^1);
    if (mid<r) tr[mid].r=build(mid+1,r,d^1);
    update(mid);
    return mid;
}
LL qr(LL x)
{
    return x*x;
}
LL dist(int now)
{
    LL dis=0LL;
    dis+=max(qr(x-tr[now].mn[0]),qr(tr[now].mx[0]-x));
    dis+=max(qr(y-tr[now].mn[1]),qr(tr[now].mx[1]-y));
    return dis;
}
void query(int now)
{
    LL dl=-1LL,dr=-1LL,d0;
    d0=qr(tr[now].d[0]-x)+qr(tr[now].d[1]-y);
    if (cnt<k) ++cnt,q.push(d0);
    else if (d0>q.top()) q.pop(),q.push(d0);
    if (tr[now].l) dl=dist(tr[now].l);
    if (tr[now].r) dr=dist(tr[now].r);
    if (dl>dr)
    {
        if (tr[now].l&&(cnt<k||dl>q.top())) query(tr[now].l);
        if (tr[now].r&&(cnt<k||dr>q.top())) query(tr[now].r);
    }
    else
    {
        if (tr[now].r&&(cnt<k||dr>q.top())) query(tr[now].r);
        if (tr[now].l&&(cnt<k||dl>q.top())) query(tr[now].l);
    }
}
int main()
{
    scanf("%d%d",&n,&k);k*=2;
    for (int i=1;i<=n;++i) scanf("%lld%lld",&tr[i].d[0],&tr[i].d[1]);
    root=build(1,n,0);
    for (int i=1;i<=n;++i)
    {
        x=tr[i].d[0],y=tr[i].d[1];
        query(root);
    }
    ans=q.top();
    printf("%lld\n",ans);
}