2025-4-23-Transformer相关知识
有关于Transformer的一些知识盘点。
Transformer
CNN | RNN | Transformer | |
---|---|---|---|
长距离依赖 | 感受野有限 | 梯度消失 | 自注意力直接建模任意位置交互 |
并行化能力 | 逐层卷积 | 顺序计算 | 全序列并行计算(自注意力 + FFN) |
梯度稳定性 | BN 等技巧 | 门控缓解 | 缩放点积(Scaled Dot-Product)+ 层归一化 |
Attention时间复杂度:\(O(n^2d)\)
Attention空间复杂度:\(O(n^2)\),权重矩阵
Q | K | V | |
---|---|---|---|
交叉注意力层 | 解码器中因果注意力层的输出向量 | 编码器输出的注意力向量 | 编码器输出的注意力向量 |
因果注意力层 | 输出序列中的当前位置词向量 | 输出序列中的所有位置词向量 | 输出序列中的所有位置词向量 |
全局自注意力层 | 输入序列中的当前位置词向量 | 输入序列中的当前位置词向量 | 输入序列中的当前位置词向量 |
KV cache
Transformer 在推理阶段(Inference)*是*逐个 token 生成的(叫做自回归生成)。
如果每生成一个 token 都重新对前面所有的 token 做 attention,会造成大量重复计算!
Transformer 的每一层都有注意力机制。在 Decoder 的每一层,都会有这样的结构:
- Masked Multi-Head Self-Attention(遮住未来 token) ← 就是这里用 KV Cache!
- Encoder-Decoder Attention(使用 Encoder 输出)
- Feed-Forward Layer
例子
比如你在生成文本
"The cat sat on"
,模型每次只输出一个新词:
第一步:输入第一个词 "The"
- 假设 token 1 的 embedding 是\(x_1\)
- 计算: \[ Q_1 = x_1 W^Q,\quad K_1 = x_1 W^K,\quad V_1 = x_1 W^V\\ \text{softmax}(Q_1 K_1^T) V_1 \rightarrow 输出 token 2 \]
- 把 \(K_1, V_1\) 缓存起来(KV Cache)
第二步:输入第二个词 "cat"
- 新的输入是 token 2 的 embedding\(x_2\)
- 现在我们只需要计算: \[ Q_2 = x_2 W^Q\\ K_{\text{all}} = [K_1, K_2],\quad V_{\text{all}} = [V_1, V_2]\\ \text{softmax}(Q_2 [K_1, K_2]^T) [V_1, V_2] \] 其中 \(K_2 = x_2 W^K,\ V_2 = x_2 W^V\),也会加入缓存。
第三步:输入第三个词 "sat"
,流程完全一样:
只计算\(Q_3 = x_3 W^Q\)
直接使用之前缓存的 \(K_1, K_2, K_3\) 和 \(V_1, V_2, V_3\) 做 Attention: \[ \text{softmax}(Q_3 [K_1, K_2, K_3]^T) [V_1, V_2, V_3] \]
代码实现
MHA
1 | import torch |
为什么d_model 必须被 num_heads 整除?
把总的维度 dmodeld_{}dmodel 平均 分给每个头(head)。
Q = self.W_q(query).view(B, L_q, self.num_heads, self.d_k).transpose(1, 2) # (B, H, L_q, d_k)这是在干什么?,K.transpose(-2, -1)中的参数是什么 意思?
transpose(1, 2)
:把L_q
和H
互换位置,得到(B, H, L_q, d_k)
,这个是注意力标准格式- 是对张量的 最后两个维度做转置:假设
K.shape = (B, H, L_k, d_k)
,那K.transpose(-2, -1)
变成(B, H, d_k, L_k)
scores.masked_fill此函数是?
把
mask==0
的位置变成-inf
attn = F.softmax(scores, dim=-1) ,dim=-1是?
dim=-1
表示在最后一个维度上做 softmax。即:在每个 query 上,对所有的 key 做归一化合并 heads context = context.transpose(1, 2).contiguous().view(B, L_q, self.d_model) # (B, L_q, d_model)这个是在?num_heads似乎没有使用。
- context.shape = (B, H, L_q, d_k)
- context.transpose(1, 2):(B, L_q, H, d_k)
- contiguous().view(B, L_q, d_model):合并 H 和 d_k,拼回去
- 输出是
(batch_size, seq_len, d_model)
,保持与原始输入形状一致
Scaled Dot-Product Attention
1 | def scaled_dot_product_attention(Q,K,V,mask=None): |
- 为什么Scaled Dot-Product Attention要返回output, attn?
output
是最终注意力后的值(输入后续 FFN)attn
是注意力权重(可视化分析、或者做 attention dropout)
交叉熵损失
1 | def cross_entropy(pred,target): |
- 交叉熵损失的代码
(log-probabilities):
torch.logsumexp
: log softmax \(\sum_i e^{\text{logits}\ {s_i}}\) \[ \log{p_i}=\text{logits}\ {s_i}-\log \sum_j e^{\text{logits}\ {s_j}}\\ \]