稀疏矩阵 part 2

时间:2023-03-09 08:53:16
稀疏矩阵 part 2

▶ 各种稀疏矩阵数据结构之间的转化

● MAT ←→ CSR

 CSR * MATToCSR(const MAT *in)                                       // MAT 转 CSR
{
checkNULL(in);
CSR * out = initializeCSR(in->row, in->col, in->count);
checkNULL(out); out->ptr[] = ;
for (int i = , j = , k = ; i < in->row * in->col; i++) // i 遍历 in->data
{
if (in->data[i] != ) // 找到非零元
{
if (j == in->count) // 在 out->data 已经填满了的基础上又发现了非零元,错误
return NULL;
out->data[j] = in->data[i]; // 填充非零元素
out->index[j] = i % in->col; // 填充列号
j++;
}
if ((i + ) % in->col == ) // 到了最后一列,写入行指针号
out->ptr[k++] = j;
}
return out;
} MAT * CSRToMAT(const CSR *in) // CSR转MAT
{
checkNULL(in);
MAT *out = initializeMAT(in->row, in->col, in->ptr[in->row]);
checkNULL(out); memset(out->data, , sizeof(format) * in->row * in->col);
for (int i = ; i < in->row; i++) // i 遍历行
{
for (int j = in->ptr[i]; j < in->ptr[i + ]; j++) // j 遍历列
out->data[i * in->col + in->index[j]] = in->data[j];
}
return out;
}

● MAT ←→ ELL

 ELL * MATToELL(const MAT *in)// MAT转ELL
{
checkNULL(in); int i, j, maxElement;
for (i = j = maxElement = ; i < in->row * in->col; i++) // i 遍历 in->data,j 记录该行非零元素数,maxElement 记录一行非零元素最大值
{
if (in->data[i] != ) // 找到非零元
j++;
if ((i + ) % in->col == ) // 行末,更新 maxElement
{
maxElement = MAX(j, maxElement);
j = ; // 开始下一行之前清空 j
}
}
format* temp_data=(format *)malloc(sizeof(format) * in->row * maxElement); // 临时数组,将列数压缩到 maxElement
checkNULL(temp_data);
int* temp_index = (int *)malloc(sizeof(int) * in->row * maxElement);
checkNULL(temp_index);
memset(temp_data, , sizeof(format) * in->row * maxElement);
memset(temp_index, , sizeof(int) * in->row * maxElement);
for (i = j = ; i < in->row * in->col; i++) // i 遍历 in->data,j 记录该行非零元素数,把 in 中每行的元素往左边推
{
if (in->data[i] != ) // 找到非零元
{
temp_data[i / in->col * maxElement + j] = in->data[i]; // 存放元素
temp_index[i / in->col * maxElement + j] = i % in->col; // 记录所在的列号
j++;
}
if ((i + ) % in->col == ) // 行末,将剩余位置的下标记作 -1,即无效元素
{
for (j += i / in->col * in->col; j < maxElement * (i / in->col + ); j++) // 使得 j 指向本行最后一个非零元素之后的元素,再开始填充
temp_index[j] = -;
j = ; // 开始下一行之前清空 j
}
}
ELL *out = initializeELL(maxElement, in->row, in->col); // 最终输出,如果不转置的话不要这部分
checkNULL(out);
for (i = ; i < out->row * out->col; i++) // 将 temp_data 和 temp_index 转置以提高缓存利用
{
out->data[i] = temp_data[i % out->col * out->row + i / out->col];
out->index[i] = temp_index[i % out->col * out->row + i / out->col];
}
free(temp_data);
free(temp_index);
return out;
} MAT * ELLToMAT(const ELL *in) // ELL转MAT
{
checkNULL(in);
MAT *out = initializeMAT(in->col, in->colOrigin);
checkNULL(out); for (int i = ; i < in->row * in->col; i++) // i 遍历 out->data
{
if (in->index[i] < ) // 注意跳过无效元素
continue;
out->data[i % in->col * in->colOrigin + in->index[i]] = in->data[i];
}
COUNT_MAT(out);
return out;
}

● MAT ←→ COO

 COO * MATToCOO(const MAT *in)                               // MAT转COO
{
checkNULL(in);
COO *out = initializeCOO(in->row, in->col, in->count); for (int i=, j = ; i < in->row * in->col; i++)
{
if (in->data[i] != )
{
out->data[j] = in->data[i];
out->rowIndex[j] = i / in->col;
out->colIndex[j] = i % in->col;
j++;
}
}
return out;
} MAT * COOToMAT(const COO *in) // COO转MAT
{
checkNULL(in);
MAT *out = initializeMAT(in->row, in->col, in->count);
checkNULL(out); for (int i = ; i < in->row * in->col; i++)
out->data[i] = ;
for (int i = ; i < in->count; i++)
out->data[in->rowIndex[i] * in->col + in->colIndex[i]] = in->data[i];
return out;
}

● MAT ←→ DIA

 DIA * MATToDIA(const MAT *in)                                       // MAT转DIA
{
checkNULL(in); int *index = (int *)malloc(sizeof(int)*(in->row + in->col - ));
for (int diff = in->row - ; diff > ; diff--) // 左侧零对角线情况
{
int flagNonZero = ;
for (int i = ; i < in->col && i + diff < in->row; i++) // i 沿着对角线方向遍历 in->data,flagNonZero 记录该对角线是否全部为零元
{
#ifdef INT
if (in->data[(i + diff) * in->col + i] != )
#else
if (fabs(in->data[(i + diff) * in->col + i]) > EPSILON)
#endif
flagNonZero = ;
}
index[in->row - - diff] = flagNonZero; // 标记该对角线上有非零元
}
for (int diff = in->col - ; diff >= ; diff--) // 右侧零对角线情况
{
int flagNonZero = ;
for (int j = ; j < in->row && j + diff < in->col; j++)
{
#ifdef INT
if (in->data[j * in->col + j + diff] != )
#else
if (fabs(in->data[j * in->col + j + diff]) > EPSILON)
#endif
flagNonZero = ;
}
index[in->row - + diff] = flagNonZero; // 标记该对角线上有非零元
}
int *prefixSumIndex = (int *)malloc(sizeof(int)*(in->row + in->col - ));
prefixSumIndex[] = index[];
for (int i = ; i < in->row + in->col - ; i++) // 闭前缀和,prefixSumIndex[i] 表示原矩阵第 0 ~ i 条对角线*有多少条非零对角线(含)
prefixSumIndex[i] = prefixSumIndex[i-] + index[i]; // index[in->row + in->col -2] 表示原矩阵非零对角线条数,等于 DIA 矩阵列数
DIA *out = initializeDIA(in->row, prefixSumIndex[in->row + in->col - ], in->col);
checkNULL(out); memset(out->data, , sizeof(int)*out->row * out->col);
for (int i = ; i < in->row + in->col - ; i++)
out->index[i] = index[i]; // index 搬进 out
for (int i = ; i < in->row; i++) // i,j 遍历原矩阵,将元素搬进 out
{
for (int j = ; j < in->col; j++)
{
int temp = j - i + in->row - ;
if (index[temp] == )
continue;
out->data[i * out->col + (temp > ? prefixSumIndex[temp - ] : )] = in->data[i * in->col + j]; // 第 row - 1 行第 0 列元素 temp == 0,单独处理
}
}
free(index);
free(prefixSumIndex);
return out;
} MAT * DIAToMAT(const DIA *in) // DIA转MAT
{
checkNULL(in);
MAT *out = initializeMAT(in->row, in->colOrigin);
checkNULL(out); int * inverseIndex = (int *)malloc(sizeof(int) * in->col);
for (int i = , j = ; i < in->row + in->col - ; i++) // 求一个 index 的逆,即 DIA 中第 i 列对应原矩阵第 inverseIndex[i] 对角线
{ // 原矩阵对角线编号 (row-1, 0) 为第 0 条,(0, 0) 为第 row - 1 条,(col-1, 0) 为第 row + col - 2 条
if (in->index[i] == )
{
inverseIndex[j] = i;
j++;
}
}
for (int i = ; i < in->row; i++) // i 遍历 in->data 行,j 遍历 in->data 列
{
for (int j = ; j < in->col; j++)
{
if (i < in->row - - inverseIndex[j] || i > inverseIndex[in->col - ] - inverseIndex[j]) // 跳过两边呈三角形的无效元素
continue;
out->data[i * in->col + inverseIndex[j] - in->row + ] = in->data[i * in->col + j]; // 利用 inverseIndex 来找钙元素在原距震中的位置
}
}
free(inverseIndex);
return out;
}