P4100-[HEOI2013]钙铁锌硒维生素【矩阵求逆,最大匹配】

时间:2023-03-10 07:26:53
P4100-[HEOI2013]钙铁锌硒维生素【矩阵求逆,最大匹配】

正题

题目链接:https://www.luogu.com.cn/problem/P4100


题目大意

给出\(n\)个线性无关的向量\(A_i\),然后给出\(n\)个向量\(B_i\),求一个字典序最小的排列\(p\)使得将任意的\(A_i\)替换为\(B_{p_i}\)后依旧线性无关。

\(1\leq n\leq 300\)


解题思路

首先因为我们有\(n\)个向量\(A\)线性无关,那么显然这\(n\)个向量能表示任意向量,如果对于一个\(B_{p_i}\)替换为\(A_i\)后依旧线性无关,那么\(B_{p_i}\)与\(A_i\)是等价的(因为\(B_{p_i}\)和\(A_i\)都代表了剩下\(n-1\)个无法表示的部分)。

所以只需考虑每个\(B_j\)能否换到\(A_i\)即可,构建出矩阵\(A=[A_1,A_2...A_n]\)和\(B=[B_1,B_2...B_n]\),考虑一个置换矩阵使得\(AR=B\),那么就是对于每个\(B\)如何用\(A\)进行表示。

那么如果\(R_{i,j}=0\)也就是说\(B\)可以用\(A_j\)以外的其他\(A\)表示所以\(B\)替换到\(A_j\)之后肯定线性有关了,所以不行。

\(R=\frac{B}{A}\),求逆得到\(R\),这样我们就知道哪些\(A\)可以替换哪些\(B\)了,问题就变成了最小字典序匹配。对于这个问题我们可以考虑找一条增广环就好了。

时间复杂度\(O(n^3)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=310;
const double eps=1e-8;
int n,v[N],link[N];
double A[N][N],B[N][N];
bool GetInv(){
for(int i=1;i<=n;i++){
int z=i;
for(int j=i+1;j<=n;j++)
if(fabs(A[j][i])>fabs(A[z][i]))z=j;
swap(A[i],A[z]);swap(B[i],B[z]);
double x=A[i][i];
if(fabs(x)<eps)return 0;
for(int j=1;j<=n;j++)
A[i][j]/=x,B[i][j]/=x;
for(int j=1;j<=n;j++){
if(i==j)continue;
double rate=-A[j][i];
for(int k=1;k<=n;k++)
A[j][k]+=rate*A[i][k],
B[j][k]+=rate*B[i][k];
}
}
for(int i=n;i>=1;i--)
for(int j=1;j<i;j++){
double rate=-A[j][i];
for(int k=1;k<=n;k++)
A[j][k]+=rate*A[i][k],B[j][k]+=rate*B[i][k];
}
return 1;
}
bool dfs(int x){
for(int i=1;i<=n;i++)
if(!v[i]&&fabs(B[x][i])>=eps){
v[i]=1;
if(!link[i]||dfs(link[i])){
link[i]=x;
return 1;
}
}
return 0;
}
int calc(int x,int top){
for(int i=1;i<=n;i++)
if(!v[i]&&fabs(B[x][i])>=eps){
v[i]=1;
if(link[i]==top||(link[i]>top&&calc(link[i],top))){
link[i]=x;
return i;
}
}
return 0;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
scanf("%lf",&A[j][i]);
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
scanf("%lf",&B[j][i]);
if(!GetInv())return puts("NIE")&0;
for(int i=1;i<=n;i++){
memset(v,0,sizeof(v));
if(!dfs(i))return puts("NIE")&0;
}
puts("TAK");
for(int i=1;i<=n;i++){
memset(v,0,sizeof(v));
printf("%d\n",calc(i,i));
}
return 0;
}