1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| import torch import torch.nn as nn from torch.nn import TransformerEncoder, TransformerEncoderLayer from torch.nn import TransformerDecoder, TransformerDecoderLayer
class TransformerEncoder(nn.Module): def __init__(self, d_model, nhead, num_layers, dim_feedforward=2048, dropout=0.1): super().__init__() encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True) self.encoder = TransformerEncoder(encoder_layer, num_layers) def forward(self, src): return self.encoder(src)
class RLTransformerDecoder(nn.Module): """ 强化学习增强的Transformer解码器。 每层解码器在更新查询时,会生成高斯分布的参数,并采样新的查询。 """ def __init__(self, d_model, nhead, num_layers, dim_feedforward=2048, dropout=0.1): super().__init__() self.layers = nn.ModuleList([ RLTransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers) ]) self.num_layers = num_layers
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): intermediate = [] for i, layer in enumerate(self.layers): tgt, mu, log_var = layer(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask) intermediate.append((tgt, mu, log_var))
return tgt, intermediate
class RLTransformerDecoderLayer(nn.Module): """ 单层解码器,包含: - 自注意力 - 交叉注意力 - FFN 然后生成高斯分布的参数(mu, log_var),并采样新的查询。 """
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = nn.ReLU()
self.mu_head = nn.Linear(d_model, d_model) self.log_var_head = nn.Linear(d_model, d_model)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) mu = self.mu_head(tgt) log_var = self.log_var_head(tgt) if self.training: std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) new_tgt = mu + eps * std else: new_tgt = mu
return new_tgt, mu, log_var
|