为什么 Attention 需要除以 :从数值稳定性到梯度流
为什么 Attention 需要除以 :从数值稳定性到梯度流
在 Transformer 模型中,Scaled Dot-Product Attention 的核心公式为:
其中 是 Key 向量的维度。初学者常有的疑问是:为什么必须除以 ?如果不除会怎样?除以 行不行?
本文将从统计学性质、梯度传播机制以及数值实验三个维度,深入剖析这一缩放因子的必要性。
统计学推导:点积的方差爆炸
假设前提:从特殊到一般
你可能会质疑:现实中的 Query 和 Key 向量真的服从标准正态分布吗?如果不服从,这个结论还成立吗?
这是一个非常好的问题。实际上,我们并不需要严格的正态分布假设。
我们只需要假设 Query 和 Key 的各个分量 是独立同分布的随机变量,且满足:
- 均值为 :
- 方差为 :
初始化的作用
在初始化阶段,我们通常会将 设为 或类似量级;在训练过程中,Layer Normalization 也会强制拉回这一分布。
点积的均值与方差
它们的点积 是一个标量。我们来计算 的统计特征。
均值:
由于 相互独立且均值为 0:方差:
首先计算单个乘积项 的方差。根据方差公式 :点积 是 个独立变量之和,因此其方差为:
结论:在输入向量各个分量独立同分布且均值为 0 的假设下,点积 的方差与维度 呈线性关系。如果不进行缩放,随着维度 的增加,点积的数值范围将显著扩大,导致数值不稳定性。
为了将输出方差控制在常数级别(与输入同量级),除以 是必要的标准化手段:
梯度分析:Softmax 的饱和区
方差增大对模型训练的直接负面影响在于 Softmax 函数的梯度性质。
Softmax 函数 对输入数值的尺度非常敏感。当输入 的方差变大时,元素之间的差值也会随之拉大。由于指数函数的放大效应,Softmax 的输出分布将迅速趋向于 One-hot 分布(即最大值对应的概率接近 1,其余接近 0)。
数值示例:
- 方差较小:假设输入为 。
- 。
- Softmax 输出约为 。分布相对平滑,梯度信息丰富。
- 方差较大(扩大 10 倍):假设输入为 。
- 。
- 此时 。
- Softmax 输出约为 。
在这种饱和(Saturated)状态下,让我们观察 Softmax 的 Jacobian 矩阵:
- 当 时,,导致主对角线梯度趋近于 0。
- 当 时,,导致非对角线梯度趋近于 0。
梯度消失的核心机制
一旦 Softmax 进入饱和区,Jacobian 矩阵的绝大多数元素都会趋近于 0。这意味着反向传播时,梯度无法有效通过 Attention 层传递到更底层的参数,导致梯度消失(Gradient Vanishing),模型训练将陷入停滞。
通过除以 ,我们将点积 的方差标准化,使其分布保持在 Softmax 的非饱和区(即线性区附近),从而保证了梯度的有效流动和训练的稳定性。
现实中的分布与 Layer Normalization
一个常见的质疑是:“虽然理论上假设了独立同分布和零均值,但深度神经网络中的实际分布真的满足这些条件吗?”
这正是 Layer Normalization (LN) 发挥关键作用的地方。
在 Transformer 的标准结构中,Attention 层的输入通常经过了 LayerNorm 处理(或者是 Pre-LN 结构中的直接输入)。LayerNorm 的定义如下:
LayerNorm 强制将输入向量标准化为均值为 0、方差为 1 的分布(忽略 的后续仿射变换)。这意味着:
- 输入 和 的各个维度在统计上接近于均值为 0、方差为 1 的分布。
- 因此,第 1 节中的统计学假设在 Transformer 架构中是具有高度现实意义的近似。
正是 LayerNorm 与 Scaled Dot-Product Attention 的配合,共同维护了深层网络中信号传输的数值稳定性。如果没有 LayerNorm,随着网络层数加深,输入 的方差可能会发生漂移,此时仅仅除以 可能就不够了。
总结来说, 解决了“宽度”(维度)带来的方差问题,而 Layer Normalization 解决了“深度”(层数)带来的方差问题。两者相辅相成。
4. 代码验证
我们可以写一段简单的 Python 代码来验证这一现象。我们不再使用正态分布,而是使用均匀分布来模拟,看看结论是否依然成立。
import numpy as np
def softmax(x):
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def verify_variance(d_k=512, n_samples=10000):
print(f"--- Dimension (d_k): {d_k} ---")
# 模拟 Q 和 K,使用均匀分布 Uniform(-sqrt(3), sqrt(3))
# 这样的均匀分布均值为 0,方差为 1,满足我们的假设
limit = np.sqrt(3)
Q = np.random.uniform(-limit, limit, (n_samples, d_k))
K = np.random.uniform(-limit, limit, (n_samples, d_k))
# 计算点积
dot_products = np.sum(Q * K, axis=1)
print(f"Theoretical Variance: {d_k}")
print(f"Actual Variance: {np.var(dot_products):.2f}")
# 验证缩放效果
scaled_dot_products = dot_products / np.sqrt(d_k)
print(f"Scaled Variance (div sqrt(d_k)): {np.var(scaled_dot_products):.2f}")
# 模拟 Softmax 输出分布
sample_logits = dot_products[:1] # 取一个样本
probs_unscaled = softmax(sample_logits) # 此时只有一个值其实没有意义,这里只是演示量级
# 为了演示 Softmax 饱和,我们需要看一个向量内部的分布
# 我们模拟一个 Query 和 一组 Keys 的点积结果
# 假设有 m 个 Keys,每个 Key 维度 d_k
m_keys = 100
one_Q = np.random.uniform(-limit, limit, (1, d_k))
many_K = np.random.uniform(-limit, limit, (m_keys, d_k))
# (1, d_k) @ (d_k, m_keys) -> (1, m_keys)
logits = np.matmul(one_Q, many_K.T).flatten()
print(f"Logits Std Dev: {np.std(logits):.2f}")
probs_unscaled = softmax(logits)
probs_scaled = softmax(logits / np.sqrt(d_k))
print(f"Unscaled Max Probability: {np.max(probs_unscaled):.4f} (Peaky)")
print(f"Scaled Max Probability: {np.max(probs_scaled):.4f} (Smoother)\n")
# 运行验证
# verify_variance()实验结果解读
无论输入分布是正态分布还是均匀分布,只要方差被 LayerNorm 控制住:
- 未缩放:Logits 的标准差都会随着 增大。
- 缩放后:Logits 的分布都会被拉回到常数级别,保证 Softmax 的平滑度。
5. 为什么不是其他缩放因子?
为什么不除以 ?
如果除以 ,方差变为 。当 很大时,方差趋近于 0,所有 Logits 趋近于 0。此时 Softmax 输出趋向于均匀分布(Uniform Distribution),即每个位置概率为 。这也是一种“梯度消失”,因为此时模型无法区分 Key 的重要性,Attention 机制退化为平均池化(Average Pooling)。
总结
Attention 除以 是为了对抗高维空间中的方差爆炸。
- 数值稳定性:保持 Softmax 输入的方差为常数级别()。
- 梯度保护:防止 Softmax 进入饱和区导致梯度消失。
- 初始化一致性:使得模型在不同 设置下,初始训练状态保持一致。
这看似简单的一步除法,实则是 Transformer 能够训练深层、宽层网络的关键细节之一。