[bzoj2124]等差子序列——线段树+字符串哈希

时间:2023-03-09 20:06:02
[bzoj2124]等差子序列——线段树+字符串哈希

题目大意

给一个1到N的排列\(A_i\),询问是否存在\(p_i\),\(i>=3\),使得\(A_{p_1}, A_{p_2}, ... ,A_{p_len}\)是一个等差序列。

题解

显然,我们只需要找到\(P_1, P_2, P_3\),使得其为等差数列即可。

考察等差数列的定义,不难得出:

\[2*P_2 = P_1 + P_3
\]

考察每一个\(P_2\),如果有\(P_2 - d\)已经出现,\(P_2 + d\)没有出现,那么一定可以组成等差序列。

我们考虑使用线段树维护每一个数字是否出现。如果一个数满足要求,以这个数开头的正序串和逆序串一定是不同的。我们维护每一个区间的哈希值,比较即可。

由于维护时满足区间查询,单点修改,所以我们使用线段树。

这个题WA了八次,最后找root要了数据才发现自己哈希的时候使用的pow数组用了int,结果溢出了orz

代码

#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#define ll long long
#define mod 1000000007
#define p 3
using namespace std;
const int maxn = 10005;
struct seg {
int l, r;
ll hash1, hash2;
} t[maxn * 4];
int n, T;
ll pow[maxn];
void update(int k, ll m) {
ll tmp = m >> 1;
t[k].hash1 = ((t[k << 1].hash1 * pow[tmp]) % mod + t[k << 1 | 1].hash1) % mod;
t[k].hash2 =
((t[k << 1 | 1].hash2 * pow[m - tmp]) % mod + t[k << 1].hash2) % mod;
}
void add(int k, int pos) {
int l = t[k].l, r = t[k].r, mid = (l + r) >> 1;
if (l == r) {
t[k].hash1 = t[k].hash2 = 1;
return;
}
if (pos <= mid)
add(k << 1, pos);
else
add(k << 1 | 1, pos);
update(k, r - l + 1);
}
ll query(int k, int x, int y, int opt) {
if (x > y)
return 0;
int l = t[k].l, r = t[k].r, mid = (l + r) >> 1;
if (x <= l && r <= y) {
if (opt == 1)
return t[k].hash1;
else
return t[k].hash2;
}
if (y <= mid)
return query(k << 1, x, y, opt);
else if (x > mid)
return query(k << 1 | 1,x , y, opt);
else {
if (opt == 1) {
return ((query(k << 1, x, mid, 1) * pow[y - mid]) % mod+
query(k << 1 | 1, mid + 1, y, 1) % mod) %
mod;
} else
return ((query(k << 1, x, mid, 2) +
query(k << 1 | 1, mid + 1, y, 2) * pow[mid - x + 1])%mod) %
mod;
}
}
void build(int k, int l, int r) {
t[k].l = l, t[k].r = r;
if (l == r) {
t[k].hash1 = t[k].hash2 = 0;
return;
}
int mid = (l + r) >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
update(k, r - l + 1);
}
int main() {
//freopen("sequence.in", "r", stdin);
//freopen("sequence.out", "w", stdout);
scanf("%d", &T);
//pow[0] = 1;
pow[1] = p;
for (int i = 2; i <= 10001; i++)
pow[i] = (pow[i - 1] * p) % mod;
while (T--) {
scanf("%d", &n);
build(1, 1, n);
//memset(t, 0, sizeof(t));
int flag = 0;
int a[maxn];
memset(a, 0, sizeof(a));
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) {
int x = a[i];
ll len = min(x - 1, n - x);
ll tmp1 = query(1, x - len, x - 1, 1);
ll tmp2 = query(1, x + 1, x + len, 2);
if (tmp1 != tmp2) {
flag = 1;
break;
}
add(1, x);
}
if (flag)
printf("Y\n");
else
printf("N\n");
}
}