General Matrix Multiplication (GEMM)
2026/6/6大约 1 分钟
General Matrix Multiplication (GEMM)
题目描述
实现一个基本的通用矩阵乘法(GEMM)。给定矩阵 ()、()、输入/输出矩阵 ()以及标量乘数 和 ,计算:
输入矩阵 、 和 的初始状态包含 16 位浮点数(FP16/half),所有矩阵以行优先顺序存储。标量 和 为 32 位浮点数。
实现要求
- 只允许使用原生功能(除 WMMA 外不允许使用外部库)。
solve函数签名必须保持不变。- 乘法过程中的累加应使用 FP32 以获得更好的精度,最终结果转换回 FP16。
- 最终结果必须以 half 类型存储回矩阵 。
示例
Input: A (2×3, FP16): [[1,2,3],[4,5,6]]
B (3×2, FP16): [[1,2],[3,4],[5,6]]
C_init (2×2, FP16): [[1,1],[1,1]]
α = 1.0, β = 0.0
Output: C (2×2, FP16): [[22,28],[49,64]]约束条件
- 。
- 性能测试在 的规模下进行。
解题思路
GEMM 是 GPU 计算中最重要的内核之一,cuBLAS 的大部分性能优化都围绕它展开。核心优化链:朴素全局内存 → 共享内存分块 → 寄存器分块 → 向量化加载(float4)→ 双缓冲 → warp-level 矩阵指令(WMMA/Tensor Core)。FP16 计算可以利用 Tensor Core(Ampere+)获得数倍于 FP32 的吞吐。累加器用 FP32 是标准做法——FP16 的精度不足以安全地累积大量项。欢迎在 GitHub Discussions 分享你的解法。