CodeChef - PRIMEDST Prime Distance On Tree 树分治 + FFT

时间:2023-06-07 11:04:14

Prime Distance On Tree

Problem description.

You are given a tree. If we select 2 distinct nodes uniformly at random, what's the probability that the distance between these 2 nodes is a prime number?

Input

The first line contains a number N: the number of nodes in this tree.
The following N-1 lines contain pairs a[i] and b[i], which means there is an edge with length 1 between a[i] and b[i].

Output

Output a real number denote the probability we want.
You'll get accept if the difference between your answer and standard answer is no more than 10^-6.

Constraints

2 ≤ N ≤ 50,000

The input must be a tree.

Example

Input:
5
1 2
2 3
3 4
4 5 Output:
0.5

Explanation

We have C(5, 2) = 10 choices, and these 5 of them have a prime distance:

1-3, 2-4, 3-5: 2

1-4, 2-5: 3

Note that 1 is not a prime number.

题意:

    给你一颗树,n个点,n-1条边

    让你求任意选两个不同的点,其距离是素数的概率

题解:

    点分治

    求出只经过重心的所有路径深度种类数

    让属于不同的子树的点,利用其深度进行任意组合(FFT加速)求出最后组合结果

    累积是素数的答案即可,复杂度 n* logn * logn

#include<bits/stdc++.h>
using namespace std;
#pragma comment(linker, "/STACK:102400000,102400000")
#define ls i<<1
#define rs ls | 1
#define mid ((ll+rr)>>1)
#define pii pair<int,int>
#define MP make_pair
typedef long long LL;
typedef unsigned long long ULL;
const long long INF = 1e18+1LL;
const double pi = acos(-1.0);
const int N = 3e5+, M = 1e6+, mod = 1e9+,inf = 2e9; struct Complex {
double r , i ;
Complex () {}
Complex ( double r , double i ) : r ( r ) , i ( i ) {}
Complex operator + ( const Complex& t ) const {
return Complex ( r + t.r , i + t.i ) ;
}
Complex operator - ( const Complex& t ) const {
return Complex ( r - t.r , i - t.i ) ;
}
Complex operator * ( const Complex& t ) const {
return Complex ( r * t.r - i * t.i , r * t.i + i * t.r ) ;
}
} ; void FFT ( Complex y[] , int n , int rev ) {
for ( int i = , j , t , k ; i < n ; ++ i ) {
for ( j = , t = i , k = n >> ; k ; k >>= , t >>= ) j = j << | t & ;
if ( i < j ) swap ( y[i] , y[j] ) ;
}
for ( int s = , ds = ; s <= n ; ds = s , s <<= ) {
Complex wn = Complex ( cos ( rev * * pi / s ) , sin ( rev * * pi / s ) ) , w ( , ) , t ;
for ( int k = ; k < ds ; ++ k , w = w * wn ) {
for ( int i = k ; i < n ; i += s ) {
y[i + ds] = y[i] - ( t = w * y[i + ds] ) ;
y[i] = y[i] + t ;
}
}
}
if ( rev == - ) for ( int i = ; i < n ; ++ i ) y[i].r /= n ;
}
Complex s[N],t[N]; int vis[N],f[N],siz[N],n,allnode,root;
int P[N];
vector<int > G[N];
void init() {
for(int i = ; i <= *n; ++i) {
if(!P[i]) {
for(int j = i+i; j <= *n; j += i)
P[j] = ;
}
}
P[] = ;
for(int i = ; i <= n; ++i) vis[i] = ;
}
void getroot(int u,int fa) {
f[u] = ;
siz[u] = ;
for(int i = ; i < G[u].size(); ++i) {
int to = G[u][i];
if(vis[to] || to == fa) continue;
getroot(to,u);
siz[u] += siz[to];
f[u] = max(f[u],siz[to]);
}
f[u] = max(f[u], allnode - siz[u]);
if(f[u] < f[root]) root = u;
} int len = ,cnt[N],dep[N],nowcnt[N],mxdep;
LL ans = ;
void getdeep(int u,int f) {
siz[u] = ;
for(int i = ; i < G[u].size(); ++i) {
int to = G[u][i];
if(vis[to] || to == f) continue;
dep[to] = dep[u] + ;
getdeep(to,u);
mxdep = max(mxdep,dep[to]);
siz[u] += siz[to];
}
}
void dfs(int u,int f,int p) {
nowcnt[dep[u]]+=p;
if(p == -) cnt[dep[u]] += ;
for(int i = ; i < G[u].size(); ++i) {
int to = G[u][i];
if(vis[to] || to == f) continue;
dfs(to,u,p);
}
}
LL cal(int u) {
LL ret = ;
for(int i = ; i <= n; ++i) cnt[i] = ;
cnt[] = ;
dep[u] = ;
mxdep = -;
getdeep(u,);
len = ;
while(len <= *mxdep) len<<=;
for(int i = ; i < G[u].size(); ++i) {
int to = G[u][i];
if(vis[to]) continue;
dfs(to,u,);
for(int j = ; j < len; ++j) t[j] = Complex(nowcnt[j],);
for(int j = ; j < len; ++j) s[j] = Complex(cnt[j],); FFT(s,len,);FFT(t,len,);
for(int j = ; j < len; ++j) s[j] = s[j] * t[j];
FFT(s,len,-);
for(int j = ;j < len; ++j) {
LL tmp = (s[j].r+0.5); if(P[j]) continue; ret += tmp;
}
dfs(to,u,-);
}
return ret;
}
void work(int u) {
vis[u] = ;
ans += cal(u);
// exit(0);
for(int i = ; i < G[u].size(); ++i) {
int to = G[u][i];
if(vis[to]) continue;
allnode = siz[to];
root = ;
getroot(to,);
work(root);
}
}
int main() {
scanf("%d",&n);
while(len <= n) len<<=;
init();
for(int i = ; i < n; ++i) {
int x,y;
scanf("%d%d",&x,&y);
G[x].push_back(y);
G[y].push_back(x);
} ans = ;
f[] = inf;root = ;allnode = n;
getroot(,);
work(root);
printf("%.6f\n",(double)1.0*ans/((double)n*(n-)/));
return ;
}