Attention 的秩崩溃问题
Attention 的秩崩溃问题
“Attention Is Not All You Need” —— 这是 Dong 等人在 2021 年发表的一篇论文的标题(ICML 2021)。
该研究揭示了一个关键的理论问题:在没有特定结构约束的情况下,纯 Attention 模型的秩(Rank)会随着深度增加呈双指数级衰减。
简而言之,如果不引入非线性(如 FFN)或残差连接(Residual),深层 Transformer 的输出最终将退化为线性相关的向量,导致模型表达能力丧失。这种现象被称为秩崩溃(Rank Collapse)。
什么是秩(Rank)?
在线性代数中,矩阵的秩代表了它所能张成的空间的维数,反映了信息的丰富程度。
对于一个包含 个 token、维度为 的序列表示矩阵 :
- 满秩(Full Rank):意味着所有 token 向量在特征空间中是线性无关的,包含丰富且多样的语义信息。
- 低秩(Low Rank):意味着 token 之间存在高度的线性相关性,信息冗余度高。
- 秩为 1:意味着所有 token 退化为同一个向量的标量倍(甚至完全相同),即模型输出发生了“同质化(Homogenization)”。
纯 Self-Attention 的收敛性分析
考虑一个不包含残差连接和 FFN 的纯 Self-Attention 网络:
其中 。
若忽略线性变换 (假设它们为单位阵),Attention 操作本质上是对输入向量的凸组合(Convex Combination)。
几何视角:凸包的收缩(Convex Hull Shrinkage)
Attention 矩阵 满足随机矩阵性质(行和为 1,元素非负)。这意味着每一层的输出 是上一层所有 token 的加权平均。
从几何角度看,这意味着 中的每个向量都位于 所构成的凸包(Convex Hull)内部。
随着层数的增加,这个凸包会不断向几何中心收缩。
Dong 等人利用 Lipschitz 常数进行了严格的数学证明:
其中 。这意味着随着层数 ,所有 token 都会指数级地收敛到一个常数向量 。此时,矩阵 的秩退化为 1。
更严重的是,对于 Self-Attention 机制,这种收敛速度是双指数级(Doubly Exponential)的,即误差衰减遵循 。这比普通的均值滤波(Mean Filtering)收敛速度快得多。
残差连接与 FFN 的缓解机制
既然纯 Attention 会导致秩崩溃,为何现代 Transformer 依然有效?
这主要归功于两个关键组件:残差连接(Residual Connection) 和 前馈网络(FFN)。
残差连接:保持特征多样性
残差连接强制保留了上一层的信息 。这意味着每一层的输出不仅包含“趋同”的 Attention 结果,还保留了原始的、多样的 token 特征。
实验表明,引入残差连接后,秩的衰减速度从双指数级降低为多项式级(Polynomial)或更慢,从而使得训练深层网络成为可能。
FFN:引入非线性与升维
FFN 层通常包含非线性激活函数(ReLU/GeLU),并且通常会先进行升维()再降维。
非线性变换打破了线性混合的均值化趋势,增加了特征空间的复杂度和秩。Dong 等人的实验表明,FFN 是阻止秩崩溃的关键因素。如果没有 FFN,即使有残差连接,深层 Transformer 的表现也会显著下降。
理论指导意义
理解秩崩溃问题,不仅有助于理解 Transformer 各组件的必要性,还指导了后续的模型优化。
Talking-Heads Attention
为了对抗秩崩溃,Google 提出了 Talking-Heads Attention,允许不同的 Head 之间进行交互(加权求和),从而增加 Attention 矩阵的秩,提升模型的表达能力。
DeepNet 与初始化策略
微软的 DeepNet 通过调整初始化策略(DeepNorm),使得深层网络的残差连接权重更大,从而更强地抑制了 Attention 带来的秩衰减,成功训练了 1000 层的模型。
这也从理论上解释了 BERT/GPT 等模型必须堆叠大量 FFN 层的原因:Attention 负责“聚合信息”(Global Aggregation),倾向于降低秩;而 FFN 负责“特征变换”(Feature Transformation),倾向于恢复秩。 两者相互制衡,共同维持了深层网络的表达能力。