原始题目:LeetGPU - Causal Self-Attention
实现因果(掩码)自注意力。给定查询矩阵 Q(M×d)、键矩阵 K(M×d)和值矩阵 V(M×d),计算:
Attentioncausal(Q,K,V)=softmax(mask(dQKT))V
其中因果掩码将当前位置之后的所有键位置设为 −∞:
mask(aij)={aij,−∞,j≤ij>i
softmax 按行应用。所有数据为 float32。
- 不允许使用外部库。
solve 函数签名必须保持不变。
Q (2×4): [[1,0,0,0],[0,1,0,0]]
K (2×4): [[1,0,0,0],[0,1,0,0]]
V (2×4): [[1,2,3,4],[5,6,7,8]]
→ Row 0 只看 pos 0, Row 1 看 pos 0+1
Output (2×4): [[1,2,3,4],[3.49,4.49,5.49,6.49]]
- 1≤M≤10,000,1≤d≤128。
- 所有元素范围 [−100,100]。
- 性能测试在 M=5,000 下进行。
因果自注意力 = 标准自注意力 + 下三角掩码。掩码可以通过在 softmax 前将上三角元素(j>i)设为 −∞ 来实现。对于解码器(仅生成阶段),每步只有 1 个查询 token,因此 Q 是 1×d,QKT 是 1×M 的向量,计算量极小——瓶颈在加载整个 K 和 V(即 KV-cache)。欢迎在 GitHub Discussions 分享你的解法。