POJ3417 LCA+树dp

时间:2023-03-09 20:15:53
POJ3417 LCA+树dp

http://poj.org/problem?id=3417

题意:先给出一棵无根树,然后下面再给出m条边,把这m条边连上,然后每次你能毁掉两条边,规定一条是树边,一条是新边,问有多少种方案能使树断裂。

我们考虑加上每一条新边的情况,当一条新边加上之后,原本的树就会成环,环上除了所有的树边要断的话必然要砍掉这条新边才可行。

每一条新边成的环就是u - lca(u,v) - v,对每一条边的覆盖次数++

考虑所有的树边,被覆盖 == 0的时候,意味着单独砍掉这条树边即可,其他随便选一个新边就是一种方案,贡献值 += M;

被覆盖 == 1的时候,意味着砍掉这条树边必须砍掉另一条与他匹配的新边,贡献值 ++

被覆盖 >= 2的时候,这条树边被砍掉是没有意义的,因为不可能同时砍掉两条以上的新边

下面的问题就变成了如何求每一条边的被覆盖次数,我们只要对dp[lca] -= 2,dp[u]++,dp[v]++从根节点向下推,到叶子节点之后回溯,更新dp值即可

这就变成了一个喜闻乐见的树dp、

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
#define For(i, x, y) for(int i=x;i<=y;i++)
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Scl(x) scanf("%lld",&x);
#define Pri(x) printf("%d\n", x)
#define Prl(x) printf("%lld\n",x);
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second
typedef vector<int> VI;
const double eps = 1e-;
const int maxn = 1e5 + ;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + ;
int N,M,tmp,K;
int head[maxn],tot,cnt;
bool vis[maxn];
int F[maxn * ],P[maxn],rmq[maxn * ];
struct Edge{
int to,next;
}edge[maxn * ];
int dp[maxn];
LL sum;
struct ST{
int dp[maxn * ][];
int mm[maxn * ];
void init(int n){
mm[] = -;
for(int i = ; i <= n ; i ++){
mm[i] = ((i & (i - )) == )?mm[i - ] + :mm[i - ];
dp[i][] = i;
}
for(int j = ; j <= mm[n]; j ++){
for(int i = ; i + ( << j) - <= n ; i ++){
dp[i][j] = rmq[dp[i][j - ]] < rmq[dp[i + ( << (j - ))][j - ]]?dp[i][j - ]:dp[i + ( << (j - ))][j - ];
}
}
}
int query(int a,int b){
if(a > b) swap(a,b);
int k = mm[b - a + ];
return rmq[dp[a][k]] <= rmq[dp[b - ( << k) + ][k]]?dp[a][k]:dp[b - ( << k) + ][k];
}
}st;
void init(){
Mem(head,-);
tot = ;
}
void add(int u,int v){
edge[tot].next = head[u];
edge[tot].to = v;
head[u] = tot++;
}
void dfs(int u,int pre,int dep){
F[++cnt] = u;
rmq[cnt] = dep;
P[u] = cnt;
for(int i = head[u]; ~i; i = edge[i].next){
int v = edge[i].to;
if(v == pre ) continue;
dfs(v,u,dep + );
F[++cnt] = u;
rmq[cnt] = dep;
}
}
void LCA_init(int root){
cnt = ;
dfs(root,root,);
st.init( * N - );
}
int lca(int u,int v){
return F[st.query(P[u],P[v])];
}
int dfs2(int x,int last){
for(int i = head[x]; ~i ; i = edge[i].next){
int to = edge[i].to;
if(to == last) continue;
dfs2(to,x);
dp[x] += dp[to];
if(dp[to] == ) sum++;
else if(!dp[to]) sum += M;
}
return dp[x];
}
int main()
{
Sca2(N,M);
init();
For(i,,N - ){
int u,v; Sca2(u,v);
add(u,v); add(v,u);
}
LCA_init();
For(i,,M){
int u,v; Sca2(u,v);
dp[u]++; dp[v]++; dp[lca(u,v)] -= ;
}
dfs2(,-);
Prl(sum);
#ifdef VSCode
system("pause");
#endif
return ;
}