bzoj4326 树链剖分 + 线段树 // 二分 lca + 树上差分

时间:2021-07-14 14:47:03

https://www.lydsy.com/JudgeOnline/problem.php?id=4326

题意:N个点的树上给M条树链,问去掉一条边的权值之后所有树链长度和的最大值最小是多少。

首先想到去掉的树边一定是最长链上的树边,所以产生的思路就是寻找出一条询问里的最长链之后依次枚举上面所有的边,询问去掉这条边之后其余所有边的最大值。

由于N和M都在30W,直接暴力肯定不行,考虑转换思维,变为维护不经过这条边上的所有链的最大值,在这个最大值和最长链 - 这条边权之中取较大的值就是去掉这条边之后整棵树最长的链。而维护去不经过这条边上的最大值可以考虑用线段树维护.所以整理得到如下的流程。

1.树剖维护出所有边形成的序列,维护一个前缀和,利用前缀和计算所有询问的链的长度并且找出最长链t

2.对于每个询问进行操作,计算出询问中的链在树剖序列上经过的区间,反向将没有经过的区间最大值更新为该询问树链的长度.

3.枚举最长链t经过的每一个树边,更新答案即可.

细节:

1.由于这是对边权而不是点权的维护.要注意将树边转化为除了1之外的点,方向为向叶子节点方向.在树剖序列里的初始权值为边权.

2.由于这波点转边的操作,在树剖跳结点的时候,最后top[u] == top[v]的时候,深度较小的那一个点是取不到的,显然1 - 2这条边中只包含了边2而没有边1,但是在x跳到fa[top[x]]的过程中则将整条链全部计算

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
inline int read(){int now=;register char c=getchar();for(;!isdigit(c);c=getchar());
for(;isdigit(c);now=now*+c-'',c=getchar());return now;}
#define For(i, x, y) for(int i=x;i<=y;i++)
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define Scl(x) scanf("%lld",&x);
#define Pri(x) printf("%d\n", x)
#define Prl(x) printf("%lld\n",x);
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second
typedef vector<int> VI;
const double eps = 1e-;
const int maxn = 3e5 + ;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + ;
int N,M,K;
//链式前向星
struct Edge{
int to,dis,next;
}edge[maxn * ];
int head[maxn],tot;
void init(){
for(int i = ; i <= N ;i ++) head[i] = -;
tot = ;
}
void add(int u,int v,int w){
edge[tot].to = v;
edge[tot].next = head[u];
edge[tot].dis = w;
head[u] = tot++;
}
//树链剖分
int fa[maxn],dep[maxn],hson[maxn],size[maxn],ne[maxn];
int top[maxn],Index[maxn],To_num[maxn],nw[maxn];
void dfs1(int t,int la){
size[t] = ; int heavy = ;
for(int i = head[t]; ~i ; i = edge[i].next){
int v = edge[i].to;
if(v == la) continue;
dep[v] = dep[t] + ; fa[v] = t;
dfs1(v,t);
size[t] += size[v];
if(size[v] > heavy){
hson[t] = v;
ne[t] = edge[i].dis;
heavy = size[v];
}
}
}
int cnt;
void dfs2(int t,int la,int d){
nw[++cnt] = d; Index[t] = cnt;
top[t] = la; To_num[cnt] = t;
if(hson[t]) dfs2(hson[t],la,ne[t]);
for(int i = head[t]; ~i ; i = edge[i].next){
int v = edge[i].to;
if(v == hson[t] || v == fa[t]) continue;
dfs2(v,v,edge[i].dis);
}
}
//最大值线段树
struct Tree{
int l,r,MAX;
}tree[maxn << ];
void Build(int t,int l,int r){
tree[t].l = l; tree[t].r = r;
tree[t].MAX = ;
if(l == r) return;
int m = (l + r) >> ;
Build(t << ,l,m); Build(t << | ,m + ,r);
}
void Pushdown(int t){
tree[t << ].MAX = max(tree[t << ].MAX,tree[t].MAX);
tree[t << | ].MAX = max(tree[t << | ].MAX,tree[t].MAX);
}
void update(int t,int l,int r,int p){
if(p <= tree[t].MAX) return;
if(l <= tree[t].l && tree[t].r <= r){
tree[t].MAX = max(tree[t].MAX,p);
return;
}
Pushdown(t);
int m = (tree[t].l + tree[t].r) >> ;
if(r <= m) update(t << ,l,r,p);
else if(l > m) update(t << | ,l,r,p);
else{
update(t << ,l,m,p);
update(t << | ,m + ,r,p);
}
}
int query(int t,int p){
if(tree[t].l == tree[t].r) return tree[t].MAX;
Pushdown(t);
int m = (tree[t].l + tree[t].r) >> ;
if(p <= m) return query(t << ,p);
else return query(t << | ,p);
}
//前缀和
int pre[maxn];
void Build(){
for(int i = ; i <= N ; i ++){
pre[i] = pre[i - ] + nw[i];
}
}
struct Query{
int u,v,sum;
}Query[maxn];
int SUM(int u,int v){
int ans = ;
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u,v);
ans += pre[Index[u]] - pre[Index[top[u]] - ];
u = fa[top[u]];
}
if(Index[u] > Index[v]) swap(u,v);
ans += pre[Index[v]] - pre[Index[u]];
return ans;
}
void update(int t){
int u = Query[t].u,v = Query[t].v;
int S = Query[t].sum;
vector<PII>Q;
Q.pb(mp(,));
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u,v);
int l = Index[top[u]],r = Index[u];
if(l <= r) Q.pb(mp(l,r));
u = fa[top[u]];
}
if(dep[u] > dep[v]) swap(u,v);
if(u != ){
int l = Index[u] + ,r = Index[v];
if(l <= r) Q.pb(mp(l,r));
}
Q.pb(mp(N,N));
sort(Q.begin(),Q.end());
for(int i = ; i < Q.size() - ; i ++){
int l = Q[i].se + ,r = Q[i + ].fi - ;
if(l <= r){
update(,l,r,S);
}
}
}
void solve(int t){
int ans = Query[t].sum;
int u = Query[t].u,v = Query[t].v;
while(top[u] != top[v]){
if(dep[top[u]] < dep[top[v]]) swap(u,v);
for(int i = Index[top[u]]; i <= Index[u] ; i ++){
int mx = query(,i);
ans = min(ans,max(mx,Query[t].sum - nw[i]));
}
u = fa[top[u]];
}
if(dep[u] > dep[v]) swap(u,v);
int l = Index[u] + ,r = Index[v];
for(int i = l; i <= r ; i ++){
int mx = query(,i);
ans = min(ans,max(mx,Query[t].sum - nw[i]));
}
Pri(ans);
}
int main(){
Sca2(N,M); init();
for(int i = ; i <= N - ; i ++){
int u = read(),v = read(),w = read();
add(u,v,w); add(v,u,w);
}
dfs1(,-); cnt = -; hson[] = ;dfs2(,,);
Build(); Build(,,N);
int t = ,MAX = ;
for(int i = ; i <= M ; i ++){
Query[i].u = read(); Query[i].v = read();
Query[i].sum = SUM(Query[i].u,Query[i].v);
update(i);
if(MAX < Query[i].sum){
MAX = Query[i].sum;
t = i;
}
}
solve(t);
return ;
}

树链剖分

方法2:

当然,对于一类最大值最小的问题,肯定会首先想到去二分,也就是t时间内可以通过的答案,所有比t大的一定都可以通过。

因此我们选择二分最终答案,然后检验是否成立。

考虑先预处理出所有询问的链长度,二分答案之后即可知道在这条链上需要删除的边至少是多长,最终将所有(至少删掉的边)取一个最大值就是最终可以删掉的最小长度,对于总共cnt条需要删除链的询问,我们只要判断树上的点是否足够长并且被所有cnt条边覆盖即可。

判断是否被cnt条边覆盖的操作考虑用树上差分。

#include <map>
#include <set>
#include <ctime>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
inline int read(){int now=;register char c=getchar();for(;!isdigit(c);c=getchar());
for(;isdigit(c);now=now*+c-'',c=getchar());return now;}
#define For(i, x, y) for(int i=x;i<=y;i++)
#define _For(i, x, y) for(int i=x;i>=y;i--)
#define Mem(f, x) memset(f,x,sizeof(f))
#define Sca(x) scanf("%d", &x)
#define Sca2(x,y) scanf("%d%d",&x,&y)
#define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define Scl(x) scanf("%lld",&x);
#define Pri(x) printf("%d\n", x)
#define Prl(x) printf("%lld\n",x);
#define CLR(u) for(int i=0;i<=N;i++)u[i].clear();
#define LL long long
#define ULL unsigned long long
#define mp make_pair
#define PII pair<int,int>
#define PIL pair<int,long long>
#define PLL pair<long long,long long>
#define pb push_back
#define fi first
#define se second
typedef vector<int> VI;
const double eps = 1e-;
const int maxn = 3e5 + ;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + ;
int N,M,K;
struct Edge{
int to,dis,next;
}edge[maxn * ];
int head[maxn],tot;
void init(){
for(int i = ; i <= N ; i ++) head[i] = -;
tot = ;
}
void add(int u,int v,int w){
edge[tot].to = v;
edge[tot].dis = w;
edge[tot].next = head[u];
head[u] = tot++;
}
//lca
int val[maxn];
const int SP = ;
int fa[maxn][],dep[maxn],dis[maxn];
void dfs(int t,int la){
dep[t] = dep[la] + ;
fa[t][] = la;
for(int i = ; i < SP; i ++) fa[t][i] = fa[fa[t][i - ]][i - ];
for(int i = head[t]; ~i ; i = edge[i].next){
int v = edge[i].to;
if(v == la) continue;
val[v] = edge[i].dis;
dis[v] = dis[t] + edge[i].dis;
dfs(v,t);
}
}
int lca(int u,int v){
if(dep[u] < dep[v]) swap(u,v);
int t = dep[u] - dep[v];
for(int i = SP - ; i >= ; i --) if(t & ( << i)) u = fa[u][i];
for(int i = SP - ; i >= ; i --){
int uu = fa[u][i],vv = fa[v][i];
if(uu != vv){
u = uu;v = vv;
}
}
return u == v?u:fa[u][];
}
//Road
struct Road{
int u,v,sum;
int l;
}road[maxn]; int num[maxn],cnt,MIN;
bool flag;
void dfs2(int t,int la){
if(flag) return;
for(int i = head[t]; ~i ; i = edge[i].next){
int v = edge[i].to;
if(v == la) continue;
dfs2(v,t);
num[t] += num[v];
}
if(num[t] == cnt && val[t] >= MIN) flag = ;
}
bool check(int t){
MIN = ; cnt = ;
for(int i = ; i <= N ; i ++ ) num[i] = ;
for(int i = ; i <= M ; i ++){
if(road[i].sum <= t) continue;
MIN = max(MIN,road[i].sum - t);
cnt++;
num[road[i].u]++; num[road[i].v]++;
num[road[i].l] -= ;
}
flag = ;
dfs2(,);
return flag;
}
int solve(){
int l = ,r = 3e8 + ;
int ans = ;
while(l <= r){
int m = (l + r) >> ;
if(check(m)){
ans = m;
r = m - ;
}else{
l = m + ;
}
}
return ans;
}
int main(){
N = read(); M = read(); init();
for(int i = ; i <= N - ; i ++){
int u = read(),v = read(),w = read();
add(u,v,w); add(v,u,w);
}
dis[] = ;dfs(,);
for(int i = ; i <= M ; i ++){
road[i].u = read(),road[i].v = read();
road[i].l = lca(road[i].u,road[i].v);
road[i].sum = dis[road[i].u] + dis[road[i].v] - * dis[road[i].l];
//cout << "lca" << road[i].l << "sum " << road[i].sum << endl;
}
Pri(solve());
return ;
}