【BZOJ4382】[POI2015]Podział naszyjnika 堆+并查集+树状数组

时间:2023-03-09 06:51:18
【BZOJ4382】[POI2015]Podział naszyjnika 堆+并查集+树状数组

【BZOJ4382】[POI2015]Podział naszyjnika

Description

长度为n的一串项链,每颗珠子是k种颜色之一。 第i颗与第i-1,i+1颗珠子相邻,第n颗与第1颗也相邻。
切两刀,把项链断成两条链。要求每种颜色的珠子只能出现在其中一条链中。
求方案数量(保证至少存在一种),以及切成的两段长度之差绝对值的最小值。

Input

第一行n,k(2<=k<=n<=1000000)。颜色从1到k标号。
接下来n个数,按顺序表示每颗珠子的颜色。(保证k种颜色各出现至少一次)。

Output

一行两个整数:方案数量,和长度差的最小值

Sample Input

9 5
2 5 3 2 2 4 1 1 3

Sample Output

4 3

HINT

四种方法中较短的一条分别是(5),(4),(1,1),(4,1,1)。相差最小值6-3=3。

题解:hash那么巧妙的做法我怎么想得到啊~我只会无脑的数据结构。

防止重复,我们不倍长原序列,然后枚举一条切割线r,只考虑另一条切割线l在这条左边的情况。那么对于每种颜色,它只能有一下两种存在方式。

1.【BZOJ4382】[POI2015]Podział naszyjnika 堆+并查集+树状数组2.【BZOJ4382】[POI2015]Podział naszyjnika 堆+并查集+树状数组

对于第一种情况,我们可以对每个点维护上一个与它颜色相同的位置pre,然后只需要满足pre<=l即可。可以用堆维护pre的最大值。

对于第二种情况,我们已经枚举到了这个颜色最右面的点,现在只需要将这个颜色最左端和最右端中间的点全部删除。用并查集维护,并用树状数组统计区间中已经被删除的点的个数即可。

于是方案数量我们很容易就能求出来了。那么长度差的最小值怎么办?我们对于右端点r,肯定是希望找到离r-n/2最近的合法的l。可以用并查集找到每个点左面和右面第一个没被删除的点,判断一下就行。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
const int maxn=1000010;
typedef long long ll;
int n,m,ans;
ll sum;
int v[maxn],pos[maxn],f[maxn],s[maxn],siz[maxn];
vector<int> p[maxn];
struct heap
{
priority_queue<int> qa,qb;
inline void push(int x) {qa.push(x);}
inline void erase(int x) {qb.push(x);}
inline int top()
{
while(qb.size()&&qa.top()==qb.top()) qa.pop(),qb.pop();
return qa.size()?qa.top():0;
}
}q;
inline int rd()
{
int ret=0,f=1; char gc=getchar();
while(gc<'0'||gc>'9') {if(gc=='-') f=-f; gc=getchar();}
while(gc>='0'&&gc<='9') ret=ret*10+gc-'0',gc=getchar();
return ret*f;
}
int find(int x)
{
return (f[x]==x)?x:(f[x]=find(f[x]));
}
inline int abs(int x)
{
return x>0?x:-x;
}
inline void updata(int x)
{
for(int i=x;i<=n;i+=i&-i) s[i]++;
}
inline int query(int x)
{
if(x==-1) return 0;
int i,ret=0;
for(i=x;i;i-=i&-i) ret+=s[i];
return ret;
}
int main()
{
n=rd(),m=rd();
int i,j,k;
for(i=1;i<=n;i++)
{
v[i]=rd(),p[v[i]].push_back(i),pos[i]=p[v[i]].size()-1,f[i]=i,siz[i]=1;
}
f[0]=1,f[n+1]=n+1,siz[n+1]=1;
ans=n;
for(i=1;i<n;i++)
{
if(pos[i]) q.erase(p[v[i]][pos[i]-1]);
if(pos[i]==(int)p[v[i]].size()-1)
{
for(j=find(p[v[i]][0]);j<i;j=f[j]) updata(j),siz[find(j+1)]+=siz[j],f[j]=f[j+1];
}
else q.push(i);
k=q.top();
sum+=i-k-(query(i-1)-query(k-1));
j=find(max(k,i-n/2));
if(j<i) ans=min(ans,abs(n-2*(i-j)));
j-=siz[j];
if(j>=k) ans=min(ans,abs(n-2*(i-j)));
}
printf("%lld %d",sum,ans);
return 0;
}