欢迎关注『youcans动手学 AI』系列
【动手学UNet】(1)创建 UNet项目
【动手学UNet】(2)数据加载
【动手学UNet】(7)主程序
【动手学UNet】(11)创建Unet_V2 项目
【动手学UNet】(12)Unet_V2 模型实现
【动手学UNet】(12)Unet_V2 模型实现
- 15.4 数据加载模块(data_utils_v2.py)
- 15.5 UNet v2 模型实现(Encoder + SDI + Decoder)
欢迎来到【动手学UNet】系列教程!
本系列将带你从零开始,一步步深入理解和实现经典的UNet图像分割模型。无论你是深度学习初学者,还是有一定经验的开发者,这个系列都将为你提供全面而实用的UNet知识。
15.4 数据加载模块(data_utils_v2.py)
1、图像预处理模块(utils/preprocess_v2.py)
preprocess_v2.py是所有 Retina 图像预处理操作的集中模块(绿色通道 + CLAHE),提供统一的医学图像预处理。模型训练和推理(包括单张推理)可以复用相同的预处理逻辑。
- 提供一个独立预处理类:RetinalPreprocessorV2;
- 支持 “绿色通道提取”;
- 支持 CLAHE(局部直方图均衡);
- 参数可调,可由 config_v2 控制;
- 最终输出统一为单通道 PIL.Image。
2、数据加载模块(utils/data_utils_v2.py)
data_utils.py是 UNetv2_Retina 的数据入口,自动完成:文件读取 → 预处理 → resize → 归一化 → Tensor 输出。
- 创建 RetinaDatasetV2:通用的数据加载类;
- 自动按照文件名匹配图像与掩膜;
- 调用 RetinalPreprocessorV2(绿色通道 + CLAHE);
- 控制预处理、通道数、图像大小;
- 预留 transform 钩子(兼容 Albumentations);
- Hermetic 设计:返回训练可直接使用的 Tensor。
此前已经建立了data_utils.py脚本用于占位和测试,现在用如下脚本替换原来的测试版本。
3、测试数据加载(test3.py)
为了测试data_utils.py,在项目根目录编写一个测试程序test3.py。
- 构建训练集 Dataset
- 从 dataset/train/images & dataset/train/masks 中取出一张样本,检查形状
- 构建 DataLoader
- 从 DataLoader 中取出一个 batch
运行测试程序时,要在dataset/train/images/和dataset/train/masks/目录下至少各放一张图片。
15.5 UNet v2 模型实现(Encoder + SDI + Decoder)
目标:在 model/ 目录下实现 U-Net v2 的三大模块 & 总体网络封装,清晰体现 多层特征提取 + SDI 融合 + 解码重建 的流程。
1、多层级编码器(model/encoder_v2.py)encoder_v2.py负责实现 UNet v2 的编码器部分,是整张视网膜图像特征提取的入口模块,用于逐层提取由浅入深的多尺度特征图,为后续 SDI 注入和解码器重建提供语义与细节信息基础。
- 提供多层级编码器类:EncoderV2;
- 采用与经典 U-Net 一致的结构:DoubleConv 卷积块 + Down 下采样块;
- 从输入图像中逐级提取特征,输出 [f1, f2, f3, f4, f5] 五个层级的特征图;
4. 保持高分辨率浅层特征用于细节保留、低分辨率深层特征用于语义表达; - 将各层通道数记录在 self.channels 中,为 SDI 模块和 Decoder 提供统一的通道配置。
输入:x, [B, C_in, H, W]
输出:特征图列表 [f1, f2, f3, f4, f5],分别对应从浅到深的不同层级。
# model/encoder_v2.py""" UNet v2 编码器(EncoderV2) - 与经典 U-Net 类似,采用多层级卷积 + 下采样结构 - 输出一系列特征图 [f1, f2, f3, f4, f5],从浅到深 - youcans@qq.com """fromtypingimportListimporttorchimporttorch.nnasnnclassDoubleConv(nn.Module):""" 两次 3x3 卷积 + BN + ReLU 的基本卷积块。 """def__init__(self,in_ch:int,out_ch:int):super().__init__()self.net=nn.Sequential(nn.Conv2d(in_ch,out_ch,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True),nn.Conv2d(out_ch,out_ch,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True),)defforward(self,x:torch.Tensor)->torch.Tensor:returnself.net(x)classDown(nn.Module):""" 下采样块:MaxPool2d(2) + DoubleConv """def__init__(self,in_ch:int,out_ch:int):super().__init__()self.pool=nn.MaxPool2d(kernel_size=2,stride=2)self.conv=DoubleConv(in_ch,out_ch)defforward(self,x:torch.Tensor)->torch.Tensor:x=self.pool(x)x=self.conv(x)returnxclassEncoderV2(nn.Module):""" UNet v2 编码器: 输入:x, [B, C_in, H, W] 输出:一个特征图列表 [f1, f2, f3, f4, f5] - f1: 最浅层(高分辨率,细节丰富) - f5: 最深层(低分辨率,高语义) """def__init__(self,in_channels:int=1,base_channels:int=64):super().__init__()# 通道设置:经典 U-Net 风格c1=base_channels c2=base_channels*2c3=base_channels*4c4=base_channels*8c5=base_channels*16self.inc=DoubleConv(in_channels,c1)self.down1=Down(c1,c2)self.down2=Down(c2,c3)self.down3=Down(c3,c4)self.down4=Down(c4,c5)self.channels=[c1,c2,c3,c4,c5]defforward(self,x:torch.Tensor)->List[torch.Tensor]:f1=self.inc(x)# [B, c1, H, W ]f2=self.down1(f1)# [B, c2, H/2, W/2]f3=self.down2(f2)# [B, c3, H/4, W/4]f4=self.down3(f3)# [B, c4, H/8, W/8]f5=self.down4(f4)# [B, c5, H/16,W/16]return[f1,f2,f3,f4,f5]2、SDI模块(model/sdi_module.py)
sdi_module.py实现语义与细节注入模块(SDI,Semantic & Detail Infusion),通过从高层特征中提炼语义信息并注入到浅层/中层特征中,使网络在保持细节分辨率的同时增强语义表达能力。
- 提供 SDI 模块类:SDIModule;
- 使用最高层特征 f5 作为语义源;
- 对于指定层级 i,将 f5 经过 1x1 卷积 + 上采样到与 f_i 相同尺寸;
- 通过 Sigmoid生成语义门控语义图sem_gate;
- 采用哈达玛积(Hadamard product)对指定层级的特征进行语义增强:
f_i’ = f_i * (1 + sem_gate)
6. 是否启用、作用于哪些层级等行为,可以通过 feat_channels、sdi_levels、mode 进行配置,并由 config_v2 统一管理。
# model/sdi_module.py""" UNet v2 的 SDI 模块(Semantic & Detail Infusion) - 使用最高层特征 f5 作为语义源 - 对于指定层级 i,将 f5 经过 1x1 卷积 + 上采样到与 f_i 相同尺寸 - 生成语义门控语义图 sem_gate = sigmoid(conv(upsampled_f5)) - 利用哈达玛积(Hadamard product)进行增强: f_i' = f_i * (1 + sem_gate) - 未启用的层级保持原样返回 - youcans(at)qq.com """fromtypingimportList,Sequenceimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFclassSDIModule(nn.Module):def__init__(self,feat_channels:Sequence[int],sdi_levels:Sequence[int]=(1,2,3,4),mode:str="hadamard",):""" :param feat_channels: 每层特征图通道数列表,例如 [c1, c2, c3, c4, c5] :param sdi_levels: 需要进行 SDI 注入的层级索引(1-based),如 (1,2,3,4) :param mode: 当前保留占位,默认为 'hadamard' """super().__init__()assertlen(feat_channels)>=2,"feat_channels 至少需要两个层级"self.feat_channels=list(feat_channels)self.sdi_levels=list(sdi_levels)self.mode=mode# 最高层特征 f5 的通道数self.top_channels=feat_channels[-1]# 为每个层级 i 创建一个 1x1 卷积,将 top feature 映射到该层通道数self.proj_convs=nn.ModuleDict()fori,cinenumerate(feat_channels[:-1],start=1):name=f"proj_to_l{i}"self.proj_convs[name]=nn.Conv2d(self.top_channels,c,kernel_size=1)defforward(self,feats:List[torch.Tensor])->List[torch.Tensor]:""" :param feats: [f1, f2, f3, f4, f5] :return: [f1',f2',f3',f4',f5],其中部分层经过 SDI 增强 """assertlen(feats)==len(self.feat_channels),"输入特征层数与 feat_channels 不匹配"f_top=feats[-1]# f5B,C_top,H_top,W_top=f_top.shape out_feats:List[torch.Tensor]=[]fori,finenumerate(feats,start=1):ifi==len(feats):# 最深层(f5)通常作为语义源,可直接保留out_feats.append(f)continueifinotinself.sdi_levels:# 不在 SDI 作用层级中,保持原样out_feats.append(f)continue# ---- SDI 注入过程 ----# 1) 将 f_top 投影到与 f_i 相同通道数proj_conv=self.proj_convs[f"proj_to_l{i}"]sem_feat=proj_conv(f_top)# [B, c_i, H_top, W_top]# 2) 上采样到 f_i 相同空间尺寸_,_,H_i,W_i=f.shape sem_feat_up=F.interpolate(sem_feat,size=(H_i,W_i),mode="bilinear",align_corners=False)# 3) 生成语义 gatesem_gate=torch.sigmoid(sem_feat_up)# [B, c_i, H_i, W_i]ifself.mode=="hadamard":# Hadamard product + 残差增强f_enhanced=f*(1.0+sem_gate)else:# 其他模式可以后续扩展,这里简单做加法占位f_enhanced=f+sem_gate out_feats.append(f_enhanced)returnout_feats3、解码器(model/decoder_v2.py)
decoder_v2.py实现 UNet v2 的解码器模块,负责将多层级特征逐步上采样并融合浅层跳跃连接,实现空间分辨率的重建和最终的像素级分割输出,整体风格与经典 U-Net 解码路径保持一致。
- 提供上采样块 Up 和解码器类 DecoderV2;
- 通过转置卷积(ConvTranspose2d)逐级将特征图放大 2 倍;
- 在每一层上采样后,与对应编码器特征(如 f4、f3、f2、f1)进行 skip-connection 拼接;
- 使用 DoubleConv 对拼接后的特征进行融合,逐层恢复细节与结构信息;
- 最终通过 1×1 卷积输出分割 logits(通道数等于 num_classes),用于后续 Sigmoid / Softmax 等激活与损失计算。
# model/decoder_v2.py""" UNet v2 解码器(DecoderV2) 与经典 U-Net 解码路径类似: - 逐步上采样 - 与对应层级的 encoder 特征做 skip-connection - 通过 DoubleConv 融合特征 - youcans(at)qq.com """fromtypingimportListimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfrom.encoder_v2importDoubleConvclassUp(nn.Module):""" 上采样块: - 采用 ConvTranspose2d 做 2x 上采样 - 与对应 encoder 特征拼接(channel 方向) - 再通过 DoubleConv 融合 """def__init__(self,in_ch:int,out_ch:int):""" :param in_ch: 解码特征的通道数(上采样前) :param out_ch: 上采后 + 拼接后,再经卷积得到的输出通道数 """super().__init__()# 上采样:通道数减半,尺寸 ×2self.up=nn.ConvTranspose2d(in_ch,in_ch//2,kernel_size=2,stride=2)# 上采样后会与 skip 特征拼接,拼接后通道数 = in_ch//2 + skip_ch# DoubleConv 里具体 in/out 通道设置在 forward 前根据实际拼接情况决定# 这里采用经典 U-Net 的写法:用 in_ch 作为 DoubleConv 的输入通道self.conv=DoubleConv(in_ch,out_ch)defforward(self,x:torch.Tensor,skip:torch.Tensor)->torch.Tensor:x=self.up(x)# [B, in_ch//2, H*2, W*2]# 对齐尺寸(某些情况下可能因奇偶问题产生 1 像素差异)diff_y=skip.size(2)-x.size(2)diff_x=skip.size(3)-x.size(3)ifdiff_y!=0ordiff_x!=0:x=F.pad(x,[diff_x//2,diff_x-diff_x//2,diff_y//2,diff_y-diff_y//2])# 通道维拼接x=torch.cat([skip,x],dim=1)# [B, skip_ch + in_ch//2, H, W]x=self.conv(x)returnxclassDecoderV2(nn.Module):""" UNet v2 解码器: 输入:来自 Encoder (+ SDI) 的多层特征 [f1', f2', f3', f4', f5'] 输出:分割 logits [B, num_classes, H, W] """def__init__(self,base_channels:int=64,num_classes:int=1):super().__init__()c1=base_channels c2=base_channels*2c3=base_channels*4c4=base_channels*8c5=base_channels*16# 对应 Encoder 的通道设置:# f1: c1, f2: c2, f3: c3, f4: c4, f5: c5self.up1=Up(c5,c4)# f5 -> f4self.up2=Up(c4,c3)# ...self.up3=Up(c3,c2)self.up4=Up(c2,c1)self.out_conv=nn.Conv2d(c1,num_classes,kernel_size=1)defforward(self,feats:List[torch.Tensor])->torch.Tensor:""" :param feats: [f1', f2', f3', f4', f5'] :return: logits [B, num_classes, H, W] """assertlen(feats)==5,"DecoderV2 目前假定有 5 个层级特征"f1,f2,f3,f4,f5=feats# 注意:f1 分辨率最高,f5 最低x=self.up1(f5,f4)# -> 类似 f4 空间尺寸x=self.up2(x,f3)x=self.up3(x,f2)x=self.up4(x,f1)logits=self.out_conv(x)returnlogits4、UNetV2 模型(unetv2.py)
unetv2.py将 Encoder、SDI 模块与 Decoder 组合成一个完整的 UNet v2 模型类 UNetV2,是训练与推理阶段真正被调用的主网络结构,实现端到端的“图像输入 → 分割输出”。
- 提供整体网络类:UNetV2;
- 内部集成 EncoderV2、SDIModule(可选)与 DecoderV2,形成统一的前向计算流程;
- 支持从 config_v2 读取模型关键参数(输入通道数、类别数、base_channels、是否启用 SDI、SDI 层级等),实现配置集中管理;
- 前向过程中先通过编码器获得多层特征,再选择性通过 SDI 模块进行语义与细节注入,最后交由解码器重建输出;
- 输出为大小与输入一致、通道数为 num_classes 的分割 logits,可直接接入损失函数和评估指标,用于视网膜血管等医学图像分割任务。
forward(x) 内部流程如下:
(1)使用 Encoder 提取特征:[f1, f2, f3, f4, f5]
(2)使用 SDI 模块生成融合特征:[f1’, f2’, f3’, f4’, f5’]
(3)使用 Decoder 从高层到低层逐步重建高分辨率分割图;
最终输出 [B, num_classes, H, W] 的分割概率/ logit。
# model/unetv2.py""" UNetV2 主模型: - EncoderV2: 提取多层级特征 [f1..f5] - SDIModule: 在中间对特征进行语义注入(可选) - DecoderV2: 逐级上采样,输出分割结果 - youcans(at)qq.com """fromtypingimportSequence,Listimporttorchimporttorch.nnasnnfromcore.config_v2importcfg_v2from.encoder_v2importEncoderV2from.sdi_moduleimportSDIModulefrom.decoder_v2importDecoderV2classUNetV2(nn.Module):def__init__(self,in_channels:int=None,num_classes:int=None,base_channels:int=None,use_sdi:bool=None,sdi_levels:Sequence[int]=None,):super().__init__()# 从 config_v2 中读取默认值(方便后续统一配置)self.in_channels=in_channelsifin_channelsisnotNoneelsecfg_v2.IN_CHANNELS self.num_classes=num_classesifnum_classesisnotNoneelsecfg_v2.NUM_CLASSES self.base_channels=base_channelsifbase_channelsisnotNoneelsecfg_v2.BASE_CHANNELS self.use_sdi=use_sdiifuse_sdiisnotNoneelsecfg_v2.USE_SDI self.sdi_levels=sdi_levelsifsdi_levelsisnotNoneelsecfg_v2.SDI_LEVELS# 1. 编码器self.encoder=EncoderV2(in_channels=self.in_channels,base_channels=self.base_channels,)# 2. SDI 模块(可选)feat_channels=self.encoder.channels# [c1, c2, c3, c4, c5]ifself.use_sdi:self.sdi_module=SDIModule(feat_channels=feat_channels,sdi_levels=self.sdi_levels,mode=cfg_v2.SDI_FUSION_MODE,)else:self.sdi_module=None# 3. 解码器self.decoder=DecoderV2(base_channels=self.base_channels,num_classes=self.num_classes,)defforward(self,x:torch.Tensor)->torch.Tensor:""" :param x: [B, C_in, H, W] :return: logits [B, num_classes, H, W] """feats:List[torch.Tensor]=self.encoder(x)# [f1..f5]ifself.sdi_moduleisnotNone:feats=self.sdi_module(feats)# [f1'..f5']logits=self.decoder(feats)returnlogits5、测试UNetV2模型(test4.py)
为了测试 UNetV2模型,在项目根目录编写一个测试程序test4.py。
- 使用 cfg_v2 中的 IMG_SIZE & IN_CHANNELS 生成随机输入;
- 构建 UNetV2 模型(默认使用 SDI);
- 前向推理一次,检查输出形状是否为 [1, NUM_CLASSES, H, W];
- 自动给出 OK / ERROR 提示。
# test4.py""" 测试 UNetV2 模型结构是否正确: - 使用随机生成的一张图像(batch_size=1) - 大小为 cfg_v2.IMG_SIZE,通道数为 cfg_v2.IN_CHANNELS - 前向运行 UNetV2,检查输出形状是否正确 """importtorchfromcore.config_v2importcfg_v2frommodel.unetv2importUNetV2defmain():print("=== test4.py: 测试 UNetV2 forward ===\n")device=torch.device("cpu")# 测试脚本用 CPU 即可in_channels=cfg_v2.IN_CHANNELS num_classes=cfg_v2.NUM_CLASSES H,W=cfg_v2.IMG_SIZEprint(f"[INFO] IN_CHANNELS ={in_channels}")print(f"[INFO] NUM_CLASSES ={num_classes}")print(f"[INFO] IMG_SIZE ={cfg_v2.IMG_SIZE}\n")# 1. 构建模型model=UNetV2().to(device)model.eval()# 2. 构造随机输入x=torch.randn(1,in_channels,H,W,device=device)# 3. 前向推理withtorch.no_grad():y=model(x)print(f"[INFO] 输入张量形状: x.shape ={x.shape}")print(f"[INFO] 输出张量形状: y.shape ={y.shape}")# 4. 自动检查形状expected_shape=(1,num_classes,H,W)iftuple(y.shape)==expected_shape:print(f"[OK] UNetV2 forward 通过,输出形状符合预期:{expected_shape}🎉")else:print(f"[ERROR] UNetV2 输出形状不符合预期!")print(f" 实际:{tuple(y.shape)}, 期望:{expected_shape}")print("\n=== test4.py: 结束 ===")if__name__=="__main__":main()结果如下图所示,说明 Encoder + SDI + Decoder + UNetV2 总体结构已经跑通 forward。
【本节完】
版权声明:
欢迎关注『youcans动手学 AI』系列
转发请注明原文链接:
【动手学UNet】(12)Unet_V2 模型实现
Copyright 2025 youcans
Crated:2025-12