FP16 Batched Matrix Multiplication
2026/6/6大约 1 分钟
FP16 Batched Matrix Multiplication
题目描述
在 FP16 中实现批量矩阵乘法。给定一批形状为 的矩阵 和一批形状为 的矩阵 (均为 FP16/half 类型),计算输出批次 (形状 ):
累加过程中使用 FP32 以获得更好的精度,最终结果转换回 FP16。所有矩阵以行优先顺序存储。
实现要求
- 不允许使用外部库。
solve函数签名必须保持不变。- 累加使用 FP32,最终结果以 half 存储在 中。
示例
Input: B=2, M=2, K=3, N=2
A[0]=[[1,2,3],[4,5,6]], A[1]=[[7,8,9],[10,11,12]]
B[0]=[[1,2],[3,4],[5,6]], B[1]=[[6,5],[4,3],[2,1]]
Output: C[0]=[[22,28],[49,64]], C[1]=[[92,68],[128,95]]约束条件
- ,。
- 性能测试在 的规模下进行。
解题思路
与 FP32 批量矩阵乘法结构相同,但使用 FP16 存储和 FP32 累加。利用 Tensor Core(Ampere+)的 FP16 矩阵乘法指令可以显著提升吞吐。关键点:FP16 的精度有限(约 3.3 位十进制有效数字),累加器必须用 FP32 以避免舍入误差累积。__half 类型和 __hmul/__hadd 内置函数是标准工具。欢迎在 GitHub Discussions 分享你的解法。