Tsinsen A1486. 树(王康宁)

时间:2023-03-10 07:05:45
Tsinsen A1486. 树(王康宁)

Description

一棵树,问至少有 \(k\) 个黑点的路径最大异或和.

Sol

点分治.

用点分治找重心控制树高就不说了,主要是对答案的统计的地方.

将所有路径按点的个数排序.

可以发现当左端点递增的时候右端点单调递减,时刻满足Trie树里的所有元素都是合法的即可,不断把右端点丢进去,用左端点统计答案.

主要跨越根的时候根的贡献计算了两次,需要删掉一次.

对于需要满足不是一颗子树,可以将Trie树上的节点打一个标记,表示这个节点及其子节点都是在某子树下的路径,子树个数大于1的时候这个标记就没用了.

我在维护标记的时候标记位置打错了...居然有95...调了好长时间QAQ...

Code

#include <bits/stdc++.h>
using namespace std; #define debug(a) cout<<#a<<"="<<a<<" "
const int N = 1e5+50;
const int M = 31; int n,k,kk,rt,ans=-1;
int pow2[M];
int bl[N],v[N],sz[N],t[N];
vector< int > g[N];
int usd[N]; struct pr { int x,y,z; };
bool operator < (const pr &a,const pr &b) { return a.x<b.x; }
vector< pr > S; struct Trie {
int cnt,rt;
int ch[N*M][2],s[N*M],bl[N*M]; int GetNode() { cnt++;ch[cnt][0]=ch[cnt][1]=s[cnt]=0;return cnt; }
void init() {
cnt=0,rt=GetNode();
}
void insert(int x,int fr) {
int o=rt,r;
for(int i=M-1;~i;i--) {
if(x&pow2[i]) r=1;else r=0;
if(!ch[o][r]) ch[o][r]=GetNode(),bl[ch[o][r]]=fr;
else bl[ch[o][r]]=bl[ch[o][r]]==fr ? fr : 0;
o=ch[o][r],s[o]++;
}
}
int getv(int x,int fr) {
int o=rt,r,res=0;
if(!ch[rt][0] && !ch[rt][1]) return -1;
for(int i=M-1;~i;i--) {
if(x&pow2[i]) r=1;else r=0;
if(s[ch[o][r^1]] && bl[ch[o][r^1]]!=fr) res|=pow2[i],r^=1;
if(bl[ch[o][r]]==fr) return -1;
o=ch[o][r];
}return res;
}
}py;
inline int in(int x=0,char ch=getchar()) { while(ch>'9' || ch<'0') ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();return x; } void GetRoot(int u,int fa,int nn) {
t[u]=0,sz[u]=1;
for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++)
if((*i)!=fa && !usd[(*i)]) GetRoot(*i,u,nn),t[u]=max(t[u],sz[*i]),sz[u]+=sz[*i];
t[u]=max(t[u],nn-sz[u]);
if(t[u]<t[rt]) rt=u;
}
void GetS(int u,int fa,int c,int vv,int ff) {
S.push_back((pr){ c,vv,ff });
if(c>=k) ans=max(ans,vv);
for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++)
if((*i)!=fa && !usd[*i]) GetS(*i,u,c+bl[*i],vv^v[*i],ff);
}
void GetAns(int u,int fa,int nn) {
usd[u]=1,py.init(),S.clear();if(bl[u]>=k) ans=max(ans,v[u]);
for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++)
if((*i)!=fa && !usd[(*i)]) GetS((*i),u,bl[u]+bl[(*i)],v[(*i)],*i);
sort(S.begin(),S.end()); // cout<<u<<" : "<<nn<<endl;
// for(int i=0;i<(int)S.size();i++) cout<<S[i].x<<" "<<S[i].y<<" "<<S[i].z<<endl; int lim=S.size(),l=0,r=lim-1;
for(;l<lim;l++) {
while(l<r && S[l].x+S[r].x>=k+bl[u]) py.insert(S[r].y,S[r].z),r--;
ans=max(ans,py.getv(S[l].y^v[u],S[l].z));
// debug(l),debug(r),debug(ans)<<endl;
}
// debug(ans)<<endl;
// cout<<"-------------------------"<<endl; int ss;
for(vector< int >::iterator i=g[u].begin();i!=g[u].end();i++)
if((*i)!=fa && !usd[(*i)]) rt=0,ss=sz[(*i)]>sz[u] ? nn-sz[u] : sz[*i],GetRoot((*i),u,ss),GetAns(rt,rt,ss);
}
int main() {
n=in(),k=in();
for(int i=1;i<=n;i++) bl[i]=in();
for(int i=1;i<=n;i++) v[i]=in();
for(int i=1,u,v;i<n;i++) u=in(),v=in(),g[u].push_back(v),g[v].push_back(u); pow2[0]=1;for(int i=1;i<M;i++) pow2[i]=pow2[i-1]<<1;
// for(int i=0;i<M;i++) cout<<pow2[i]<<endl;
rt=0,t[rt]=n+1,GetRoot(1,1,n),GetAns(rt,rt,n); cout<<ans<<endl;
return 0;
}