Batched Matrix Multiplication
2026/6/6大约 1 分钟
Batched Matrix Multiplication
题目描述
在 FP32 中实现批量矩阵乘法。给定一批形状为 的矩阵 和一批形状为 的矩阵 ,计算输出批次 (形状 ),使得对每个批次索引 :
所有矩阵以行优先顺序存储,使用 32 位浮点数(FP32)。
实现要求
- 不允许使用外部库。
solve函数签名必须保持不变。- 最终结果必须存储在数组 中。
示例
Input: B=2, M=2, K=3, N=2
A[0] = [[1,2,3],[4,5,6]], A[1] = [[7,8,9],[10,11,12]]
B[0] = [[1,2],[3,4],[5,6]], B[1] = [[6,5],[4,3],[2,1]]
Output: C[0] = [[22,28],[49,64]], C[1] = [[92,68],[128,95]]约束条件
- 。
- 。
- 性能测试在 的规模下进行。
解题思路
批量矩阵乘法是对 个独立矩阵对同时执行 GEMM。最简单的做法是串行处理每个 batch,但 GPU 的优势在于并行——可以让每个 batch 由不同的 SM 处理,或者使用 cuBLAS 的 cublasGemmBatched / cublasGemmStridedBatched。手写实现中,关键是将 batch 索引作为第三维扩展到已有的二维分块 GEMM 策略中,确保所有 batch 同时利用 GPU 的并行性。欢迎在 GitHub Discussions 分享你的解法。