注意力(Attention)是 Transformer 捕获长距离依赖的核心机制。这篇文章从数学推导到代码实现,逐步拆解,每一步都给出具体数值。

缩放点积注意力

给定 Query $Q$、Key $K$、Value $V$,注意力输出为:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

用一个小型示例来追踪。设 $d_k = 4$:

Q = [[1.0, 0.0, 0.5, 0.2]]   # shape (1, 4)
K = [[0.8, 0.1, 0.3, 0.6],   # shape (3, 4)
     [0.2, 0.9, 0.1, 0.4],
     [0.5, 0.5, 0.5, 0.5]]
V = [[1.0, 0.0],             # shape (3, 2)
     [0.0, 1.0],
     [0.5, 0.5]]

Step 1:计算 $QK^T$。对第一个 Key:$1.0 \times 0.8 + 0.0 \times 0.1 + 0.5 \times 0.3 + 0.2 \times 0.6 = 0.8 + 0 + 0.15 + 0.12 = \mathbf{1.07}$。

为什么要除以 √d_k?

随着 $d_k$ 增大,点积的方差也会增大。假设 $Q = [0.1, 0.2, 0.3, 0.4]$,$K = [0.1, 0.2, 0.3, 0.4]$,点积为 $0.01 + 0.04 + 0.09 + 0.16 = 0.30$。但当 $d_k = 1280$(如 GPT-3),点积的方差约为 1280,会把 softmax 推入梯度消失的区域。

$\sqrt{d_k}$ 缩放使得无论维度多大,方差都保持在约 1,确保训练过程中 softmax 梯度稳定。

多头注意力

不是只做一个注意力函数,而是将 $Q, K, V$ 投影到 $h$ 个不同的子空间并行计算:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$

以 $d_{\text{model}} = 512$、$h = 8$ 为例,每个头处理 $d_k = 512 / 8 = 64$ 维。这允许模型同时关注来自不同表示子空间的信息。

为什么要多个头?

不同的头学习不同的模式:有的头捕获语法关系,有的关注语义关联。研究(Clark et al., 2019)表明某些头专门做共指消解,另一些头专注于关注前一个 token。

数值计算完整示例

继续上面的计算。$d_k = 4$,缩放因子 $\sqrt{4} = 2$:

scores = QK^T / 2 = [1.07, 0.82, 1.0] / 2 = [0.535, 0.410, 0.500]
weights = softmax(scores) = [0.349, 0.310, 0.341]
output  = weights @ V
        = 0.349*[1,0] + 0.310*[0,1] + 0.341*[0.5,0.5]
        = [0.349+0+0.171, 0+0.310+0.171]
        = [0.520, 0.481]

实现代码

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """缩放点积注意力
    
    Args:
        Q: Query tensor, shape (..., seq_len_q, d_k)
        K: Key tensor,   shape (..., seq_len_k, d_k)
        V: Value tensor,  shape (..., seq_len_k, d_v)
        mask: 可选,用于 decoder 的因果遮罩
    
    Returns:
        output: 注意力输出, shape (..., seq_len_q, d_v)
        weights: 注意力权重, shape (..., seq_len_q, seq_len_k)
    """
    d_k = Q.size(-1)
    # Step 1: 计算注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
    # Step 2: 应用遮罩(如果有)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # Step 3: softmax 归一化
    weights = F.softmax(scores, dim=-1)
    # Step 4: 加权求和
    return torch.matmul(weights, V), weights