初次接触CDQ分治,感觉真的挺厉害的。整体思路即分而治之,再用之前处理出来的答案统计之后的答案。
大概流程是(对于区间 l ~ r):
1.处理 l ~mid, mid + 1 ~ r 的答案;
2.分别排序规整;
3.计算 l ~ mid 中每一个数对 mid + 1 ~ r 中的答案的贡献, 累加;
4.得到区间l ~ r的答案。
CDQ分治我一共也才做了两道题目, 就一起整理在这里了。大体都差不多,CDQ+树状数组分别维护两个维度。
1.三维偏序
#include <bits/stdc++.h>
using namespace std;
#define maxn 3000000
#define lowbit(x) x &(-x)
int n, k, tot, ans[maxn], c[maxn];
struct node
{
int a, b, c, ans, cnt;
}num[maxn], a[maxn]; int read()
{
int x = , k = ;
char c;
c = getchar();
while(c < '' || c > '') { if(c == '-') k = -; c = getchar();}
while(c >= '' && c <= '') x = x * + c - '', c = getchar();
return x * k;
} bool cmp(node a, node b)
{
if(a.a != b.a) return a.a < b.a;
if(a.b != b.b) return a.b < b.b;
return a.c < b.c;
} bool cmp2(node a, node b)
{
if(a.b != b.b) return a.b < b.b;
return a.c < b.c;
} void Update(int x, int v)
{
for(int i = x; i <= k; i += lowbit(i))
c[i] += v;
} int Query(int x)
{
int ans = ;
for(int i = x; i; i -= lowbit(i))
ans += c[i];
return ans;
} void cdq(int l, int r)
{
int mid = (l + r) >> ;
if(r - l >= ) cdq(l, mid), cdq(mid + , r);
if(r == l) return;
sort(num + l, num + mid + , cmp2);
sort(num + mid + , num + r + , cmp2);
int i = l, j = mid + ;
while(i <= mid && j <= r)
{
if(num[i].b <= num[j].b) Update(num[i].c, num[i].cnt), i ++;
else num[j].ans += Query(num[j].c), j ++;
}
while(i <= mid) Update(num[i].c, num[i].cnt), i ++;
while(j <= r) num[j].ans += Query(num[j].c), j ++;
for(int i = l; i <= mid; i ++) Update(num[i].c, -num[i].cnt);
} int main()
{
n = read(), k = read();
for(int i = ; i <= n; i ++)
a[i].a = read(), a[i].b = read(), a[i].c = read();
sort(a + , a + + n, cmp);
for(int i = ; i <= n;)
{
int j = ;
while(i + j <= n && a[i].a == a[i + j].a && a[i].b == a[i + j].b && a[i].c == a[i + j].c) j ++;
num[++ tot] = a[i];
num[tot].cnt = j;
i += j;
}
cdq(, tot);
for(int i = ; i <= tot; i ++) ans[num[i].ans + num[i].cnt - ] += num[i].cnt;
for(int i = ; i < n; i ++) printf("%d\n", ans[i]);
return ;
}
2.动态逆序对
#include <bits/stdc++.h>
using namespace std;
#define maxn 2000000
#define ll long long
#define lowbit(x) x & (-x)
int n, m, timer, a[maxn], b[maxn], d[maxn], t[maxn];
ll ans[maxn], c[maxn];
struct node
{
int t, num, pl;
ll ans;
}w[maxn]; int read()
{
int x = , k = ;
char c;
c = getchar();
while(c < '' || c > '') { if(c == '-') k = -; c = getchar();}
while(c >= '' && c <= '') x = x * + c - '', c = getchar();
return x * k;
} bool cmp(node a, node b)
{
if(a.t != b.t) return a.t < b.t;
else return a.pl < b.pl;
} bool cmp2(node a, node b)
{
return a.pl < b.pl;
} bool cmp3(node a, node b)
{
return a.pl > b.pl;
} void add(int x, int num)
{
for(int i = x; i <= n; i += lowbit(i))
c[i] += num;
} ll query(int x)
{
ll ans = ;
for(int i = x; i; i -= lowbit(i))
ans += c[i];
return ans;
} void CDQ(int l, int r)//位置在我之前,num>我的
{
int mid = (l + r) >> ;
if(r - l >= ) CDQ(l, mid), CDQ(mid + , r);
if(l == r) return;
sort(w + l, w + + mid, cmp2);
sort(w + mid + , w + r + , cmp2);
int i = l, j = mid + ;
while(i <= mid && j <= r)
{
if(w[i].pl < w[j].pl) add(w[i].num, ), i ++;
else w[j].ans += (query(n) - query(w[j].num)), j ++;
}
while(i <= mid) add(w[i].num, ), i ++;
while(j <= r) w[j].ans += (query(n) - query(w[j].num)), j ++;
for(int i = l; i <= mid; i ++)
add(w[i].num, -);
} void CDQ2(int l, int r)//位置在我之后,num<我的
{
int mid = (l + r) >> ;
if(r - l >= ) CDQ2(l, mid), CDQ2(mid + , r);
if(l == r) return;
sort(w + l, w + + mid, cmp3);
sort(w + mid + , w + r + , cmp3);
int i = l, j = mid + ;
while(i <= mid && j <= r)
{
if(w[i].pl > w[j].pl) add(w[i].num, ), i ++;
else w[j].ans += (query(w[j].num)), j ++;
}
while(i <= mid) add(w[i].num, ), i ++;
while(j <= r) w[j].ans += (query(w[j].num)), j ++;
for(int i = l; i <= mid; i ++)
add(w[i].num, -);
} int main()
{
n = read(), m = read();
for(int i = ; i <= n; i ++)
{
a[i] = read();
b[a[i]] = i;
}
timer = m;
for(int i = ; i <= m; i ++)
{
d[i] = read();
t[b[d[i]]] = timer --;
}
for(int i = ; i <= n; i ++) w[i].t = t[i], w[i].num = a[i], w[i].pl = i;
sort(w + , w + + n, cmp);
CDQ(, n);
for(int i = ; i <= n; i ++) ans[w[i].t] += w[i].ans;
for(int i = ; i <= n; i ++) w[i].t = t[i], w[i].num = a[i], w[i].pl = i, w[i].ans = ;
sort(w + , w + + n, cmp);
memset(c, , sizeof(c));
CDQ2(, n);
for(int i = ; i <= n; i ++) ans[w[i].t] += w[i].ans;
for(int i = ; i <= m; i ++) ans[i] += ans[i - ];
for(int i = m; i >= ; i --) printf("%lld\n", ans[i]);
return ;
}