news 2026/5/6 3:24:30

从零实现Transformer:第 6 部分 - 解码器(The Decoder)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从零实现Transformer:第 6 部分 - 解码器(The Decoder)

从零实现Transformer:第 6 部分 - 解码器(The Decoder)

flyfish

在实现编码器后,本部分将构建 Transformer 的解码器,包含掩码自注意力、编码器-解码器交叉注意力。


Encoder 和 Decoder 展开就是

要实现该图像的右侧部分

编码器输出 + 已生成序列 → [解码器] → 下一个 token 的概率分布

Decoder Block:三层结构(看右边)

将三层结构放到整个Encoder中
这里的注意力都是多头,举个例子
Multi-Headed Cross-Attention = Encoder-Decoder Attention

注意力是多头

单头的样子

多头的样子

数据流向图

输入: 目标序列嵌入 x [batch, tgt_len, d_model] 编码器输出: enc_out [batch, src_len, d_model] │ ▼ ┌─────────────────────────┐ │ 子层1: 掩码多头自注意力 │ ← tgt_mask (前瞻+填充) │ Masked Self-Attention │ ← Q=K=V=x └─────────────────────────┘ │ ▼ ┌─────────────────────────┐ │ Add & Norm │ │ LayerNorm(x + Dropout) │ └─────────────────────────┘ │ ▼ ┌─────────────────────────┐ │ 子层2: 交叉注意力 │ ← src_mask (源序列填充) │ Encoder-Decoder Attention│ ← Q=上一步输出, K/V=enc_out └─────────────────────────┘ │ ▼ ┌─────────────────────────┐ │ Add & Norm │ │ LayerNorm(x + Dropout) │ └─────────────────────────┘ │ ▼ ┌─────────────────────────┐ │ 子层3: 位置前馈网络 │ │ Position-wise FFN │ └─────────────────────────┘ │ ▼ ┌─────────────────────────┐ │ Add & Norm │ │ LayerNorm(x + Dropout) │ └─────────────────────────┘ │ ▼ 输出: [batch, tgt_len, d_model] → 下一个 DecoderBlock

掩码自注意力(Masked Self-Attention)

# 为什么需要前瞻掩码?# 训练时:防止"偷看"未来答案,保证学习因果依赖# 推理时:保证自回归生成,每次只依赖已输出内容# 掩码组合:前瞻掩码 + 填充掩码look_ahead=torch.triu(torch.ones(T,T),diagonal=1).bool()# 上三角为1padding=(tgt==pad_idx).unsqueeze(1).unsqueeze(2)# 填充位置为1tgt_mask=look_ahead|padding# 逻辑或:任一条件满足即屏蔽

交叉注意力(Encoder-Decoder Attention)

# 思想:解码器每个位置"查询"编码器所有位置# Q 来自解码器:当前要预测的 token 表示# K/V 来自编码器:源序列的上下文表示# 类比翻译任务:# 生成英文 "cat" 时,解码器通过交叉注意力"关注"中文输入中的 "猫"

掩码形状与广播规则

# 自注意力掩码:[batch, 1, tgt_len, tgt_len]# - 第3维:查询位置(目标序列)# - 第4维:键位置(目标序列)# 交叉注意力掩码:[batch, 1, 1, src_len]# - 第3维:查询位置(目标序列,广播到 tgt_len)# - 第4维:键位置(源序列)# 广播后都变为:[batch, num_heads, tgt_len, src/tgt_len] ✓

DecoderBlock 代码实现

importtorchimporttorch.nnasnn# MultiHeadAttention# PositionwiseFeedForward# ResidualConnectionclassDecoderBlock(nn.Module):""" Transformer 解码器单块实现 包含:掩码自注意力 + 交叉注意力 + 前馈网络,均包裹在 Add & Norm 中 """def__init__(self,features:int,self_attention_block:MultiHeadAttention,cross_attention_block:MultiHeadAttention,feed_forward_block:PositionwiseFeedForward,dropout:float):super().__init__()self.self_attention_block=self_attention_block# 子层1self.cross_attention_block=cross_attention_block# 子层2self.feed_forward_block=feed_forward_block# 子层3# 三个残差连接模块self.residual_connections=nn.ModuleList([ResidualConnection(features,dropout)for_inrange(3)])defforward(self,x:torch.Tensor,encoder_output:torch.Tensor,src_mask:torch.Tensor,tgt_mask:torch.Tensor):""" Args: x: 目标序列输入 [batch, tgt_len, d_model] encoder_output: 编码器输出 [batch, src_len, d_model] src_mask: 源序列掩码 [batch, 1, 1, src_len] tgt_mask: 目标序列掩码 [batch, 1, tgt_len, tgt_len] Returns: output: [batch, tgt_len, d_model] """# 子层1: 掩码自注意力 + Add & Norm# Q=K=V=x,使用 tgt_mask 屏蔽未来位置和 paddingx=self.residual_connections[0](x,lambdax_res:self.self_attention_block(x_res,x_res,x_res,tgt_mask))# 子层2: 交叉注意力 + Add & Norm# Q=当前输出x, K/V=编码器输出,使用 src_mask 屏蔽源序列 paddingx=self.residual_connections[1](x,lambdax_res:self.cross_attention_block(x_res,encoder_output,encoder_output,src_mask))# 子层3: 前馈网络 + Add & Normx=self.residual_connections[2](x,self.feed_forward_block)returnx

代码解析

# 为什么需要两个不同的 MultiHeadAttention 实例?# 虽然结构相同,但自注意力和交叉注意力的参数需要独立学习self_attention=MultiHeadAttention(...)# 学习目标序列内部关系cross_attention=MultiHeadAttention(...)# 学习目标-源序列对齐关系# lambda 函数的作用?# ResidualConnection 期望 sublayer(x) 单参数接口# 但 MultiHeadAttention 需要 (Q, K, V, mask)# lambda 固定部分参数,创建兼容接口lambdax_res:self.cross_attention_block(x_res,encoder_output,encoder_output,src_mask)# 等价于:# def cross_sublayer(q):# return self.cross_attention_block(q, encoder_output, encoder_output, src_mask)

Decoder Stack:堆叠多个 Block

整体架构

目标嵌入 + 位置编码 │ ▼ ┌────────────────┐ │ DecoderBlock │ ← 第1层 │ (3子层+掩码) │ └────────────────┘ │ ▼ ┌────────────────┐ │ DecoderBlock │ ← 第2层 │ (接收上层输出) │ └────────────────┘ │ ▼ ⋮ │ ▼ ┌────────────────┐ │ DecoderBlock │ ← 第N层 (N=6) └────────────────┘ │ ▼ ┌────────────────┐ │ LayerNorm │ ← 最终归一化 └────────────────┘ │ ▼ 输出: [batch, tgt_len, d_model] → 线性层 → 词汇表概率

Decoder 代码实现

importtorchimporttorch.nnasnn# LayerNormalization# DecoderBlockclassDecoder(nn.Module):""" 完整 Transformer 解码器:堆叠 N 个 DecoderBlock """def__init__(self,features:int,layers:nn.ModuleList):""" Args: features: d_model,用于最终 LayerNorm layers: 预初始化的 DecoderBlock 列表 """super().__init__()self.layers=layers# N 个 DecoderBlockself.norm=LayerNormalization(features)# 最终归一化defforward(self,x:torch.Tensor,encoder_output:torch.Tensor,src_mask:torch.Tensor,tgt_mask:torch.Tensor):""" Args: x: 目标输入 [batch, tgt_len, d_model] encoder_output: 编码器输出 [batch, src_len, d_model] src_mask: 源序列掩码 tgt_mask: 目标序列掩码 Returns: output: [batch, tgt_len, d_model] """# 逐层传递,所有参数透传forlayerinself.layers:x=layer(x,encoder_output,src_mask,tgt_mask)returnself.norm(x)

完整测试代码

importtorchimporttorch.nnasnn# MultiHeadAttention# PositionwiseFeedForward# DecoderBlock# Decoderprint("\n--- 测试 Transformer 解码器 ---")# 参数配置batch_size=4src_len,tgt_len=12,10# 源/目标序列长度d_model,num_heads=512,8d_ff,num_layers=2048,6# 4×d_model, N=6dropout=0.1# 🎲 创建虚拟输入dummy_tgt=torch.randn(batch_size,tgt_len,d_model)# 目标嵌入dummy_enc_out=torch.randn(batch_size,src_len,d_model)# 编码器输出# 创建掩码# 源序列掩码 (填充)src_mask=torch.zeros(batch_size,1,1,src_len,dtype=torch.bool)src_mask[0,:,:,-3:]=True# 示例:屏蔽末尾# 目标序列掩码 (前瞻 + 填充)look_ahead=torch.triu(torch.ones(tgt_len,tgt_len),diagonal=1).bool()look_ahead=look_ahead.unsqueeze(0).unsqueeze(0)# [1, 1, T, T]tgt_padding=torch.zeros(batch_size,1,1,tgt_len,dtype=torch.bool)tgt_padding[0,:,:,-1:]=Truetgt_mask=look_ahead|tgt_padding# 逻辑或组合print(f"目标输入:{dummy_tgt.shape}")# [4, 10, 512]print(f"编码器输出:{dummy_enc_out.shape}")# [4, 12, 512]print(f"src_mask:{src_mask.shape}")# [4, 1, 1, 12]print(f"tgt_mask:{tgt_mask.shape}")# [4, 1, 10, 10]# 🔧 构建组件(注意:自注意力和交叉注意力需要独立实例)self_attn=MultiHeadAttention(d_model,num_heads,dropout)cross_attn=MultiHeadAttention(d_model,num_heads,dropout)# 独立参数!ffn=PositionwiseFeedForward(d_model,d_ff,dropout)# 创建 DecoderBlock 堆叠decoder_layers=nn.ModuleList([DecoderBlock(d_model,self_attn,cross_attn,ffn,dropout)for_inrange(num_layers)])# 实例化解码器decoder=Decoder(d_model,decoder_layers)# 前向传播dec_output=decoder(dummy_tgt,dummy_enc_out,src_mask,tgt_mask)# 验证输出print(f"解码器输出:{dec_output.shape}")# [4, 10, 512] ✓assertdec_output.shape==dummy_tgt.shapeprint("解码器测试通过!")

掩码创建实用函数

defcreate_tgt_mask(tgt,pad_idx=0,device='cpu'):""" 创建解码器目标序列掩码(前瞻 + 填充) Args: tgt: 目标 token IDs, [batch, tgt_len] pad_idx: padding token 索引 Returns: mask: [batch, 1, tgt_len, tgt_len], bool 类型 """batch_size,tgt_len=tgt.shape# 1️前瞻掩码:屏蔽未来位置look_ahead=torch.triu(torch.ones(tgt_len,tgt_len,device=device),diagonal=1).bool().unsqueeze(0).unsqueeze(0)# [1, 1, T, T]# 2️填充掩码:屏蔽 <pad> 位置padding=(tgt==pad_idx).unsqueeze(1).unsqueeze(2)# [B, 1, 1, T]# 3️组合:任一条件满足即屏蔽mask=look_ahead|padding# 广播后 [B, 1, T, T]returnmaskdefcreate_src_mask(src,pad_idx=0):"""创建编码器/交叉注意力的源序列掩码"""return(src==pad_idx).unsqueeze(1).unsqueeze(2)# [B, 1, 1, src_len]

下一步预告:组装完整的 Transformer

第 7 部分 将编码器(The Encoder)和解码器(The Decoder)组装成完整 Transformer

输入/输出嵌入层 + 位置编码 编码器 + 解码器堆叠连接 最终线性层 + Softmax → 词汇表概率 完整的 forward 流程:src → enc → dec → logits

完整模型结构预览

源序列 src │ ▼ [输入嵌入 + 位置编码] → Encoder堆叠 → enc_output │ ▼ 目标序列 tgt (训练时) ──→ [输出嵌入 + 位置编码] │ ▼ Decoder堆叠 (接收 enc_output + 双掩码) │ ▼ [线性层: d_model → vocab_size] │ ▼ [Softmax] → 下一个 token 概率分布
# 伪代码预览classTransformer(nn.Module):def__init__(self,encoder,decoder,src_embed,tgt_embed,generator):self.encoder=encoder#self.decoder=decoder#self.src_embed=src_embed# 源嵌入+位置编码self.tgt_embed=tgt_embed# 目标嵌入+位置编码self.generator=generator# 线性+softmaxdefforward(self,src,tgt,src_mask,tgt_mask):enc_out=self.encoder(self.src_embed(src),src_mask)dec_out=self.decoder(self.tgt_embed(tgt),enc_out,src_mask,tgt_mask)returnself.generator(dec_out)# [batch, tgt_len, vocab_size]
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/6 3:19:32

如何快速上手GI-Model-Importer:原神角色模型自定义终极指南

如何快速上手GI-Model-Importer&#xff1a;原神角色模型自定义终极指南 【免费下载链接】GI-Model-Importer Tools and instructions for importing custom models into a certain anime game 项目地址: https://gitcode.com/gh_mirrors/gi/GI-Model-Importer GI-Model…

作者头像 李华
网站建设 2026/5/6 3:18:36

PCB设计-器件:1.电容

一、基本认识符号&#xff1a;C单位&#xff1a;F法拉封装&#xff1a;二、电气特性&#xff08;一&#xff09;不可突变电容两端的相对电压不能突变&#xff0c;在通电的一瞬间&#xff0c;电容相当于导线。然而&#xff0c;可以在保持相对值不变的同时&#xff0c;一起突变。…

作者头像 李华
网站建设 2026/5/6 3:15:40

自修改策略与PAC学习边界的动态优化实践

1. 项目概述在机器学习领域&#xff0c;自修改策略&#xff08;Self-Modifying Strategies&#xff09;与PAC&#xff08;Probably Approximately Correct&#xff09;学习边界的交叉研究&#xff0c;正逐渐成为算法优化和理论分析的前沿方向。这个看似抽象的组合&#xff0c;实…

作者头像 李华
网站建设 2026/5/6 3:13:37

SVG-EAR技术:无参数线性补偿在视频生成中的应用

1. 项目背景与核心价值在视频内容创作领域&#xff0c;稀疏视频生成技术正逐渐成为提升生产效率的关键手段。这种技术通过智能识别视频中的关键帧&#xff0c;仅对变化部分进行渲染&#xff0c;从而大幅减少计算资源消耗。然而传统方法往往面临两个核心痛点&#xff1a;一是需要…

作者头像 李华
网站建设 2026/5/6 3:10:31

XCursor主题开发指南:从设计到部署的完整实践

1. 项目概述&#xff1a;一个为现代桌面注入灵魂的指针主题如果你和我一样&#xff0c;每天有超过8小时的时间与电脑屏幕为伴&#xff0c;那么桌面环境的每一个细节&#xff0c;都直接影响着你的工作效率和心情。显示器、壁纸、字体&#xff0c;这些我们常常会花心思去调整&…

作者头像 李华