LoRA Linear
2026/6/6大约 1 分钟
LoRA Linear
题目描述
实现 LoRA(Low-Rank Adaptation)线性层的前向传播。给定输入矩阵 ()、基础权重矩阵 ()、LoRA 降维投影矩阵 ()和 LoRA 升维投影矩阵 (),计算:
其中 。所有张量为 float32。
实现要求
- 实现
solve函数,签名保持不变。 - 不允许使用外部库。
- 将结果写入
output。
示例
Input: x = [[1,0,-1,2],[0,1,1,-1]], W = I_3x4 (前3行), A = [[1,0,0,0],[0,1,0,0]], B = [[1,0,0],[0,1,0]]
lora_scale = 0.5
x @ W^T = [[1,0,-1],[0,1,1]]
x @ A^T = [[1,0],[0,1]]
α * (x@A^T) @ B^T = 0.5 * [[1,0,0],[0,1,0]]
Output: [[1.5, 0, -1], [0, 1.5, 1]]约束条件
- ,。
- ,。
- 性能测试在 下进行。
解题思路
LoRA 将全秩权重更新分解为两个低秩矩阵的乘积。计算流程:(GEMM)+ (两次小 GEMM 或一次 batched GEMM)。由于 ,LoRA 分支的计算量远小于主分支,但增加了一次额外的矩阵乘。可以将 视为合并权重来一次性完成计算。欢迎在 GitHub Discussions 分享你的解法。