原始题目:LeetGPU - SSM Selective Scan
实现状态空间模型(SSM)选择性扫描的前向传播,这是 Mamba 风格序列模型的核心操作。给定输入序列 u、时间步参数 Δ、状态转移矩阵 A、输入投影 B、输出投影 C 和跳跃连接权重 skip,以 float32 计算输出序列 y。
对每个 batch b、位置 t 和通道 d:
Aˉb,t,d,nBˉb,t,d,nhb,t,d,nyb,t,d=exp(Δb,t,d⋅Ad,n)=Δb,t,d⋅Bb,t,n=Aˉb,t,d,n⋅hb,t−1,d,n+Bˉb,t,d,n⋅ub,t,d=n∑Cb,t,n⋅hb,t,d,n+skipd⋅ub,t,d
初始隐藏状态 hb,−1,d,n=0。所有通道 d 相互独立——它们共享相同的 B、C 投影,但在 A 中有各自独立的状态转移行。
- 实现
solve(u, delta, A, B, C, skip, y, batch, seq_len, d_model, d_state),签名保持不变。 - 不允许使用外部库。
batch=1, seq_len=4, d_model=2, d_state=2
u = [[[1,0],[0,1],[1,1],[0,0]]], delta 全1
A = [[-0.5,-1.0],[-0.5,-1.0]] → A_bar ≈ [[0.607,0.368],[0.607,0.368]]
t=0: y=[1,0], t=1: y=[0,1], t=2: y=[2.368,2.368], t=3: y=[0.599,0.555]
- 1≤batch≤16,1≤seq_len≤8,192,1≤dmodel≤2,048,1≤d_state≤64。
- Δ>0,A<0(确保 Aˉ∈(0,1))。
- 性能测试在 batch=4,seq_len=4,096,dmodel=512,d_state=16 下进行。
SSM 选择性扫描是 Mamba 区别于传统 SSM(S4)的关键——Δ、B、C 是输入相关的(选择性),而非固定的。这破坏了卷积表示,使得必须逐时间步串行计算。与线性递推类似,可以使用并行扫描来处理选择性 SSM:将串行递推转化为关联操作,然后在 O(logL) 并行步中求解。欢迎在 GitHub Discussions 分享你的解法。