[luoguP3644] [APIO2015]八邻旁之桥(权值线段树)

时间:2021-12-06 06:39:54

首先如果起点终点都在同一侧可以直接措置惩罚惩罚,如果需要过桥答案再加1

对付k即是1的情况

桥的坐标为x的话,,a和b为起点和终点坐标

$ans=\sum_{1}^{n} abs(a_{i}-x)+abs(b_{i}-x)$

起点和终点显然可以合并

那么 $ans=\sum_{1}^{n} abs(a_{i}-x)$

x为中位数就是最优解

对付k即是2的情况

首先有个结论:$(a_{i}+b_{i})/2$ 离哪座桥近,就选择哪座桥

可以把坐标凭据上面的公式排序,然后枚举中间点,分成摆布两部分

每一部分都有一座桥,那么就需要一个可以维护中位数,求和,删除/增加一个数的数据布局

平衡树或者线段树都可以

#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #define N 200100 #define LL long long #define root 1, 1, cnt #define ls now << 1, l, mid #define rs now << 1 | 1, mid + 1, r using namespace std; int k, n, m, cnt; LL ans, tmp, tot, L[2], R[2]; int a[N]; char s[2]; struct node { int x, y; }p[N]; struct tree { LL sum[N << 2], num[N << 2]; inline void update(int now, int l, int r, int x, int d) { num[now] += d, sum[now] += d * a[x]; if(l == r) return; int mid = (l + r) >> 1; if(x <= mid) update(ls, x, d); else update(rs, x, d); } inline int find(int now, int l, int r, int x) { if(l == r) { L[0] += sum[now], L[1] += num[now]; return a[l]; } int mid = (l + r) >> 1; if(num[now << 1] >= x) { R[0] += sum[now << 1 | 1], R[1] += num[now << 1 | 1]; return find(ls, x); } else { L[0] += sum[now << 1], L[1] += num[now << 1]; return find(rs, x - num[now << 1]); } } }t[2]; inline int read() { int x = 0, f = 1; char ch = getchar(); for(; !isdigit(ch); ch = getchar()) if(ch == ‘-‘) f = -1; for(; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - ‘0‘; return x * f; } inline void solve1() { int i, x; for(i = 1; i <= n; i++) { scanf("%s", s), a[++m] = read(); scanf("%s", s + 1), a[++m] = read(); if(s[0] == s[1]) ans += abs(a[m] - a[m - 1]), m -= 2; else ans++; } sort(a + 1, a + m + 1); x = a[m >> 1]; for(i = 1; i <= m; i++) ans += abs(a[i] - x); } inline bool cmp(node x, node y) { return x.x + x.y < y.x + y.y; } inline void solve2() { int i, x; for(i = 1; i <= n; i++) { m++; scanf("%s", s), p[m].x = read(); scanf("%s", s + 1), p[m].y = read(); if(s[0] == s[1]) tot += abs(p[m].x - p[m].y), m--; else tot++; } if(!m) { ans = tot; return; } sort(p + 1, p + m + 1, cmp); for(i = 1; i <= m; i++) a[++cnt] = p[i].x, a[++cnt] = p[i].y; sort(a + 1, a + cnt + 1); cnt = unique(a + 1, a + cnt + 1) - a - 1; for(i = 1; i <= m; i++) p[i].x = lower_bound(a + 1, a + cnt + 1, p[i].x) - a, p[i].y = lower_bound(a + 1, a + cnt + 1, p[i].y) - a; for(i = 1; i <= m; i++) t[1].update(root, p[i].x, 1), t[1].update(root, p[i].y, 1); x = t[1].find(root, m); ans = x * L[1] - L[0] + R[0] - x * R[1] + tot; for(i = 1; i <= m; i++) { t[0].update(root, p[i].x, 1), t[0].update(root, p[i].y, 1); t[1].update(root, p[i].x, -1), t[1].update(root, p[i].y, -1); tmp = L[0] = L[1] = R[0] = R[1] = 0; x = t[0].find(root, i); tmp += x * L[1] - L[0] + R[0] - x * R[1]; L[0] = L[1] = R[0] = R[1] = 0; x = t[1].find(root, m - i); tmp += x * L[1] - L[0] + R[0] - x * R[1]; ans = min(ans, tmp + tot); } } int main() { k = read(); n = read(); if(k == 1) solve1(); if(k == 2) solve2(); printf("%lld\n", ans); return 0; }