题意:
给出一个有n个数的数列,并定义mex(l, r)表示数列中第l个元素到第r个元素中第一个没有出现的最小非负整数。
求出这个数列中所有mex的值。
思路:
可以看出对于一个数列,mex(r, r~l)是一个递增序列
mex(0, 0~n-1)是很好求的,只需要遍历找出第一个没有出现的最小非负整数就好了。这里有一个小技巧:
tmp = ;
for (int i = ; i <= n; ++i) {
mp[arr[i]] = ;
while (mp.find(tmp) != mp.end()) tmp++;
mex[i] = tmp;
}
这样可以利用map中的红黑树很快找到第一个没有出现的最小非负整数。
然后在求mex(1~n-1, 0~n-1)的过程中,我们可以看出,每消除当前值arr[i],会影响到的是在下一个arr[i]出现前 往后的mex值中比arr[i]大的值,即如果当前这个值不存在了,那么在这个值下一次出现前,mex值比当前值大的mex值都应被替换成arr[i]。
所以我们可以再一次利用map的红黑树找到当前值下一次出现的位置,然后利用线段树成段更新往后的mex值和求出会影响到的mex值的个数。
for (int i = n; i >= ; --i) {
if (mp.find(arr[i]) == mp.end()) next[i] = n+;
else next[i] = mp[arr[i]];
mp[arr[i]] = i;
}
这里我们还需要利用线段树求出第一个比当前arr[i]大的mex值的位置,以便成段更新区间的mex值。
Tips:
※ 这里有一个小小优化的地方,就是当更新的时候,可以先查看mx[1]是否比当前值大,如果是,则表示往后的区间里有比当前值大的mex值,则需要线段树是需要更新的,否则不用更新。
※ 还有一个要注意的地方是:只有求出的左边界比右边界小的时候才能更新。
Code:
#include <stdio.h>
#include <cstring>
#include <map>
#include <algorithm>
using namespace std; const int MAXN = ;
long long sum[MAXN<<];
int mx[MAXN<<], arr[MAXN], next[MAXN], mex[MAXN];
int lazy[MAXN<<]; void Pushup(int rt)
{
sum[rt] = sum[rt<<]+sum[rt<<|];
mx[rt] = max(mx[rt<<], mx[rt<<|]);
} void Pushdown(int rt, int x)
{
if (lazy[rt] != -) {
lazy[rt<<] = lazy[rt<<|] = lazy[rt];
sum[rt<<] = (x-x/)*lazy[rt];
sum[rt<<|] = x/*lazy[rt];
mx[rt<<] = mx[rt<<|] = lazy[rt];
lazy[rt] = -;
}
} void Creat(int l, int r, int rt)
{
lazy[rt] = -;
if (l == r) {
sum[rt] = mx[rt] = mex[l];
return;
}
int mid = (l+r)/;
Creat(l, mid, rt<<);
Creat(mid+, r, rt<<|);
Pushup(rt);
} void Modify(int l, int r, int x, int L, int R, int rt)
{
if (l <= L && r >= R) {
lazy[rt] = x;
sum[rt] = x*(R-L+);
mx[rt] = x;
return;
}
Pushdown(rt, R-L+);
int mid = (L+R)/;
if (l <= mid) Modify(l, r, x, L, mid, rt<<);
if (r > mid) Modify(l, r, x, mid+, R, rt<<|);
Pushup(rt);
} int Get(int rt, int l, int r, int x)
{
if(l == r) return l;
Pushdown(rt, r-l+);
int mid = (l+r)/;
if (mx[rt<<] > x) return Get(rt<<, l, mid, x);
else return Get(rt<<|, mid+, r, x);
} int main()
{
//freopen("in.txt", "r", stdin);
int n, tmp;
long long ans_sum;
map<int, int> mp;
while (~scanf("%d", &n)) {
if (n == ) break;
ans_sum = tmp = ;
mp.clear();
memset(sum, , sizeof(sum));
memset(next, , sizeof(next)); for (int i = ; i <= n; ++i)
scanf("%d", &arr[i]);
for (int i = ; i <= n; ++i) {
mp[arr[i]] = ;
while (mp.find(tmp) != mp.end()) tmp++;
mex[i] = tmp;
} Creat(, n, );
mp.clear();
for (int i = n; i >= ; --i) {
if (mp.find(arr[i]) == mp.end()) next[i] = n+;
else next[i] = mp[arr[i]];
mp[arr[i]] = i;
} for (int i = ; i <= n; ++i) {
ans_sum += sum[];
if (mx[] > arr[i]) {
int l = Get(, , n, arr[i]);
int r = next[i];
// printf("%d %d %d\n", l, r, sum[1]);
if (l < r) Modify(l, r-, arr[i], , n, );
} Modify(i, i, , , n, );
}
printf("%I64d\n", ans_sum);
}
return ;
}