SwiGLU MLP Block
2026/6/6大约 1 分钟
SwiGLU MLP Block
题目描述
实现 SwiGLU MLP 块——LLaMA、Mistral、Gemma 及大多数现代大语言模型中使用的前馈网络。给定输入矩阵 ()和三个权重矩阵 ()、()和 (),计算:
其中 (sigmoid 门控), 表示逐元素乘法。所有张量为 float32。
实现要求
- 实现
solve函数,签名保持不变。 - 不允许使用外部库。
- 将结果写入
output。
示例
Input: M=2, d_model=2, d_ffn=4
x = [[1,0],[0,1]], W_gate=W_up = [[1,0,0,0],[0,1,0,0]], W_down = [[1,0],[0,1],[0,0],[0,0]]
gate = x @ W_gate = [[1,0,0,0],[0,1,0,0]]
SiLU(gate) ≈ [[0.7311,0,0,0],[0,0.7311,0,0]]
hidden = SiLU(gate) ⊙ up ≈ [[0.7311,0,0,0],[0,0.7311,0,0]]
Output: hidden @ W_down ≈ [[0.7311,0],[0,0.7311]]约束条件
- ,,。
- 性能测试在 下进行。
解题思路
SwiGLU MLP 包含两次矩阵乘法(gate 和 up projection)和一次下游投影。gate 和 up 的投影可以合并为一个矩阵乘法(拼在一起),然后 split+SiLU+multiply,减少 kernel launch 次数。SiLU 是逐元素操作,可以融合到上一个或下一个 kernel 中。 投影是最昂贵的部分( 通常远大于 )。欢迎在 GitHub Discussions 分享你的解法。