2025-6-27-对于nn.Embeddding的理解
对于nn.Embeddding
的理解
nn.Embedding(num_embeddings, embedding_dim)
是一个索引表。根据输入,给出对应索引的向量。
num_embeddings
→ 索引的长度。embedding_dim
→ 索引idx对应的向量,维度是:embedding_dim x 1
。
与nn.Linear
的区别
我一开始初看此概念,会觉得比较类似于nn.Linear
,现在看来区别非常大?这样。
虽然都是给定输入、得到输出。但是nn.Embedding
是去查自己的索引表(本身就有的)给出输出;而nn.Linear
则是根据公式\(\mathbf{w^T
x+b}\)计算得到的输出。
特点 | nn.Embedding |
nn.Linear |
---|---|---|
输入 | 整数(索引) | 向量 |
本质 | 查表(取向量) | 线性变换(矩阵乘法) |
输出 | 向量 | 向量 |
结合例子理解
假设有 5 个单词,分别编号为 0~4,现在你想用
nn.Embedding
把它们变成“向量”,然后送进神经网络进行后续处理。首先定义一个
Embedding
层(省略导入库):
1 | embedding = nn.Embedding(num_embeddings=5, embedding_dim=3) |
num_embeddings=5
:说明我们最多支持 5 个词(编号 0~4)embedding_dim=3
:每个词都表示成长度为 3 的向量
本质上它内部是一个 5×3
的矩阵,就像这样(随机初始化):
1 | [[ 0.1, -0.2, 0.3], # 第0个词 |
现在我们模拟一个句子作为输入的“索引序列”,比如是词编号
[1, 2, 4]
,用 torch.LongTensor
来表示:
1 | input_ids = torch.LongTensor([1, 2, 4]) |
🎯 注意:Embedding
只能接受整数(Long
类型),表示索引!
我们把这组索引喂进 embedding
层:
1 | output = embedding(input_ids) |
Embedding 层如何更新?
就像 nn.Linear
一样,nn.Embedding
的 weight
也是参数,可以随着训练更新。
1 | print(embedding.weight.requires_grad) # True |
训练时,这个向量表会学到更有意义的表示。