CF Manthan, Codefest 16 G. Yash And Trees 线段树+bitset

时间:2023-03-09 19:02:50
CF Manthan, Codefest 16 G. Yash And Trees 线段树+bitset

题目链接:http://codeforces.com/problemset/problem/633/G

大意是一棵树两种操作,第一种是某一节点子树所有值+v,第二种问子树中节点模m出现了多少种m以内的质数。

第一种操作非常熟悉了,把每个节点dfs过程中的pre和post做出来,对序列做线段树。维护取模也不是问题。第二种操作,可以利用bitset记录质数出现情况。所以整个线段树需要维护bitset的信息。

对于某一个bitset x,如果子树所有值需要加y,则x=(x<<y)|(x>>(m-y))

一开始写挂了几次,有一点没注意到,因为我bitset直接全都是1000,而不是m,所以上面式子左移会有问题,解决方法是做一个0到m每位都是1的全集,或者表示质数集的bitset上限做到m。

 #include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <string.h>
#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <ctime>
#include <numeric>
#include <bitset>
#include <cassert> using namespace std;
const int N=;
int a[N];
vector<int>edge[N];
bitset<>bs[N<<];
bitset<>prime;
int mark[N<<];
int L[N],R[N];
int dfn=;
int m;
int id[N];
void dfs(int u,int f) {
L[u]=++dfn;
id[dfn]=u;
for (int i=;i<edge[u].size();i++) {
int v=edge[u][i];
if (v==f)
continue;
dfs(v,u);
}
R[u]=dfn;
}
void up(int rt) {
bs[rt]=(bs[rt<<]|bs[rt<<|]);
}
void add(int rt,int x) {
bs[rt]=(((bs[rt]<<x))|(bs[rt]>>(m-x)));
}
void down(int rt) {
if (mark[rt]) {
add(rt<<,mark[rt]);
add(rt<<|,mark[rt]);
mark[rt<<]=(mark[rt<<]+mark[rt])%m;
mark[rt<<|]=(mark[rt<<|]+mark[rt])%m;
mark[rt]=;
}
}
void build(int l,int r,int rt) {
mark[rt]=;
if (l==r) {
int x=a[id[l]];
bs[rt].set(x);
return;
}
int m=(l+r)>>;
build(l,m,rt<<);
build(m+,r,rt<<|);
up(rt);
} void upd(int L,int R,int x,int l,int r,int rt) {
if (L<=l&&r<=R) {
add(rt,x);
mark[rt]=(mark[rt]+x)%m;
return;
}
down(rt);
int m=(l+r)>>;
if (L<=m)
upd(L,R,x,l,m,rt<<);
if (R>m)
upd(L,R,x,m+,r,rt<<|);
up(rt);
}
bitset<> ans;
void ask(int L,int R,int l,int r,int rt) {
if (L<=l&&r<=R) {
ans|=bs[rt];
return;
}
down(rt);
int m=(l+r)>>;
if (L<=m)
ask(L,R,l,m,rt<<);
if (R>m)
ask(L,R,m+,r,rt<<|);
} int main () {
int n;
scanf("%d %d",&n,&m);
for (int i=;i<m;i++) {
bool isp=true;
for (int j=;isp&&j*j<=i;j++) {
if (i%j==)
isp=false;
}
if (isp){
prime.set(i);
}
}
for (int i=;i<=n;i++){
scanf("%d",a+i);
a[i]%=m;
}
for (int i=;i<n;i++) {
int u,v;
scanf("%d %d",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
dfn=;
dfs(,-);
build(,n,);
int Q;
scanf("%d",&Q);
while (Q--) {
int op;
scanf("%d",&op);
if (op==) {
int u,x;
scanf("%d %d",&u,&x);
x%=m;
upd(L[u],R[u],x,,n,);
}
else {
int u;
scanf("%d",&u);
ans=;
ask(L[u],R[u],,n,);
ans&=prime;
int ret=ans.count();
printf("%d\n",ret);
}
}
return ;
}