Linear Recurrence
2026/6/6大约 1 分钟
Linear Recurrence
题目描述
给定两个矩阵 和 ,形状均为 (批量大小 序列长度),计算形状为 的线性递推 :
所有值为 float32。此操作是状态空间模型(SSM)(如 Mamba、S4 和 H3)的核心计算原语。
实现要求
- 不允许使用外部库。
solve函数签名必须保持不变。- 结果必须存储在输出张量 中。
示例
示例 1 — 指数衰减 (a=0.5, 单位脉冲):
a = [0.5, 0.5, 0.5, 0.5], x = [1, 0, 0, 0]
h = [1, 0.5, 0.25, 0.125]
示例 2 — 前缀和 (a=1, 全1输入):
a = [1, 1, 1, 1], x = [1, 1, 1, 1]
h = [1, 2, 3, 4]
示例 3 — B=2, L=4:
a = [[0.5,0.5,0.5,0.5],[1,1,1,1]], x = [[1,0,0,0],[1,1,1,1]]
h = [[1,0.5,0.25,0.125],[1,2,3,4]]约束条件
- ,。
- 和 均为 float32。
- 性能测试在 下进行。
解题思路
线性递推的串行依赖是 GPU 并行化的核心难点—— 依赖 使得不能直接全并行。对于小 ,可以使用并行扫描(parallel scan / prefix sum 的关联性推广)在 步骤中求解。对于大 和不同的 ,分段并行+串行边界传递是常用策略。欢迎在 GitHub Discussions 分享你的解法。