Softmax Attention
2026/6/6大约 1 分钟
Softmax Attention
题目描述
编写一个 GPU 程序,计算矩阵的 softmax 注意力操作。给定查询矩阵 ()、键矩阵 ()和值矩阵 (),计算输出矩阵:
其中 softmax 按行应用。
实现要求
- 只允许使用 GPU 原生功能(不允许使用外部库)。
solve函数签名必须保持不变。- 最终结果必须存储在输出矩阵
output中。
示例
示例 1
Input: Q (2×4): [[1,0,0,0], [0,1,0,0]]
K (3×4): [[1,0,0,0], [0,1,0,0], [0,0,1,0]]
V (3×4): [[1,2,3,4], [5,6,7,8], [9,10,11,12]]
Output: (2×4): [[4.29, 5.29, 6.29, 7.29], [5.00, 6.00, 7.00, 8.00]]示例 2
Input: Q (1×2): [[1, 2]]
K (2×2): [[1, 0], [0, 1]]
V (2×2): [[3, 4], [5, 6]]
Output: (1×2): [[4.34, 5.34]]约束条件
- 为 , 和 为 。
- ,。
- 性能测试在 的规模下进行。
解题思路
Softmax Attention 是 Transformer 的核心运算,可分解为三步:(矩阵乘)、、(矩阵乘)。关键挑战在于 softmax 的数值稳定性:需要先按行找 max,再用 "max trick" 计算 softmax。当 较大时, 的结果矩阵可能非常大,需要分块(tiling)计算来节省显存——这就是 FlashAttention 的核心思想。欢迎在 GitHub Discussions 分享你的解法。