原始题目:LeetGPU - Multi-Head Attention
实现多头自注意力(Multi-Head Self-Attention)。给定查询矩阵 Q、键矩阵 K 和值矩阵 V(大小均为 N×dmodel),计算:
MultiHead(Q,K,V)=Concat(head1,…,headh)
其中每个头计算:
headi=softmax(dkQiKiT)Vi
dk=dmodel/h,Qi,Ki,Vi 是输入矩阵的第 i 个头分区。
- 不允许使用外部库。
solve 函数签名必须保持不变。
N=2, d_model=4, h=2
Q=[[1,0,2,3],[4,5,6,7]], K=[[1,2,3,4],[5,6,7,8]], V=[[0.5,1,1.5,2],[2.5,3,3.5,4]]
Output (2×4): [[2.39,2.89,3.50,4.00],[2.50,3.00,3.50,4.00]]
- 1≤N≤10,000,2≤dmodel≤1,024,1≤h≤dmodel,dmodelmodh=0。
- 性能测试在 N=1,024,dmodel=1,024 下进行。
MHA 是 Transformer 的核心算子。每个头独立计算 softmax attention,最后拼接输出。GPU 上高效实现的关键是将多个小矩阵乘法 batch 化:将 Q,K,V 的重塑视为 batched GEMM。每个头的数据量较小(dk 一般 64–128),可以多个头并行处理。FlashAttention 通过分块计算 QKT 避免完整 attention 矩阵的显存分配,是生产级实现的标准方案。欢迎在 GitHub Discussions 分享你的解法。