CCF-CSP认证考试 202305-2 矩阵运算 100分题解

时间:2025-04-18 14:48:18

更多 CSP 认证考试题目题解可以前往:CSP-CCF 认证考试真题题解


原题链接: 202305-2 矩阵运算

时间限制: 5.0s
内存限制: 512.0MB

题目背景

S o f t m a x ( Q × K T d ) × V \mathrm{Softmax}(\frac{\mathbf{Q} \times \mathbf{K}^{T}}{\sqrt{d}}) \times \mathbf{V} Softmax(d Q×KT)×V 是 Transformer 中注意力模块的核心算式,其中 Q \mathbf{Q} Q K \mathbf{K} K V \mathbf{V} V 均是 n n n d d d 列的矩阵, K T \mathbf{K}^{T} KT 表示矩阵 K \mathbf{K} K 的转置, × \times × 表示矩阵乘法。

问题描述

为了方便计算,顿顿同学将 S o f t m a x \mathrm{Softmax} Softmax 简化为了点乘一个大小为 n n n 的一维向量 W \mathbf{W} W ( W ⋅ ( Q × K T ) ) × V \left(\mathbf{W} \cdot (\mathbf{Q} \times \mathbf{K}^{T})\right) \times \mathbf{V} (W(Q×KT))×V
点乘即对应位相乘,记 W ( i ) \mathbf{W}^{(i)} W(i) 为向量 W \mathbf{W} W 的第 i i i 个元素,即将 ( Q × K T ) (\mathbf{Q} \times \mathbf{K}^{T}) (Q×KT) i i i 行中的每个元素都与 W ( i ) \mathbf{W}^{(i)} W(i) 相乘。

现给出矩阵 Q \mathbf{Q} Q K \mathbf{K} K V \mathbf{V} V 和向量 W \mathbf{W} W,试计算顿顿按简化的算式计算的结果。

输入格式

从标准输入读入数据。

输入的第一行包含空格分隔的两个正整数 n n n d d d,表示矩阵的大小。

接下来依次输入矩阵 Q \mathbf{Q} Q K \mathbf{K} K V \mathbf{V} V。每个矩阵输入 n n n 行,每行包含空格分隔的 d d d 个整数,其中第 i i i 行的第 j j j 个数对应矩阵的第 i i i 行、第 j j j 列。

最后一行输入 n n n 个整数,表示向量 W \mathbf{W} W

输出格式

输出到标准输出中。

输出共 n n n 行,每行包含空格分隔的 d d d 个整数,表示计算的结果。

样例输入

3 2
1 2
3 4
5 6
10 10
-20 -20
30 30
6 5
4 3
2 1
4 0 -5

样例输出

480 240
0 0
-2200 -1100

子任务

70 % 70\% 70% 的测试数据满足: n ≤ 100 n \le 100 n100 d ≤ 10 d \le 10 d10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30 30 30

全部的测试数据满足: n ≤ 1 0 4 n \le 10^4 n104 d ≤ 20 d \le 20 d20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000 1000 1000

提示

请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。


题解

对于矩阵 A = ( a i , j ) m × n \mathbf A=(a_{i,j})_{m\times n} A=(ai,j)m×n B = ( b i , j ) n × q \mathbf B=(b_{i,j})_{n\times q} B=(bi,j)n×q,矩阵乘法的结果 C = ( c i , j ) m × q \mathbf C=(c_{i,j})_{m\times q} C=(ci,j)m×q 满足 c i , j = ∑ k = 1 n a i , k b k , j c_{i,j}=\sum\limits_{k=1}^na_{i,k}b_{k,j} ci,j=k=1nai,kbk,j通过公式可以看出,计算矩阵乘法的复杂度为 O ( m n q ) \mathcal{O}(mnq) O(mnq)

按照题目中给的计算顺序: ( W ⋅ ( Q × K T ) ) × V \left(\mathbf{W} \cdot (\mathbf{Q} \times \mathbf{K}^{T})\right) \times \mathbf{V} (W(Q×KT))×V,计算单个矩阵乘法的复杂度为 O ( n 2 d ) \mathcal{O}(n^2d) O(n2d),大致为 2 × 1 0 9 2\times 10^9 2×109,无法接受。

运用矩阵乘法的结合律,可以将顺序变为 W ⋅ ( Q × ( K T × V ) ) \mathbf{W} \cdot \left(\mathbf{Q} \times (\mathbf{K}^{T} \times \mathbf{V})\right) W(Q×(KT×V)),这样计算单个矩阵乘法的复杂度为 O ( n d 2 ) \mathcal{O}(nd^2) O(nd2),大致为 4 × 1 0 6 4\times 10^6 4×106,可以接受。

注意数据范围,要开 long long

时间复杂度: O ( n d 2 ) \mathcal{O}(nd^2) O(nd2)

参考代码(203ms,16.33MB)

/*
    Created by Pujx on 2024/3/16.
*/
#pragma GCC optimize(2, 3, "Ofast", "inline")
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
//#define int long long
//#define double long double
using i64 = long long;
using ui64 = unsigned long long;
using i128 = __int128;
#define inf (int)0x3f3f3f3f3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define yn(x) cout << (x ? "yes" : "no") << endl
#define Yn(x) cout << (x ? "Yes" : "No") << endl
#define YN(x) cout << (x ? "YES" : "NO") << endl
#define mem(x, i) memset(x, i, sizeof(x))
#define cinarr(a, n) for (int i = 1; i <= n; i++) cin >> a[i]
#define cinstl(a) for (auto& x : a) cin >> x;
#define coutarr(a, n) for (int i = 1; i <= n; i++) cout << a[i] << " \n"[i == n]
#define coutstl(a) for (const auto& x : a) cout << x << ' '; cout << endl
#define all(x) (x).begin(), (x).end()
#define md(x) (((x) % mod + mod) % mod)
#define ls (s << 1)
#define rs (s << 1 | 1)
#define ft first
#define se second
#define pii pair<int, int>
#ifdef DEBUG
    #include ""
#else
    #define dbg(...) void(0)
#endif

const int N = 2e5 + 5;
//const int M = 1e5 + 5;
const int mod = 998244353;
//const int mod = 1e9 + 7;
//template <typename T> T ksm(T a, i64 b) { T ans = 1; for (; b; a = 1ll * a * a, b >>= 1) if (b & 1) ans = 1ll * ans * a; return ans; }
//template <typename T> T ksm(T a, i64 b, T m = mod) { T ans = 1; for (; b; a = 1ll * a * a % m, b >>= 1) if (b & 1) ans = 1ll * ans * a % m; return ans; }

int a[N];
int n, m, t, k, q, d;

template <typename T = int> struct Matrix {
    int m, n; // m 行 n 列的矩阵
    vector<vector<T>> v;

    Matrix(): n(0), m(0) { v.resize(0); }
    Matrix(int r, int c, T num = T()) : m(r), n(c) {
        v.resize(m, vector<T>(n, num));
    }
    Matrix(int n) : m(n), n(n) { // 单位矩阵的构造函数
        v.resize(n, vector<T>(n, 0));
        for (int i = 0; i < n; i++) v[i][i] = 1;
    }

    Matrix& operator += (const Matrix& b) & {
        assert(m == b.m && n == b.n);
        for (int i = 0; i < m; i++)
            for (int j = 0; j < n; j++)
                v[i][j] += b[i][j];
        return *this;
    }
    Matrix& operator -= (const Matrix& b) & {
        assert(m == b.m && n == b.n);
        for (int i = 0; i < m; i++)
            for (int j = 0; j < n; j++)
                v[i][j] -= b[i][j];
        return *this;
    }
    Matrix& operator *= (const T& b) & {
        for (int i = 0; i < m; i++)
            for (int j = 0; j < n; j++)
                v[i][j] *= b;
        return *this;
    }
    Matrix& operator *= (const Matrix& b) & {
        assert(n == b.m);
        Matrix ans(m, b.n);
        for (int i = 0; i < m; i++)
            for (int j = 0; j < b.n; j++)
                for (int k = 0; k < n; k++)
                    ans[i][j] += v[i][k] * b[k][j];
        return *this = ans;
    }
    friend Matrix operator + (const Matrix& a, const Matrix& b) {
        Matrix ans = a; ans += b; return ans;
    }
    friend Matrix operator - (const Matrix& a, const Matrix& b) {
        Matrix ans = a; ans -= b; return ans;
    }
    friend Matrix operator * (const Matrix& a, const T& b) {
        Matrix ans = a; a *= b; return ans;
    }
    friend Matrix operator * (const Matrix& a, const Matrix& b) {
        Matrix ans = a; ans *= b; return ans;
    }
    Matrix trans() const {
        Matrix ans(n, m);
        for (int i = 0; i < n; i++)
            for (int j = 0; j < m; j++)
                ans[i][j] = v[j][i];
        return ans;
    }
    Matrix ksm(const long long& x) const {
        assert(n == m);
        Matrix ans(n), a = *this;
        for (long long pw = x; pw; a *= a, pw >>= 1) if (pw & 1) ans *= a;
        return ans;
    }
    vector<T>& operator [] (const int& t) { return v[t]; }
    const vector<T>& operator [] (const int& t) const { return v[t]; }

    friend bool operator == (const Matrix& a, const Matrix& b) {
        assert(a.m == b.m && a.n == b.n);
        for (int i = 0; i < a.m; i++)
            for (int j = 0; j < a.n; j++)
                if (a[i][j] != b[i][j])
                    return false;
        return true;
    }
    friend bool operator != (const Matrix& a, const Matrix& b) {
        return !(a == b);
    }

    friend istream& operator >> (istream& in, Matrix& x) {
        for (int i = 0; i < x.m; i++)
            for (int j = 0; j < x.n; j++)
                in >> x[i][j];
        return in;
    }
    friend ostream& operator << (ostream& out, const Matrix& x) {
        for (int i = 0; i < x.m; i++)
            for (int j = 0; j < x.n; j++)
                out << x[i][j] << " \n"[j == x.n - 1];
        return out;
    }
};

void work() {
    cin >> n >> d;
    Matrix<i64> Q(n, d), K(n, d), V(n, d);
    cin >> Q >> K >> V;
    Matrix<i64> ans = Q * (K.trans() * V);
    for (int i = 0; i < n; i++) {
        int w; cin >> w;
        for (int j = 0; j < d; j++)
            ans[i][j] *= w;
    }
    cout << ans << endl;
}

signed main() {
#ifdef LOCAL
    freopen("C:\\Users\\admin\\CLionProjects\\Practice\\", "r", stdin);
    freopen("C:\\Users\\admin\\CLionProjects\\Practice\\", "w", stdout);
#endif
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int Case = 1;
    //cin >> Case;
    while (Case--) work();
    return 0;
}
/*
     _____   _   _       _  __    __
    |  _  \ | | | |     | | \ \  / /
    | |_| | | | | |     | |  \ \/ /
    |  ___/ | | | |  _  | |   }  {
    | |     | |_| | | |_| |  / /\ \
    |_|     \_____/ \_____/ /_/  \_\
*/

关于代码的亿点点说明:

  1. 代码的主体部分位于 void work() 函数中,另外会有部分变量申明、结构体定义、函数定义在上方。
  2. #pragma ... 是用来开启 O2、O3 等优化加快代码速度。
  3. 中间一大堆 #define ... 是我习惯上的一些宏定义,用来加快代码编写的速度。
  4. "" 头文件是我用于调试输出的代码,没有这个头文件也可以正常运行(前提是没定义 DEBUG 宏),在程序中如果看到 dbg(...) 是我中途调试的输出的语句,可能没删干净,但是没有提交上去没有任何影响。
  5. ios::sync_with_stdio(false); (0); (0); 这三句话是用于解除流同步,加快输入 cin 输出 cout 速度(这个输入输出流的速度很慢)。在小数据量无所谓,但是在比较大的读入时建议加这句话,避免读入输出超时。如果记不下来可以换用 scanfprintf,但使用了这句话后,cinscanfcoutprintf 不能混用。
  6. main 函数和 work 函数分开写纯属个人习惯,主要是为了多组数据。