从零构建Transformer:用PyTorch实现编码器与解码器的核心逻辑
在自然语言处理领域,Transformer架构已经成为现代AI系统的基石。但很多学习者在理解其工作原理时陷入了一个怪圈——能够背诵自注意力公式,却无法用代码实现最基本的版本;能解释多头注意力的优势,但面对实际项目时依然无从下手。本文将带你用PyTorch从零开始构建一个简化版Transformer,通过动手实践真正掌握编码器(Encoder)和解码器(Decoder)的核心机制。
1. 环境准备与基础组件
1.1 初始化项目环境
首先确保你的Python环境已安装PyTorch 1.8+版本。我们创建一个干净的虚拟环境:
conda create -n transformer python=3.8 conda activate transformer pip install torch torchtext matplotlib1.2 实现基础构建块
Transformer的核心由几个关键组件构成,我们先实现最基础的版本:
import torch import torch.nn as nn import math class EmbeddingLayer(nn.Module): def __init__(self, vocab_size, d_model): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.d_model = d_model def forward(self, x): return self.embedding(x) * math.sqrt(self.d_model)这个简单的嵌入层已经包含了一个重要细节:初始化时将嵌入值乘以√d_model。这个缩放操作能防止后续注意力计算时的数值爆炸问题——这是许多初学者容易忽略的关键点。
2. 位置编码与自注意力机制
2.1 实现正弦位置编码
Transformer没有循环结构,必须显式地注入位置信息。以下是经典的正弦位置编码实现:
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x): return x + self.pe[:x.size(1)]2.2 构建缩放点积注意力
自注意力机制的核心计算单元如下:
def scaled_dot_product_attention(q, k, v, mask=None): d_k = q.size(-1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = torch.softmax(scores, dim=-1) return torch.matmul(p_attn, v), p_attn注意这里的mask参数对于解码器至关重要——它确保模型在预测当前位置时无法"偷看"未来的信息。
3. 多头注意力实现
3.1 多头机制分解
将注意力分散到多个"头"上,让模型从不同角度学习特征:
class MultiHeadAttention(nn.Module): def __init__(self, h, d_model): super().__init__() assert d_model % h == 0 self.d_k = d_model // h self.h = h self.linears = nn.ModuleList([ nn.Linear(d_model, d_model) for _ in range(4) ]) def forward(self, q, k, v, mask=None): batch_size = q.size(0) # 线性变换并分头 q, k, v = [ lin(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for lin, x in zip(self.linears, (q, k, v)) ] # 计算注意力 x, attn = scaled_dot_product_attention(q, k, v, mask) # 合并多头输出 x = x.transpose(1, 2).contiguous() \ .view(batch_size, -1, self.h * self.d_k) return self.linears[-1](x)3.2 残差连接与层归一化
Transformer的稳定性很大程度上依赖于这两个组件:
class SublayerConnection(nn.Module): def __init__(self, size, dropout): super().__init__() self.norm = nn.LayerNorm(size) self.dropout = nn.Dropout(dropout) def forward(self, x, sublayer): return x + self.dropout(sublayer(self.norm(x)))这种设计使得深层网络训练成为可能,也是Transformer能够堆叠多层的关键。
4. 编码器与解码器架构
4.1 编码器层实现
class EncoderLayer(nn.Module): def __init__(self, size, self_attn, feed_forward, dropout): super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.sublayer = nn.ModuleList([ SublayerConnection(size, dropout) for _ in range(2) ]) self.size = size def forward(self, x, mask): x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) return self.sublayer[1](x, self.feed_forward)4.2 解码器层实现
解码器需要额外的交叉注意力机制来处理编码器输出:
class DecoderLayer(nn.Module): def __init__(self, size, self_attn, src_attn, feed_forward, dropout): super().__init__() self.size = size self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward self.sublayer = nn.ModuleList([ SublayerConnection(size, dropout) for _ in range(3) ]) def forward(self, x, memory, src_mask, tgt_mask): m = memory x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) return self.sublayer[2](x, self.feed_forward)5. 完整模型组装与训练
5.1 模型整合
将各个组件组合成完整的Transformer:
class Transformer(nn.Module): def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): super().__init__() self.encoder = encoder self.decoder = decoder self.src_embed = src_embed self.tgt_embed = tgt_embed self.generator = generator def encode(self, src, src_mask): return self.encoder(self.src_embed(src), src_mask) def decode(self, memory, src_mask, tgt, tgt_mask): return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)5.2 训练技巧与参数设置
训练Transformer时需要注意几个关键点:
- 学习率预热:初始阶段线性增加学习率,之后逐步衰减
- 标签平滑:防止模型对预测结果过度自信
- 梯度裁剪:避免梯度爆炸
optimizer = torch.optim.Adam( model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 ) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda step: min( (step + 1) ** -0.5, (step + 1) * (warmup_steps ** -1.5) ) )6. 实战:字符级语言模型
为了验证我们的实现,我们构建一个简单的字符级语言模型:
# 数据预处理示例 text = "hello transformer" chars = sorted(list(set(text))) char_to_idx = {ch:i for i, ch in enumerate(chars)} idx_to_char = {i:ch for i, ch in enumerate(chars)} # 创建训练样本 def create_samples(text, seq_len=5): samples = [] for i in range(len(text) - seq_len): sample = text[i:i+seq_len] target = text[i+1:i+seq_len+1] samples.append(( torch.tensor([char_to_idx[c] for c in sample]), torch.tensor([char_to_idx[c] for c in target]) )) return samples训练过程中观察注意力权重的变化特别有启发性——你可以清楚地看到模型如何逐步学会关注输入序列中的相关部分。例如在预测"transformer"中的"m"时,模型会重点关注前面的"for"字符组合。
7. 调试与可视化技巧
7.1 注意力权重可视化
理解模型内部运作的关键是观察注意力分布:
import matplotlib.pyplot as plt def plot_attention(attention, input_tokens): fig = plt.figure(figsize=(10, 10)) ax = fig.add_subplot(111) cax = ax.matshow(attention.numpy(), cmap='bone') fig.colorbar(cax) ax.set_xticks(range(len(input_tokens))) ax.set_yticks(range(len(input_tokens))) ax.set_xticklabels(input_tokens, rotation=90) ax.set_yticklabels(input_tokens) plt.show()7.2 常见问题排查
初学者常遇到的几个典型问题:
- 梯度消失/爆炸:检查层归一化和残差连接是否正确实现
- 过拟合:调整dropout率(通常0.1-0.3之间)
- 训练不稳定:尝试降低学习率或使用预热策略
- 预测结果重复:可能是解码器mask实现有误
8. 性能优化与扩展
8.1 内存效率优化
当处理长序列时,可以优化注意力计算:
# 内存高效的注意力计算 def memory_efficient_attention(q, k, v, mask=None): d_k = q.size(-1) scores = torch.einsum('bhid,bhjd->bhij', q, k) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = torch.softmax(scores, dim=-1) return torch.einsum('bhij,bhjd->bhid', p_attn, v)8.2 扩展到实际应用
要将这个简化版Transformer扩展到实际NLP任务,需要考虑:
- 批处理优化:实现padding和masking
- 词汇表处理:使用子词分词(BPE/WordPiece)
- 预训练策略:实现MLM和NSP目标
- 混合精度训练:使用torch.cuda.amp
# 批处理示例 def collate_fn(batch): src_batch, tgt_batch = zip(*batch) src_len = max(len(x) for x in src_batch) tgt_len = max(len(x) for x in tgt_batch) src_padded = torch.zeros(len(batch), src_len).long() tgt_padded = torch.zeros(len(batch), tgt_len).long() for i, (src, tgt) in enumerate(zip(src_batch, tgt_batch)): src_padded[i, :len(src)] = src tgt_padded[i, :len(tgt)] = tgt return src_padded, tgt_padded通过这个从零实现的旅程,你会发现Transformer不再是一个神秘的"黑箱",而是一系列精心设计的组件的有序组合。每个技术选择——从位置编码到残差连接——都有其明确的目的和数学依据。这种深入理解将帮助你在实际项目中灵活应用和调整Transformer架构,而不仅仅是机械地调用现成的库函数。