Decaying Causal Attention
2026/6/6大约 1 分钟
Decaying Causal Attention
题目描述
实现衰减因果注意力。给定查询矩阵 、键矩阵 和值矩阵 (形状均为 )以及标量衰减因子 ,计算未归一化的因果注意力输出。位置 以权重 关注所有过去位置 :
与标准 softmax 注意力不同,这里没有归一化——权重从当前位置向后呈几何衰减。这是 RetNet(Retention Network)的并行形式,用作序列模型中注意力机制的一种递推友好替代方案。
实现要求
- 实现
solve函数,签名保持不变。 - 不允许使用外部库。
示例
seq_len=2, d_model=4, gamma=0.5
Q=[[1,1,0,0],[1,1,0,0]], K=[[1,0,0,0],[0,1,0,0]], V=[[4,8,12,16],[4,8,12,16]]
Scores QK^T/√4: [[0.5,0.5],[0.5,0.5]]
Decay mask D[n,m]=0.5^(n-m): [[1,0],[0.5,1]]
Weighted A⊙D: [[0.5,0],[0.25,0.5]]
Output (A⊙D)V: [[2,4,6,8],[3,6,9,12]]约束条件
- ,。
- 。
- 性能测试在 下进行。
解题思路
衰减因果注意力与标准 softmax attention 的区别在于用几何衰减替代了 softmax 归一化。计算流程仍是 + 逐元素衰减掩码 + 。衰减掩码是一个下三角矩阵,第 个元素为 ,对角线为 1。可以用递推方式避免 的计算成本——每个新位置 的加权和可以从 递推而来。欢迎在 GitHub Discussions 分享你的解法。