洛谷 P4427 求和

时间:2022-05-22 16:42:36

传送门啦

思路:

开始不肿么容易想到用倍增,但是想到需要求 $ Lca $ ,倍增这种常数小而且快的方法就很方便了。求 $ Lca $ 就是一个最普通的板子。那现在考虑怎么求题目中的结果。

树上差分可能听起来很高大上,但是前缀和并不陌生,树上差分就理解成树上前缀和就好了:

$ sum[u] + sum[v] - sum[lca(u , v)] ; $

树上差分之前要先预处理出 $ dis $ 数组, $ dis[i][j] $ 表示从 $ i $ 出发到根节点(本题中的1号节点)的 $ j $ 次方。

	for(re long long j = 1 ; j <= 50 ; ++ j)
dis[x][j] = dis[fa][j] + quick_power(dep[x] , j) ;

这就是预处理的代码了, $ dep $ 表示深度 , $ quick - power $ 为快速幂。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#define re register
using namespace std ;
const long long maxn = 300005 ;
const long long mod = 998244353 ; inline long long read () {
long long f = 1 , x = 0 ;
char ch = getchar () ;
while(ch > '9' || ch < '0') {if(ch == '-') f = -1 ; ch = getchar () ;}
while(ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + ch - '0' ; ch = getchar () ;}
return x * f ;
} inline void print (long long x){
if(x < 0) {putchar('-') ; x = -x ;}
if(x > 9) print(x / 10) ;
putchar(x % 10 + '0') ;
} long long n , x , y , m , a , b , c ;
long long head[maxn] , tot ;
long long ans ; struct Edge {
long long from , to , next ;
}edge[maxn << 1] ; inline void add (long long u , long long v) {
edge[++tot].from = u ;
edge[tot].to = v ;
edge[tot].next = head[u] ;
head[u] = tot ;
} long long quick_power (long long a , long long b) {
long long res = a , ans = 1 ;
while(b) {
if(b & 1) ans = ans * res % mod ;
res = res * res % mod ;
b >>= 1 ;
}
return ans % mod ;
} long long dep[maxn] , f[maxn][21] , dis[maxn][51]; inline void dfs (long long x , long long fa) {
dep[x] = dep[fa] + 1 ;
f[x][0] = fa ;
for(re long long j = 1 ; j <= 50 ; ++ j)
dis[x][j] = dis[fa][j] + quick_power(dep[x] , j) ;
for(re long long i = 1 ; (1 << i) <= dep[x] ; ++ i) {
f[x][i] = f[f[x][i - 1]][i - 1] ;
}
for(re long long i = head[x] ; i ; i = edge[i].next) {
long long v = edge[i].to ;
if(v != fa) dfs(v , x) ;
}
} inline long long lca (long long a , long long b) {
if(dep[a] < dep[b]) swap(a , b) ;
for(re long long i = 20 ; i >= 0 ; -- i) {
if((1 << i) <= (dep[a] - dep[b]) ) {
a = f[a][i] ;
}
}
if(a == b) return a ;
for(re long long i = 20 ; i >= 0 ; -- i) {
if((1 << i) <= dep[a] && (f[a][i] != f[b][i])) {
a = f[a][i] ;
b = f[b][i] ;
}
}
return f[a][0] ;
} int main () {
n = read () ;
for(re long long i = 1 ; i <= n - 1 ; ++ i) {
x = read () ; y = read () ;
add(x , y) ;
add(y , x) ;
}
dep[1] = -1 ;
dfs(1 , 1) ;
m = read () ;
for(re long long i = 1 ; i <= m ; ++ i) {
a = read () ; b = read () ; c = read () ;
long long root = lca(a , b) ;
ans = (dis[a][c] - dis[root][c] + dis[b][c] - dis[f[root][0]][c]) % mod ;
//printf("%d %d %d %d\n" , dis[a][c] , dis[b][c] , dis[root][c] , quick_power(dep[root] , c) % mod ) ;
print(ans) ;
printf("\n") ;
}
return 0 ;
}