BZOJ 3295:[Cqoi2011]动态逆序对(三维偏序 CDQ分治+树状数组)

时间:2023-03-10 05:25:23
BZOJ 3295:[Cqoi2011]动态逆序对(三维偏序 CDQ分治+树状数组)

http://www.lydsy.com/JudgeOnline/problem.php?id=3295

题意:简单明了。

思路:终于好像有点明白CDQ分治处理三维偏序了。把删除操作看作是插入操作,那么可以按照插入的时间顺序看作是一维x,插入的数在原本序列的下标是一维y,插入的数本身是一维z。那么问题可以转化成每插入一个数(xx,yy,zz),求有多少个数(x,y,z)使得 x < xx,y < yy,z > zz 。一开始先对 x 进行排序,然后进行CDQ分治。这样可以干掉一维,保证随着时间递增。在分治的时候,通过标记判断那一个点属于左半区间还是右半区间,然后对 y 进行排序。如果在左半区间,那么它的 x 必定是小于 右半区间的,它所修改的结果会影响右半区间的查询,因此要去更新左半区间的元素。因为 y 是升序的,那么正着查询大于该点的 z 值的个数,就是查询可以满足 y < yy, z > zz 的条件的个数了。反着查询小于该点的 z 值的个数,即满足 y > yy, z < zz 的条件的个数。这样就可以找全插入一个数对整个数组产生的逆序对的个数了。

 #include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
#define N 100010
struct node {
int x, y, z, f;
node () {}
node (int x, int y, int z) : x(x), y(y), z(z) {}
} p[N], s[N]; int bit[N], gap, num[N], Hash[N];
LL ans[N]; bool cmpx(const node &a, const node &b) {
return a.x < b.x;
}
bool cmpy(const node &a, const node &b) {
if(a.y == b.y) return a.z < b.z;
return a.y < b.y;
} int lowbit(int x) { return x & (-x); } LL query(int x) {
LL ans = ;
while(x) { ans += bit[x]; x -= lowbit(x); }
return ans;
} void update(int x, int w) {
while(x <= gap) { bit[x] += w; x += lowbit(x); }
} void CDQ(int l, int r) {
if(l == r) return ;
int m = (l + r) >> , cnt = ;
CDQ(l, m); CDQ(m + , r);
for(int i = l; i <= m; i++) s[++cnt] = p[i], s[cnt].f = ; // 在左半部分
for(int i = m + ; i <= r; i++) s[++cnt] = p[i], s[cnt].f = ; // 在右半部分
sort(s + , s + + cnt, cmpy); // 根据y排序
for(int i = ; i <= cnt; i++) { // 正着扫
if(!s[i].f) update(s[i].z, ); // 左半部分对右半部分的查询有影响因此更新
else ans[s[i].x] += query(gap) - query(s[i].z); // 在[m,r]区间查询大于它的z的数量
}
for(int i = ; i <= cnt; i++) if(!s[i].f) update(s[i].z, -);
for(int i = cnt; i >= ; i--) { // 逆着扫
if(!s[i].f) update(s[i].z, );
else ans[s[i].x] += query(s[i].z); // 在[m,r]区间查询小于它的z的数量
}
for(int i = ; i <= cnt; i++) if(!s[i].f) update(s[i].z, -);
} int main() {
int n, m;
while(~scanf("%d%d", &n, &m)) {
int a, cnt = ;
gap = ;
for(int i = ; i <= n; i++) {
scanf("%d", &num[i]);
Hash[num[i]] = i;
p[i] = node(, i, num[i]);
if(num[i] > gap) gap = num[i];
}
for(int i = ; i <= m; i++) {
scanf("%d", &a);
p[Hash[a]].x = n - i + ;
}
for(int i = ; i <= n; i++)
if(p[i].x == ) p[i].x = ++cnt;
sort(p + , p + + n, cmpx);
memset(bit, , sizeof(bit));
CDQ(, n);
LL res = ;
for(int i = ; i <= n; i++) res += ans[i];
for(int i = n; i > n - m; i--) {
printf("%lld\n", res);
res -= ans[i];
}
}
return ;
}