快速沃尔什变换(FWT)学习笔记

时间:2023-03-09 18:35:59
快速沃尔什变换(FWT)学习笔记

概述

FWT的大体思路就是把要求的 C(x)=A(x)×B(x)  即 \( c[i]=\sum\limits_{j?k=i} (a[j]*b[k]) \) 变换成这样的:\( c^{'}[i]=a^{'}[i]*b^{'}[i] \)。

只要知道 c'[ i ] 和 c[ i ] 的关系,就能把 A(x)、B(x) 变成 A'(x)、B'(x) ,从而算出 C'(x) ,再把 C'(x) 变成 C(x)。

或卷积

定义\( c^{'}[i]=\sum\limits_{j | i=i} c[j] \),则\( c^{'}[i]=a^{'}[i]*b^{'}[i] \)

  证明:\( c^{'}[i]=\sum\limits_{j | i=i} c[j] \)

     因为\( c[i]=\sum\limits_{j|k=i}a[j]*b[k] \)

     所以 \( c^{'}[i]=\sum\limits_{(j|k)|i=i}a[j]*b[k] \)

            \( =\sum\limits_{(j|k)|i=i}a[j]*b[k]=\sum\limits_{j|i=i}a[j] * \sum\limits_{k|i=i}b[k] \)

     又有:\( a^{'}[i]=\sum\limits_{j|i=i}a[j] \)  \( b^{'}[i]=\sum\limits_{j|i=i}b[j] \)

     所以\( c^{'}[i]=a^{'}[i]*b^{'}[i] \)

接下来考虑怎么把 A(x) 变成 A'(x) 。

考虑按位来做,比如从低位到高位枚举,则每一部分左边一半的该位全是0、右边一半的该位全是1;记左边为A0,右边为A1。

如:00001111

  00110011

  01010101

  如果已经算好了A0和A1,考虑用它们求出A。比如算好了第 1~2 位置的值A0和第 3~4 位置的值A1,想求第 1~4 位置的值A;(那么现在是枚举到了二进制第二位了)

  此时的A0里没有A1位置对它的贡献,A1里也没有A0位置对它的贡献;考虑两部分位置的值怎样互相贡献;

  考虑左边和右边的对应位置,它们只有最高位一个是0一个是1的不同;则是左边对应位置的子集的位置一定也是右边对应位置的子集,可以这样做:A0' = A0,A1' = A0+A1

  所以模仿FFT的框架写一个就行了。(感觉这里求 \( \sum\limits_{j|i=i}a[j] \) 的思路和高维前缀和很像)

  但不用弄那个 r[ ] 来换位置(因为不是弄偶数项和奇数项,而是真的前半部分和后半部分);不过就算换了位置也可以!

    因为那样换一下位置相当于是每个位置的角标被翻转了,比如上面那8个位置的角标会变成:

    01010101

    00110011

    00001111

    这样的话,自己“从低位到高位枚举”可以看作从高位到低位枚举,一切就没问题了。主要是因为位运算每一位是独立的嘛。

它的逆变换是这样想:因为 A0' = A0,A1' = A0+A1;所以 A0 = A0',A1 = A1'-A0 = A1'-A0'。刚才是从低位到高位枚举的话,现在要从高位到低位枚举。

  但其实还是从低位到高位枚举也是对的!

  考虑一个位置k,它加上的那些 “对应位置” j 的特点是 j 只和 k 有一位不同。比如从低到高枚举到第3位的时候 k 位置的值加上了 j 位置的值,说明二进制第3位上 j 是0、k是1,第3位之前 j 和 k 一样(因为“对应”嘛),而第3位之后 j 和 k 其实也一样(因为第3位之后 j 和 k 就变成“一块"里的了,再高的位会一起变成0或1之类的);

  从 A'(x) 变回 A(x) 的过程中,比如第一步的时候,每个 a[ i ] 都记录着所有 角标是 i 子集的a的权值和 ;

  从低到高枚举到第一个 k 是1的位置(除了最低位),比如是第3位,则此时 a'[ k ] - a'[ j ] 减去的值是 “角标第3位是0、其余部分是 k 的子集” 的那些位置的值;剩下的就是 “角标第3位是1、其余部分是 k 的子集” 的值。

  接下来枚举到下一个 k 是1的位置,比如是第5位;因为 j 的其它位上的值都和 k 一样,所以此时 j 也是经历过第3位时的一番操作;则此时 a'[ k ] - a'[ j ] 减去的值是“角标第3位是1、第5位是0、其余部分是 k 的子集”的那些位置的值;则 a'[ k ] 剩下的值是 “角标第3位是1、第5位是1、其余部分是 k 的子集” 的位置的值;

  这样一直枚举到最后,剩下的就是 “角标在 k 是1的位上是1、其余位上是 k 的自己” 位置的值,即只剩正好的 a[ k ] 了,于是此时 a'[ k ] = a[ k ] 。

与卷积

和或卷积一样。变换:A0'=A0+A1,A1'=A1  逆变换:A0 = A0'-A1 = A0'-A1',A1=A1'

异或卷积

定义 \( c^{'}[i]=\sum\limits_{j \& i有偶数个1} c[j] - \sum\limits_{j \& i有奇数个1} c[j] \)

考虑证明 \( c^{'}[i]=a^{'}[i]*b^{'}[i] \)

  证明:因为 \( c[i]=\sum\limits_{j \otimes k=i} a[j]*b[k] \)

     所以 \( c^{'}[i]=\sum\limits_{ (j \otimes k)与 i 有偶数个1重合 } a[j]*b[k] - \sum\limits_{ (j \otimes k)与 i 有奇数个1重合 } a[j]*b[k] \)

     又 \( a^{'}[i]*b^{'}[i] = ( \sum\limits_{j \& i有偶数个1}a[j] - \sum\limits_{j \& i有奇数个1}a[j] ) * ( \sum\limits_{j \& i有偶数个1}b[j] - \sum\limits_{j \& i有奇数个1}b[j] ) \)

              \( = \sum\limits_{j \& i有偶数个1}a[j]*b[j] + \sum\limits_{j \& i有奇数个1}a[j]*b[j] - \sum\limits_{j \& i有偶数个1,k \& i有奇数个1}a[j]*b[k] - \sum\limits_{j \& i有奇数个1,k \& i有偶数个1}a[j]*b[k] \)

              \( = \sum\limits_{j \& i与k \& i的1的个数奇偶性相同}a[j]*b[k] - \sum\limits_{j \& i与k \& i的1的个数奇偶性不同}a[j]*b[k] \)

              \( = \sum\limits_{(j \otimes k)与 i 有偶数个1重合}a[j]*b[k] - \sum\limits_{(j \otimes k)与 i 有奇数个1重合}a[j]*b[k] \)

     (这一步等价是因为异或的时候,如果 j 和 k 有公共位置的1,那么一次会消掉2个1;所以 ( (j&i)的1的个数 + (k&i)的1的个数 ) 在 j 和 k 异或之后奇偶性不会变)

     所以 \( c^{'}[i]=a^{'}[i]*b^{'}[i] \)

接下来考虑怎么把A(x)变成A'(x)。

  还是有前一半的A0和后一半的A1。对应位置 & 起来之后,那个最高位还是0;

  所以对于A0里的一个a[ i ]来说,记和它 & 起来的那些位置 j (其实 j 遍历了所有A0里的位置)在A1里的对应位置为 j' ,则 j & i == j' & i;所以A0'=A0+A1;

  而对于A1里的一个a[ i ]来说,算A1的时候A1的标号的最高位还没被考虑(即视作0),合并的时候A1的最高位变成1了;设 i 在A1的 & 起来的那些位置 j 在A0里的对应位置为 j',则 j & i 比 j' & i 多了一个1(最高位即当前枚举到的位),所以当 i 和A1里的 j 匹配时,单独算A1时算好的 a[ i ] = sigma - sigma 里的两个 sigma 的位置换了一下,也就是符号变了;所以A1'=A0-A1。

它的逆变换就是:A0=(A0'+A1')/2,A1=(A0'-A1')/2。

关于实现方法的讨论就和或卷积一样。

模板

洛谷4717 【模板】快速沃尔什变换

题目:https://www.luogu.org/problemnew/show/P4717

不知为何跑得很慢。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const int N=(<<)+,mod=;
int n,a[][N],b[][N],c[][N],len,r[N],inv;
int rdn()
{
int ret=;bool fx=;char ch=getchar();
while(ch>''||ch<''){if(ch=='-')fx=;ch=getchar();}
while(ch>=''&&ch<='') ret=(ret<<)+(ret<<)+ch-'',ch=getchar();
return fx?ret:-ret;
}
int g[];
void wrt(int x)
{
if(x<)putchar('-'),x=-x;
if(!x){printf("0 ");return;}
int t=;while(x)g[++t]=x%,x/=;
while(t)putchar(g[t]+''),t--;putchar(' ');
}
void upd(int &x){x>=mod?x-=mod:;}
void fwt0(int *a,bool fx)
for(int R=;R<=len;R<<=)
for(int i=,m=R>>;i<len;i+=R)
for(int j=;j<m;j++)
(fx?a[i+m+j]+=mod-a[i+j]:a[i+m+j]+=a[i+j]),upd(a[i+m+j]);
}
void fwt1(int *a,bool fx)
for(int R=;R<=len;R<<=)
for(int i=,m=R>>;i<len;i+=R)
for(int j=;j<m;j++)
(fx?a[i+j]+=mod-a[i+m+j]:a[i+j]+=a[i+m+j]),upd(a[i+j]);
}
void fwt2(int *a,bool fx)
for(int R=;R<=len;R<<=)
{
for(int i=,m=R>>;i<len;i+=R)
for(int j=;j<m;j++)
{
int x=a[i+j]+a[i+m+j],y=a[i+j]+mod-a[i+m+j];
upd(x); upd(y);
fx?(x=(ll)x*inv%mod,y=(ll)y*inv%mod):;
a[i+j]=x; a[i+m+j]=y;
}
}
}
int main()
{
n=(<<rdn());
for(int i=;i<n;i++)a[][i]=a[][i]=a[][i]=rdn();
for(int i=;i<n;i++)b[][i]=b[][i]=b[][i]=rdn();
len=n<<;
for(int i=;i<len;i++)r[i]=(r[i>>]>>)+((i&)?len>>:);
int k=mod-,tmp=;inv=;
while(k){if(k&)inv=(ll)inv*tmp%mod;tmp=(ll)tmp*tmp%mod;k>>=;}
fwt0(a[],); fwt0(b[],); fwt1(a[],); fwt1(b[],); fwt2(a[],); fwt2(b[],);
for(int t=;t<;t++)
for(int i=;i<len;i++)
c[t][i]=(ll)a[t][i]*b[t][i]%mod;
fwt0(c[],); fwt1(c[],); fwt2(c[],);
for(int t=;t<;t++,puts(""))
for(int i=;i<n;i++)wrt(c[t][i]);
return ;
}