Grouped Query Attention
2026/6/6大约 2 分钟
Grouped Query Attention
题目描述
实现分组查询注意力(GQA),这是 LLaMA-3、Mistral、Gemma 等现代大语言模型使用的注意力机制。GQA 通过在查询头组之间共享键和值头来减少推理时的 KV-cache 内存占用。
给定查询张量 ( 个头)和键/值张量 、(各 个头),计算缩放点积注意力,其中每 个连续查询头共享同一个键值头:
所有张量使用 float32。
实现要求
- 实现
solve(Q, K, V, output, num_q_heads, num_kv_heads, seq_len, head_dim)。 - 不允许使用外部库。
- 始终可被 整除。
示例
num_q_heads=4, num_kv_heads=2 (每组2), seq_len=3, head_dim=4
Q0: [[1,0,0,1],[0,1,1,0],[1,1,0,0]] → Q0,Q1 attend to K0,V0
Q1: [[0,1,0,1],[1,0,1,0],[0,0,1,1]] → Q2,Q3 attend to K1,V1
...
Output: (4 heads × 3 positions × 4 dims, 结果取2位小数)约束条件
- 。
- ,(8 的倍数)。
- 性能测试在 ,,, 下进行。
解题思路
GQA 的关键区别在于 KV 头的广播——每个 KV 头被多个 Q 头共享。实现中可以在 batch 维度上复制 KV 头以匹配 Q 头数量,然后执行标准的 MHA;更高效的做法是直接在注意力计算中处理分组映射,避免额外的内存拷贝。KV-cache 友好的内存布局是生产级实现的重点。欢迎在 GitHub Discussions 分享你的解法。