复制的树缩点,主席树查k小,毫无技术含量,纯码农题。
#include<bits/stdc++.h>
#define u first
#define v second
#define F lower_bound
#define I (i+j+2>>1)
#define J (i+j>>1)
using namespace std;
int n1,n2,m,n4;
typedef long long ll;
map<ll,int>nu;
const int N=1e5+5;
struct edge{
int v;edge*s;
}z[N*2];
edge*a=z,*h[N];
void ins(int u,int v){
edge s={v,h[u]};
*(h[u]=a++)=s;
}
typedef int arr[N];
arr l,r,b,po,id,d[2],p[2][17];
ll n3,u,v,c[N];
void dfs(int u){
r[id[l[u]=++n4]=u]=1;
for(edge*i=h[u];i;i=i->s)
if(i->v^p[0][0][u]){
d[0][i->v]=d[0][p[0][0][i->v]=u]+1;
dfs(i->v);
r[u]+=r[i->v];
}
}
typedef struct node*ptr;
struct node{
ptr i,j;int s;
}e[N][17];
void ins(int i,int j,int s,ptr u,ptr v){
while(i<j){
*v=*u;
if(s>J)u=u->j,v=v->j=v+1,i=J+1;
else
++v->s,u=u->i,v=v->i=v+1,j=I-1;
}
}
int ask(int i,int j,int k,ptr u,ptr v){
while(i<j){
int s=v->s-u->s;
if(k<=s)u=u->i,v=v->i,j=I-1;
else
k-=s,u=u->j,v=v->j,i=J+1;
}
return i;
}
int lca(int i,int s,int t){
if(d[i][s]<d[i][t])swap(s,t);
int k=d[i][s]-d[i][t];
for(int j=16;~j;--j)
if(k>>j&1)s=p[i][j][s];
if(s==t)return s;
for(int j=16;~j;--j)
if(p[i][j][s]^p[i][j][t])
s=p[i][j][s],t=p[i][j][t];
return p[i][0][s];
}
typedef pair<int,int>vec;
typedef pair<vec,int>tri;
tri ask(ll v){
typeof(nu.end())j=nu.F(v);
int s=po[j->v];
return tri(vec(ask(1,n1,v-j->u+r[s],e[l[s]-1],e[l[s]+r[s]-1]),s),j->v);
}
int ask(int s,int k){
for(int j=16;~j;--j)
if(k>>j&1)s=p[1][j][s];
return s;
}
int main(){
scanf("%d%d%d",&n1,&n2,&m),++n2;
for(int i=2;i<=n1;++i)
scanf("%lld%lld",&u,&v),ins(u,v),ins(v,u);
dfs(po[nu[n3=n1]=1]=1);
e[0][0]=(node){e[0],e[0]};
for(int i=1;i<=n1;++i)
ins(1,n1,id[i],e[i-1],e[i]);
for(int i=2;i<=n2;++i){
scanf("%lld%lld",&u,&v);
tri s=ask(v);
d[1][i]=d[1][p[1][0][i]=s.v]+1,c[i]=c[s.v]+d[0][b[i]=s.u.u]-d[0][s.u.v]+1,po[nu[n3+=r[u]]=i]=u;
}
for(int i=1;i<17;++i){
for(int j=1;j<=n1;++j)
p[0][i][j]=p[0][i-1][p[0][i-1][j]];
for(int j=2;j<=n2;++j)
p[1][i][j]=p[1][i-1][p[1][i-1][j]];
}
while(m--){
scanf("%lld%lld",&u,&v);
tri s1=ask(u),t1=ask(v);
int l1=lca(1,s1.v,t1.v);
int s2=s1.u.u,t2=t1.u.u;
ll l3=0;
if(s1.v^l1){
int s3=ask(s1.v,d[1][s1.v]-d[1][l1]-1);
l3+=d[0][s2]-d[0][s1.u.v]+c[s1.v]-c[s3]+1,s2=b[s3];
}
if(t1.v^l1){
int t3=ask(t1.v,d[1][t1.v]-d[1][l1]-1);
l3+=d[0][t2]-d[0][t1.u.v]+c[t1.v]-c[t3]+1,t2=b[t3];
}
int l2=lca(0,s2,t2);
l3+=d[0][s2]+d[0][t2]-d[0][l2]*2;
printf("%lld\n",l3);
}
}