更多 CSP 认证考试题目题解可以前往:CSP-CCF 认证考试真题题解
原题链接: 202305-2 矩阵运算
时间限制: 5.0s
内存限制: 512.0MB
题目背景
S o f t m a x ( Q × K T d ) × V \mathrm{Softmax}(\frac{\mathbf{Q} \times \mathbf{K}^{T}}{\sqrt{d}}) \times \mathbf{V} Softmax(dQ×KT)×V 是 Transformer 中注意力模块的核心算式,其中 Q \mathbf{Q} Q、 K \mathbf{K} K 和 V \mathbf{V} V 均是 n n n 行 d d d 列的矩阵, K T \mathbf{K}^{T} KT 表示矩阵 K \mathbf{K} K 的转置, × \times × 表示矩阵乘法。
问题描述
为了方便计算,顿顿同学将
S
o
f
t
m
a
x
\mathrm{Softmax}
Softmax 简化为了点乘一个大小为
n
n
n 的一维向量
W
\mathbf{W}
W:
(
W
⋅
(
Q
×
K
T
)
)
×
V
\left(\mathbf{W} \cdot (\mathbf{Q} \times \mathbf{K}^{T})\right) \times \mathbf{V}
(W⋅(Q×KT))×V
点乘即对应位相乘,记
W
(
i
)
\mathbf{W}^{(i)}
W(i) 为向量
W
\mathbf{W}
W 的第
i
i
i 个元素,即将
(
Q
×
K
T
)
(\mathbf{Q} \times \mathbf{K}^{T})
(Q×KT) 第
i
i
i 行中的每个元素都与
W
(
i
)
\mathbf{W}^{(i)}
W(i) 相乘。
现给出矩阵 Q \mathbf{Q} Q、 K \mathbf{K} K 和 V \mathbf{V} V 和向量 W \mathbf{W} W,试计算顿顿按简化的算式计算的结果。
输入格式
从标准输入读入数据。
输入的第一行包含空格分隔的两个正整数 n n n 和 d d d,表示矩阵的大小。
接下来依次输入矩阵 Q \mathbf{Q} Q、 K \mathbf{K} K 和 V \mathbf{V} V。每个矩阵输入 n n n 行,每行包含空格分隔的 d d d 个整数,其中第 i i i 行的第 j j j 个数对应矩阵的第 i i i 行、第 j j j 列。
最后一行输入 n n n 个整数,表示向量 W \mathbf{W} W。
输出格式
输出到标准输出中。
输出共 n n n 行,每行包含空格分隔的 d d d 个整数,表示计算的结果。
样例输入
3 2
1 2
3 4
5 6
10 10
-20 -20
30 30
6 5
4 3
2 1
4 0 -5
样例输出
480 240
0 0
-2200 -1100
子任务
70 % 70\% 70% 的测试数据满足: n ≤ 100 n \le 100 n≤100 且 d ≤ 10 d \le 10 d≤10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30 30 30。
全部的测试数据满足: n ≤ 1 0 4 n \le 10^4 n≤104 且 d ≤ 20 d \le 20 d≤20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000 1000 1000。
提示
请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。
题解
对于矩阵 A = ( a i , j ) m × n \mathbf A=(a_{i,j})_{m\times n} A=(ai,j)m×n, B = ( b i , j ) n × q \mathbf B=(b_{i,j})_{n\times q} B=(bi,j)n×q,矩阵乘法的结果 C = ( c i , j ) m × q \mathbf C=(c_{i,j})_{m\times q} C=(ci,j)m×q 满足 c i , j = ∑ k = 1 n a i , k b k , j c_{i,j}=\sum\limits_{k=1}^na_{i,k}b_{k,j} ci,j=k=1∑nai,kbk,j通过公式可以看出,计算矩阵乘法的复杂度为 O ( m n q ) \mathcal{O}(mnq) O(mnq)。
按照题目中给的计算顺序: ( W ⋅ ( Q × K T ) ) × V \left(\mathbf{W} \cdot (\mathbf{Q} \times \mathbf{K}^{T})\right) \times \mathbf{V} (W⋅(Q×KT))×V,计算单个矩阵乘法的复杂度为 O ( n 2 d ) \mathcal{O}(n^2d) O(n2d),大致为 2 × 1 0 9 2\times 10^9 2×109,无法接受。
运用矩阵乘法的结合律,可以将顺序变为 W ⋅ ( Q × ( K T × V ) ) \mathbf{W} \cdot \left(\mathbf{Q} \times (\mathbf{K}^{T} \times \mathbf{V})\right) W⋅(Q×(KT×V)),这样计算单个矩阵乘法的复杂度为 O ( n d 2 ) \mathcal{O}(nd^2) O(nd2),大致为 4 × 1 0 6 4\times 10^6 4×106,可以接受。
注意数据范围,要开 long long
。
时间复杂度: O ( n d 2 ) \mathcal{O}(nd^2) O(nd2)。
参考代码(203ms,16.33MB)
/*
Created by Pujx on 2024/3/16.
*/
#pragma GCC optimize(2, 3, "Ofast", "inline")
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
//#define int long long
//#define double long double
using i64 = long long;
using ui64 = unsigned long long;
using i128 = __int128;
#define inf (int)0x3f3f3f3f3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define yn(x) cout << (x ? "yes" : "no") << endl
#define Yn(x) cout << (x ? "Yes" : "No") << endl
#define YN(x) cout << (x ? "YES" : "NO") << endl
#define mem(x, i) memset(x, i, sizeof(x))
#define cinarr(a, n) for (int i = 1; i <= n; i++) cin >> a[i]
#define cinstl(a) for (auto& x : a) cin >> x;
#define coutarr(a, n) for (int i = 1; i <= n; i++) cout << a[i] << " \n"[i == n]
#define coutstl(a) for (const auto& x : a) cout << x << ' '; cout << endl
#define all(x) (x).begin(), (x).end()
#define md(x) (((x) % mod + mod) % mod)
#define ls (s << 1)
#define rs (s << 1 | 1)
#define ft first
#define se second
#define pii pair<int, int>
#ifdef DEBUG
#include ""
#else
#define dbg(...) void(0)
#endif
const int N = 2e5 + 5;
//const int M = 1e5 + 5;
const int mod = 998244353;
//const int mod = 1e9 + 7;
//template <typename T> T ksm(T a, i64 b) { T ans = 1; for (; b; a = 1ll * a * a, b >>= 1) if (b & 1) ans = 1ll * ans * a; return ans; }
//template <typename T> T ksm(T a, i64 b, T m = mod) { T ans = 1; for (; b; a = 1ll * a * a % m, b >>= 1) if (b & 1) ans = 1ll * ans * a % m; return ans; }
int a[N];
int n, m, t, k, q, d;
template <typename T = int> struct Matrix {
int m, n; // m 行 n 列的矩阵
vector<vector<T>> v;
Matrix(): n(0), m(0) { v.resize(0); }
Matrix(int r, int c, T num = T()) : m(r), n(c) {
v.resize(m, vector<T>(n, num));
}
Matrix(int n) : m(n), n(n) { // 单位矩阵的构造函数
v.resize(n, vector<T>(n, 0));
for (int i = 0; i < n; i++) v[i][i] = 1;
}
Matrix& operator += (const Matrix& b) & {
assert(m == b.m && n == b.n);
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
v[i][j] += b[i][j];
return *this;
}
Matrix& operator -= (const Matrix& b) & {
assert(m == b.m && n == b.n);
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
v[i][j] -= b[i][j];
return *this;
}
Matrix& operator *= (const T& b) & {
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
v[i][j] *= b;
return *this;
}
Matrix& operator *= (const Matrix& b) & {
assert(n == b.m);
Matrix ans(m, b.n);
for (int i = 0; i < m; i++)
for (int j = 0; j < b.n; j++)
for (int k = 0; k < n; k++)
ans[i][j] += v[i][k] * b[k][j];
return *this = ans;
}
friend Matrix operator + (const Matrix& a, const Matrix& b) {
Matrix ans = a; ans += b; return ans;
}
friend Matrix operator - (const Matrix& a, const Matrix& b) {
Matrix ans = a; ans -= b; return ans;
}
friend Matrix operator * (const Matrix& a, const T& b) {
Matrix ans = a; a *= b; return ans;
}
friend Matrix operator * (const Matrix& a, const Matrix& b) {
Matrix ans = a; ans *= b; return ans;
}
Matrix trans() const {
Matrix ans(n, m);
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
ans[i][j] = v[j][i];
return ans;
}
Matrix ksm(const long long& x) const {
assert(n == m);
Matrix ans(n), a = *this;
for (long long pw = x; pw; a *= a, pw >>= 1) if (pw & 1) ans *= a;
return ans;
}
vector<T>& operator [] (const int& t) { return v[t]; }
const vector<T>& operator [] (const int& t) const { return v[t]; }
friend bool operator == (const Matrix& a, const Matrix& b) {
assert(a.m == b.m && a.n == b.n);
for (int i = 0; i < a.m; i++)
for (int j = 0; j < a.n; j++)
if (a[i][j] != b[i][j])
return false;
return true;
}
friend bool operator != (const Matrix& a, const Matrix& b) {
return !(a == b);
}
friend istream& operator >> (istream& in, Matrix& x) {
for (int i = 0; i < x.m; i++)
for (int j = 0; j < x.n; j++)
in >> x[i][j];
return in;
}
friend ostream& operator << (ostream& out, const Matrix& x) {
for (int i = 0; i < x.m; i++)
for (int j = 0; j < x.n; j++)
out << x[i][j] << " \n"[j == x.n - 1];
return out;
}
};
void work() {
cin >> n >> d;
Matrix<i64> Q(n, d), K(n, d), V(n, d);
cin >> Q >> K >> V;
Matrix<i64> ans = Q * (K.trans() * V);
for (int i = 0; i < n; i++) {
int w; cin >> w;
for (int j = 0; j < d; j++)
ans[i][j] *= w;
}
cout << ans << endl;
}
signed main() {
#ifdef LOCAL
freopen("C:\\Users\\admin\\CLionProjects\\Practice\\", "r", stdin);
freopen("C:\\Users\\admin\\CLionProjects\\Practice\\", "w", stdout);
#endif
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int Case = 1;
//cin >> Case;
while (Case--) work();
return 0;
}
/*
_____ _ _ _ __ __
| _ \ | | | | | | \ \ / /
| |_| | | | | | | | \ \/ /
| ___/ | | | | _ | | } {
| | | |_| | | |_| | / /\ \
|_| \_____/ \_____/ /_/ \_\
*/
关于代码的亿点点说明:
- 代码的主体部分位于
void work()
函数中,另外会有部分变量申明、结构体定义、函数定义在上方。#pragma ...
是用来开启 O2、O3 等优化加快代码速度。- 中间一大堆
#define ...
是我习惯上的一些宏定义,用来加快代码编写的速度。""
头文件是我用于调试输出的代码,没有这个头文件也可以正常运行(前提是没定义DEBUG
宏),在程序中如果看到dbg(...)
是我中途调试的输出的语句,可能没删干净,但是没有提交上去没有任何影响。ios::sync_with_stdio(false); (0); (0);
这三句话是用于解除流同步,加快输入cin
输出cout
速度(这个输入输出流的速度很慢)。在小数据量无所谓,但是在比较大的读入时建议加这句话,避免读入输出超时。如果记不下来可以换用scanf
和printf
,但使用了这句话后,cin
和scanf
、cout
和printf
不能混用。- 将
main
函数和work
函数分开写纯属个人习惯,主要是为了多组数据。