NOI.AC#2266-Bacteria【根号分治,倍增】

时间:2023-03-09 01:56:57
NOI.AC#2266-Bacteria【根号分治,倍增】

正题

题目链接:http://noi.ac/problem/2266


题目大意

给出\(n\)个点的一棵树,有一些边上有中转站(边长度为\(2\),中间有一个中转站),否则就是边长为\(1\)。

\(m\)次询问一个东西从\(x\)出发走到\(y\),每隔\(k\)步中转站会关闭一次(\(k\)的倍数步走完后不能在中转站上)。求在关闭多少次以内可以到达

\(1\leq n,m\leq 10^5\)


解题思路

发现最多只需要走\(2n\)步,然后每隔\(k\)步关闭一次,所以可以考虑根号分治。

先处理好总的倍增数组,后面求\(LCA\)和跳链要用。

对于\(k=1\)的询问,就看一下中间有没有中转站,如果有就是\(-1\)否则就是距离

对于\(k\leq \sqrt n\)的询问,我们对于每个\(k\)都进行一次预处理,处理每个周期每个点往上走能走到哪里。然后再处理一个倍增数组,然后询问的时候就在上面跳就好了

对于\(k>\sqrt n\)的询问直接每次暴力跳\(k\)步如果是中转站就跳\(k-1\)步,然后一直跳到\(LCA\)处

时间复杂度\(O(n\sqrt n\log n)\),调一下块的大小就能过了


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=2e5+10,T=17;
struct edge{
int to,next;
}a[N<<1];
struct node{
int x,y,k,id;
}q[N];
int n,m,Q,tot,num,ans[N],ls[N],dep[N],sd[N];
int g[N][100],f[N][T+1],h[N][T+1];
void addl(int x,int y){
a[++tot].to=y;
a[tot].next=ls[x];
ls[x]=tot;return;
}
bool cmp(node x,node y)
{return x.k<y.k;}
void dfs(int x,int fa){
g[x][0]=x;sd[x]=sd[fa]+(x>n);
f[x][0]=fa;dep[x]=dep[fa]+1;
for(int i=1;i<=Q;i++)g[x][i]=g[fa][i-1];
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa)continue;
dfs(y,x);
}
return;
}
int LCA(int x,int y){
if(dep[x]>dep[y])swap(x,y);
for(int i=T;i>=0;i--)
if(dep[f[y][i]]>=dep[x])y=f[y][i];
if(x==y)return x;
for(int i=T;i>=0;i--)
if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
return f[x][0];
}
void calc(int x,int fa,int k){
if(g[x][k]>n)h[x][0]=g[x][k-1];
else h[x][0]=g[x][k];
for(int i=ls[x];i;i=a[i].next){
int y=a[i].to;
if(y==fa)continue;
calc(y,x,k);
}
return;
}
int query(int x,int y,int k){
int p=LCA(x,y),ans=0;
for(int i=T;i>=0;i--){
if(dep[h[x][i]]>dep[p])x=h[x][i],ans+=(1<<i);
if(dep[h[y][i]]>dep[p])y=h[y][i],ans+=(1<<i);
}
if(x!=y){
int dis=dep[x]+dep[y]-2*dep[p];
if(dis>=0&&dis<=k)ans++;
else if(dis>k) ans+=2;
}
return ans;
}
int getf(int x,int k){
for(int i=0;i<=T;i++)
if((k>>i)&1)x=f[x][i];
return x;
}
int solve(int x,int y,int k){
int p=LCA(x,y),ans=0;
while(dep[x]>dep[p]){
int z=getf(x,k-1),t;
if(f[z][0]>n)t=z;
else t=f[z][0];
if(dep[t]>dep[p])x=t,ans++;
else break;
}
while(dep[y]>dep[p]){
int z=getf(y,k-1),t;
if(f[z][0]>n)t=z;
else t=f[z][0];
if(dep[t]>dep[p])y=t,ans++;
else break;
}
if(x!=y){
int dis=dep[x]+dep[y]-2*dep[p];
if(dis>=0&&dis<=k)ans++;
else if(dis>k) ans+=2;
}
return ans;
}
int main()
{
scanf("%d",&n);num=n;
for(int i=1;i<n;i++){
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
if(w==1)addl(x,y),addl(y,x);
else{
++num;
addl(x,num);addl(num,y);
addl(y,num);addl(num,x);
}
}
Q=sqrt(n);
if(Q>=70)Q=70;
scanf("%d",&m);
for(int i=1;i<=m;i++){
scanf("%d%d%d",&q[i].x,&q[i].y,&q[i].k);
q[i].id=i;
}
sort(q+1,q+1+m,cmp);
dfs(1,0);
for(int j=1;j<=T;j++)
for(int i=1;i<=num;i++)
f[i][j]=f[f[i][j-1]][j-1];
int l=1,r=1;
for(;r<=m&&q[r].k<=Q;r++,l=r){
while(r<m&&q[r].k==q[r+1].k)r++;
if(q[r].k==1){
for(int i=l;i<=r;i++){
int x=q[i].x,y=q[i].y,lca=LCA(x,y);
if(sd[x]+sd[y]-2*sd[lca])ans[q[i].id]=-1;
else ans[q[i].id]=dep[x]+dep[y]-2*dep[lca];
}
continue;
}
calc(1,1,q[r].k);
for(int j=1;j<=T;j++)
for(int i=1;i<=num;i++)
h[i][j]=h[h[i][j-1]][j-1];
for(int i=l;i<=r;i++)
ans[q[i].id]=query(q[i].x,q[i].y,q[i].k);
}
for(int i=r;i<=m;i++)
ans[q[i].id]=solve(q[i].x,q[i].y,q[i].k);
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}