Schwertlilien
As a recoder: notes and ideas.

2025-4-23-Transformer相关知识

有关于Transformer的一些知识盘点。

Transformer

CNN RNN Transformer
长距离依赖 感受野有限 梯度消失 自注意力直接建模任意位置交互
并行化能力 逐层卷积 顺序计算 全序列并行计算(自注意力 + FFN)
梯度稳定性 BN 等技巧 门控缓解 缩放点积(Scaled Dot-Product)+ 层归一化

Attention时间复杂度:\(O(n^2d)\)

Attention空间复杂度:\(O(n^2)\),权重矩阵

Q K V
交叉注意力层 解码器中因果注意力层的输出向量 编码器输出的注意力向量 编码器输出的注意力向量
因果注意力层 输出序列中的当前位置词向量 输出序列中的所有位置词向量 输出序列中的所有位置词向量
全局自注意力层 输入序列中的当前位置词向量 输入序列中的当前位置词向量 输入序列中的当前位置词向量

KV cache

Transformer 在推理阶段(Inference)*是*逐个 token 生成的(叫做自回归生成)。

如果每生成一个 token 都重新对前面所有的 token 做 attention,会造成大量重复计算

Transformer 的每一层都有注意力机制。在 Decoder 的每一层,都会有这样的结构:

  1. Masked Multi-Head Self-Attention(遮住未来 token) ← 就是这里用 KV Cache!
  2. Encoder-Decoder Attention(使用 Encoder 输出)
  3. Feed-Forward Layer

例子

比如你在生成文本 "The cat sat on",模型每次只输出一个新词:

第一步:输入第一个词 "The"

  1. 假设 token 1 的 embedding 是\(x_1\)
  2. 计算: \[ Q_1 = x_1 W^Q,\quad K_1 = x_1 W^K,\quad V_1 = x_1 W^V\\ \text{softmax}(Q_1 K_1^T) V_1 \rightarrow 输出 token 2 \]
  3. \(K_1, V_1\) 缓存起来(KV Cache)

第二步:输入第二个词 "cat"

  1. 新的输入是 token 2 的 embedding\(x_2\)
  2. 现在我们只需要计算: \[ Q_2 = x_2 W^Q\\ K_{\text{all}} = [K_1, K_2],\quad V_{\text{all}} = [V_1, V_2]\\ \text{softmax}(Q_2 [K_1, K_2]^T) [V_1, V_2] \] 其中 \(K_2 = x_2 W^K,\ V_2 = x_2 W^V\),也会加入缓存。

第三步:输入第三个词 "sat",流程完全一样:

  1. 只计算\(Q_3 = x_3 W^Q\)

  2. 直接使用之前缓存的 \(K_1, K_2, K_3\)\(V_1, V_2, V_3\) 做 Attention: \[ \text{softmax}(Q_3 [K_1, K_2, K_3]^T) [V_1, V_2, V_3] \]

代码实现

MHA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLA(nn.Module){
def __init__(self,d_model,num_heads):
super.__init__()
assert d_model % num_heads ==0,"d_model 必须被 num_heads 整除"

self.d_model=d_model
self.num_heads=num_heads
self.d_k=d_model//num_heads # 每个head的维度

# QKV的线性变换
self.W_q=nn.Linear(d_model,d_model)
self.W_k=nn.Linear(d_model,d_model)
self.W_v=nn.Linear(d_model,d_model)

# 输出线性层
self.W_0=nn.Linear(d_model,d_model)

def forward(self,q,k,v,mask=None){
# query, key, value: (batch, seq_len, d_model)
B,L_q,_=q.shape
L_k=k.shape[1]

# 线性变换+reshape分头: (B, H, L_q, d_k)
Q=self.W_q(q).view
K=
V=
# 注意力得分:(B, H, L_q, L_k)
scores=torch.matmul(Q,K.transpose(-2,-1))/(self.d_k**0.5)
if Mask is not None:
scores=scores.masked_filled(mask==0,float('-inf'))
# mask = torch.triu(torch.ones(L_q, L_k), diagonal=1).bool() # upper triangular mask


attn=F.softmax(scores,dim=-1) #自注意力权重
context=torch.matmul(attn,V)

context=context.transpose(1,2).contigous().view(B,L_q,self.d_model)
return self.W_o(context)
}


}

# Q from decoder, K/V from encoder
output = mha(decoder_input, encoder_output, encoder_output)

  1. 为什么d_model 必须被 num_heads 整除?

    把总的维度 dmodeld_{}dmodel 平均 分给每个头(head)

  2. Q = self.W_q(query).view(B, L_q, self.num_heads, self.d_k).transpose(1, 2) # (B, H, L_q, d_k)这是在干什么?,K.transpose(-2, -1)中的参数是什么 意思?

    • transpose(1, 2):把 L_qH 互换位置,得到 (B, H, L_q, d_k),这个是注意力标准格式
    • 是对张量的 最后两个维度做转置:假设 K.shape = (B, H, L_k, d_k),那 K.transpose(-2, -1) 变成 (B, H, d_k, L_k)
  3. scores.masked_fill此函数是?

    mask==0 的位置变成 -inf

  4. attn = F.softmax(scores, dim=-1) ,dim=-1是?

    dim=-1 表示在最后一个维度上做 softmax。即:在每个 query 上,对所有的 key 做归一化

  5. 合并 heads context = context.transpose(1, 2).contiguous().view(B, L_q, self.d_model) # (B, L_q, d_model)这个是在?num_heads似乎没有使用。

    • context.shape = (B, H, L_q, d_k)
    • context.transpose(1, 2):(B, L_q, H, d_k)
    • contiguous().view(B, L_q, d_model):合并 H 和 d_k,拼回去
    • 输出是(batch_size, seq_len, d_model),保持与原始输入形状一致

Scaled Dot-Product Attention

1
2
3
4
5
6
7
8
9
10
11
def scaled_dot_product_attention(Q,K,V,mask=None):
# QKV Shape:(batch_size, nums_heads, seq_len, d_k)
scores=torch.matmul(Q,K.transpose(-2,-1))/Q.size(-1)**0.5

if mask is not None:
scores=scores.maskfill(mask==0,float="-inf")

attn=F.softmax(scores,dim=-1)
output = torch.matmul(attn, V)

return output,attn
  1. 为什么Scaled Dot-Product Attention要返回output, attn?
  • output 是最终注意力后的值(输入后续 FFN)
  • attn 是注意力权重(可视化分析、或者做 attention dropout)

交叉熵损失

1
2
3
4
5
6
def cross_entropy(pred,target):
# pred:(batch_size,num_class)
# target:(batch_size,)int类型正确索引
log_probs=pred-torch.logsumexp(pred,dim=1,keepdim=True)
loss=-log_probs[torch.arrange(pred.size(0)),target]
return loss.mean()
  1. 交叉熵损失的代码

(log-probabilities):torch.logsumexp: log softmax \(\sum_i e^{\text{logits}\ {s_i}}\) \[ \log{p_i}=\text{logits}\ {s_i}-\log \sum_j e^{\text{logits}\ {s_j}}\\ \]

搜索
匹配结果数:
未搜索到匹配的文章。