cf 700 B Connecting Universities

时间:2023-03-09 16:28:40
cf 700 B  Connecting Universities

题意:现在给以一棵$n$个结点的树,并给你$2k$个结点,现在要求你把这些节点互相配对,使得互相配对的节点之间的距离(路径上经过边的数目)之和最大。数据范围$1 \leq n \leq 200000, 2k \leq n$。

分析:贪心选择距离最大、次大...的结点对?貌似不对。暴力枚举所有可能?对但不可行。考虑节点对之间的距离实际上就是它们到LCA距离之和。因此单独考虑每个结点,它对答案的贡献实际上就是它到与其匹配节点的LCA的距离,这个距离必然不超过它到根的距离。如果我们有一种方法使得每个结点对答案的贡献都等于这个上界就好了,那么答案就是所有标记节点到根的距离之和。注意,要满足这个要求,与其匹配的结点应该和它不在根的同一个子树中,考虑根节点的各个子树中的标记节点,如果满足标记节点个数最多的子树不超过其余子树标记节点个数之和,那么是存在某种配对方法是得相互配对的节点分属不同子树的。即要求子树最大重量(含标记节点数目)不超过$k$即可(容易证明,讨论一下根是否是标记节点)。那么我们只需要对原图进行一次dfs找出这样的结点,并返回该节点作为根,再次dfs找出所有标记节点到根的距离之和即使答案。可以确定一定能够找到这样的根(提供一种证明思路,归纳法)。这样这道题可以在$O(n)$时间内解出。代码如下:

 #include <algorithm>
 #include <cstdio>
 #include <cstring>
 #include <string>
 #include <queue>
 #include <map>
 #include <set>
 #include <stack>
 #include <ctime>
 #include <cmath>
 #include <iostream>
 #include <assert.h>
 #pragma comment(linker, "/STACK:102400000,102400000")
 #define max(a, b) ((a) > (b) ? (a) : (b))
 #define min(a, b) ((a) < (b) ? (a) : (b))
 #define mp std :: make_pair
 #define st first
 #define nd second
 #define keyn (root->ch[1]->ch[0])
 #define lson (u << 1)
 #define rson (u << 1 | 1)
 #define pii std :: pair<int, int>
 #define pll pair<ll, ll>
 #define pb push_back
 #define type(x) __typeof(x.begin())
 #define foreach(i, j) for(type(j)i = j.begin(); i != j.end(); i++)
 #define FOR(i, s, t) for(int i = (s); i <= (t); i++)
 #define ROF(i, t, s) for(int i = (t); i >= (s); i--)
 #define dbg(x) std::cout << x << std::endl
 #define dbg2(x, y) std::cout << x << " " << y << std::endl
 #define clr(x, i) memset(x, (i), sizeof(x))
 #define maximize(x, y) x = max((x), (y))
 #define minimize(x, y) x = min((x), (y))
 using namespace std;
 typedef long long ll;
 const int int_inf = 0x3f3f3f3f;
 const ll ll_inf = 0x3f3f3f3f3f3f3f3f;
 ) - );
 const double double_inf = 1e30;
 ;
 typedef unsigned long long ul;
 typedef unsigned int ui;
 inline int readint(){
     int x;
     scanf("%d", &x);
     return x;
 }
 inline int readstr(char *s){
     scanf("%s", s);
     return strlen(s);
 }

 class cmpt{
 public:
     bool operator () (const int &x, const int &y) const{
         return x > y;
     }
 };

 int Rand(int x, int o){
     //if o set, return [1, x], else return [0, x - 1]
     ;
     int tem = (int)((double)rand() / RAND_MAX * x) % x;
      : tem;
 }
 ll ll_rand(ll x, int o){
     ;
     ll tem = (ll)((double)rand() / RAND_MAX * x) % x;
      : tem;
 }

 void data_gen(){
     srand(time());
     freopen("in.txt", "w", stdout);
     ;
     printf("%d\n", kases);
     while(kases--){
         ll sz = 1e18;
         printf());
     }
 }

 struct cmpx{
     bool operator () (int x, int y) { return x > y; }
 };
 ;
 int n, k;
 bool np[maxn];
 struct E{
     int to, nex;
 }e[maxn << ];
 int head[maxn], N;
 void addE(int x, int y){
     e[N].nex = head[x];
     e[N].to = y;
     head[x] = N++;
     e[N].nex = head[y];
     e[N].to = x;
     head[y] = N++;
 }
 int cnt[maxn];
 void dfs1(int u, int fa){
     cnt[u] = np[u];
     for(int i = head[u]; ~i; i = e[i].nex){
         int v = e[i].to;
         if(v == fa) continue;
         dfs1(v, u);
         cnt[u] += cnt[v];
     }
 }
 int rt;
 bool flag;
 void dfs2(int u, int fa){
     if(flag) return;
     ;
     ;
     for(int i = head[u]; ~i; i = e[i].nex){
         int v = e[i].to;
         if(v == fa) continue;
         maximize(maxi, cnt[v]);
         tot += cnt[v];
     }
     if(np[u]) ++tot;
     maximize(maxi,  * k - tot);
     if(maxi <= k){
         rt = u;
         flag = ;
         return;
     }
     for(int i = head[u]; ~i; i = e[i].nex) if(e[i].to != fa) dfs2(e[i].to, u);
 }
 ll dfs3(int u, int fa, int d){
     ll tem = ;
     if(np[u]) tem += d;
     for(int i = head[u]; ~i; i = e[i].nex){
         int v = e[i].to;
         if(v == fa) continue;
         tem += dfs3(v, u, d + );
     }
     return tem;
 }
 int main(){
     //data_gen(); return 0;
     //C(); return 0;
     ;
     if(debug) freopen("in.txt", "r", stdin);
     //freopen("out.txt", "w", stdout);
     while(~scanf("%d", &n)){
         k = readint();
         clr(np, );
         FOR(i, ,  * k) np[readint()] = ;
         N = , clr(head, -);
         FOR(i, , n){
             int x = readint(), y = readint();
             addE(x, y);
         }
         dfs1(, -);
         rt = ;
         flag = ;
         dfs2(, -);
         ll ans = dfs3(rt, -, );
         printf("%lld\n", ans);
     }
     ;
 }

code: