2025-6-27-对于nn.Embeddding的理解
对于nn.Embeddding的理解
nn.Embedding(num_embeddings, embedding_dim)是一个索引表。根据输入,给出对应索引的向量。
num_embeddings→ 索引的长度。embedding_dim→ 索引idx对应的向量,维度是:embedding_dim x 1。
与nn.Linear的区别
我一开始初看此概念,会觉得比较类似于nn.Linear,现在看来区别非常大?这样。
虽然都是给定输入、得到输出。但是nn.Embedding是去查自己的索引表(本身就有的)给出输出;而nn.Linear则是根据公式$\mathbf{w^T x+b}$计算得到的输出。
| 特点 | nn.Embedding |
nn.Linear |
|---|---|---|
| 输入 | 整数(索引) | 向量 |
| 本质 | 查表(取向量) | 线性变换(矩阵乘法) |
| 输出 | 向量 | 向量 |
结合例子理解
假设有 5 个单词,分别编号为 0~4,现在你想用 nn.Embedding 把它们变成“向量”,然后送进神经网络进行后续处理。首先定义一个 Embedding 层(省略导入库):
1 | embedding = nn.Embedding(num_embeddings=5, embedding_dim=3) |
num_embeddings=5:说明我们最多支持 5 个词(编号 0~4)embedding_dim=3:每个词都表示成长度为 3 的向量
本质上它内部是一个 5×3 的矩阵,就像这样(随机初始化):
1 | [[ 0.1, -0.2, 0.3], # 第0个词 |
现在我们模拟一个句子作为输入的“索引序列”,比如是词编号 [1, 2, 4],用 torch.LongTensor 来表示:
1 | input_ids = torch.LongTensor([1, 2, 4]) |
🎯 注意:Embedding 只能接受整数(Long 类型),表示索引!
我们把这组索引喂进 embedding 层:
1 | output = embedding(input_ids) |
Embedding 层如何更新?
就像 nn.Linear 一样,nn.Embedding 的 weight 也是参数,可以随着训练更新。
1 | print(embedding.weight.requires_grad) # True |
训练时,这个向量表会学到更有意义的表示。