news 2026/5/5 6:26:34

【动手学UNet】(12)Unet_V2 模型实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【动手学UNet】(12)Unet_V2 模型实现

欢迎关注『youcans动手学 AI』系列
【动手学UNet】(1)创建 UNet项目
【动手学UNet】(2)数据加载
【动手学UNet】(7)主程序
【动手学UNet】(11)创建Unet_V2 项目
【动手学UNet】(12)Unet_V2 模型实现


【动手学UNet】(12)Unet_V2 模型实现

    • 15.4 数据加载模块(data_utils_v2.py)
    • 15.5 UNet v2 模型实现(Encoder + SDI + Decoder)

欢迎来到【动手学UNet】系列教程!
本系列将带你从零开始,一步步深入理解和实现经典的UNet图像分割模型。无论你是深度学习初学者,还是有一定经验的开发者,这个系列都将为你提供全面而实用的UNet知识。


15.4 数据加载模块(data_utils_v2.py)

1、图像预处理模块(utils/preprocess_v2.py)

preprocess_v2.py是所有 Retina 图像预处理操作的集中模块(绿色通道 + CLAHE),提供统一的医学图像预处理。模型训练和推理(包括单张推理)可以复用相同的预处理逻辑。

  1. 提供一个独立预处理类:RetinalPreprocessorV2;
  2. 支持 “绿色通道提取”;
  3. 支持 CLAHE(局部直方图均衡);
  4. 参数可调,可由 config_v2 控制;
  5. 最终输出统一为单通道 PIL.Image。

2、数据加载模块(utils/data_utils_v2.py)

data_utils.py是 UNetv2_Retina 的数据入口,自动完成:文件读取 → 预处理 → resize → 归一化 → Tensor 输出。

  1. 创建 RetinaDatasetV2:通用的数据加载类;
  2. 自动按照文件名匹配图像与掩膜;
  3. 调用 RetinalPreprocessorV2(绿色通道 + CLAHE);
  4. 控制预处理、通道数、图像大小;
  5. 预留 transform 钩子(兼容 Albumentations);
  6. Hermetic 设计:返回训练可直接使用的 Tensor。

此前已经建立了data_utils.py脚本用于占位和测试,现在用如下脚本替换原来的测试版本。


3、测试数据加载(test3.py)

为了测试data_utils.py,在项目根目录编写一个测试程序test3.py

  1. 构建训练集 Dataset
  2. 从 dataset/train/images & dataset/train/masks 中取出一张样本,检查形状
  3. 构建 DataLoader
  4. 从 DataLoader 中取出一个 batch
    运行测试程序时,要在dataset/train/images/dataset/train/masks/目录下至少各放一张图片。


15.5 UNet v2 模型实现(Encoder + SDI + Decoder)

目标:在 model/ 目录下实现 U-Net v2 的三大模块 & 总体网络封装,清晰体现 多层特征提取 + SDI 融合 + 解码重建 的流程。

1、多层级编码器(model/encoder_v2.py)
encoder_v2.py负责实现 UNet v2 的编码器部分,是整张视网膜图像特征提取的入口模块,用于逐层提取由浅入深的多尺度特征图,为后续 SDI 注入和解码器重建提供语义与细节信息基础。

  1. 提供多层级编码器类:EncoderV2;
  2. 采用与经典 U-Net 一致的结构:DoubleConv 卷积块 + Down 下采样块;
  3. 从输入图像中逐级提取特征,输出 [f1, f2, f3, f4, f5] 五个层级的特征图;
    4. 保持高分辨率浅层特征用于细节保留、低分辨率深层特征用于语义表达;
  4. 将各层通道数记录在 self.channels 中,为 SDI 模块和 Decoder 提供统一的通道配置。
    输入:x, [B, C_in, H, W]
    输出:特征图列表 [f1, f2, f3, f4, f5],分别对应从浅到深的不同层级。
# model/encoder_v2.py""" UNet v2 编码器(EncoderV2) - 与经典 U-Net 类似,采用多层级卷积 + 下采样结构 - 输出一系列特征图 [f1, f2, f3, f4, f5],从浅到深 - youcans@qq.com """fromtypingimportListimporttorchimporttorch.nnasnnclassDoubleConv(nn.Module):""" 两次 3x3 卷积 + BN + ReLU 的基本卷积块。 """def__init__(self,in_ch:int,out_ch:int):super().__init__()self.net=nn.Sequential(nn.Conv2d(in_ch,out_ch,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True),nn.Conv2d(out_ch,out_ch,kernel_size=3,padding=1,bias=False),nn.BatchNorm2d(out_ch),nn.ReLU(inplace=True),)defforward(self,x:torch.Tensor)->torch.Tensor:returnself.net(x)classDown(nn.Module):""" 下采样块:MaxPool2d(2) + DoubleConv """def__init__(self,in_ch:int,out_ch:int):super().__init__()self.pool=nn.MaxPool2d(kernel_size=2,stride=2)self.conv=DoubleConv(in_ch,out_ch)defforward(self,x:torch.Tensor)->torch.Tensor:x=self.pool(x)x=self.conv(x)returnxclassEncoderV2(nn.Module):""" UNet v2 编码器: 输入:x, [B, C_in, H, W] 输出:一个特征图列表 [f1, f2, f3, f4, f5] - f1: 最浅层(高分辨率,细节丰富) - f5: 最深层(低分辨率,高语义) """def__init__(self,in_channels:int=1,base_channels:int=64):super().__init__()# 通道设置:经典 U-Net 风格c1=base_channels c2=base_channels*2c3=base_channels*4c4=base_channels*8c5=base_channels*16self.inc=DoubleConv(in_channels,c1)self.down1=Down(c1,c2)self.down2=Down(c2,c3)self.down3=Down(c3,c4)self.down4=Down(c4,c5)self.channels=[c1,c2,c3,c4,c5]defforward(self,x:torch.Tensor)->List[torch.Tensor]:f1=self.inc(x)# [B, c1, H, W ]f2=self.down1(f1)# [B, c2, H/2, W/2]f3=self.down2(f2)# [B, c3, H/4, W/4]f4=self.down3(f3)# [B, c4, H/8, W/8]f5=self.down4(f4)# [B, c5, H/16,W/16]return[f1,f2,f3,f4,f5]

2、SDI模块(model/sdi_module.py)

sdi_module.py实现语义与细节注入模块(SDI,Semantic & Detail Infusion),通过从高层特征中提炼语义信息并注入到浅层/中层特征中,使网络在保持细节分辨率的同时增强语义表达能力。

  1. 提供 SDI 模块类:SDIModule;
  2. 使用最高层特征 f5 作为语义源;
  3. 对于指定层级 i,将 f5 经过 1x1 卷积 + 上采样到与 f_i 相同尺寸;
  4. 通过 Sigmoid生成语义门控语义图sem_gate;
  5. 采用哈达玛积(Hadamard product)对指定层级的特征进行语义增强:
    f_i’ = f_i * (1 + sem_gate)
    6. 是否启用、作用于哪些层级等行为,可以通过 feat_channels、sdi_levels、mode 进行配置,并由 config_v2 统一管理。
# model/sdi_module.py""" UNet v2 的 SDI 模块(Semantic & Detail Infusion) - 使用最高层特征 f5 作为语义源 - 对于指定层级 i,将 f5 经过 1x1 卷积 + 上采样到与 f_i 相同尺寸 - 生成语义门控语义图 sem_gate = sigmoid(conv(upsampled_f5)) - 利用哈达玛积(Hadamard product)进行增强: f_i' = f_i * (1 + sem_gate) - 未启用的层级保持原样返回 - youcans(at)qq.com """fromtypingimportList,Sequenceimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFclassSDIModule(nn.Module):def__init__(self,feat_channels:Sequence[int],sdi_levels:Sequence[int]=(1,2,3,4),mode:str="hadamard",):""" :param feat_channels: 每层特征图通道数列表,例如 [c1, c2, c3, c4, c5] :param sdi_levels: 需要进行 SDI 注入的层级索引(1-based),如 (1,2,3,4) :param mode: 当前保留占位,默认为 'hadamard' """super().__init__()assertlen(feat_channels)>=2,"feat_channels 至少需要两个层级"self.feat_channels=list(feat_channels)self.sdi_levels=list(sdi_levels)self.mode=mode# 最高层特征 f5 的通道数self.top_channels=feat_channels[-1]# 为每个层级 i 创建一个 1x1 卷积,将 top feature 映射到该层通道数self.proj_convs=nn.ModuleDict()fori,cinenumerate(feat_channels[:-1],start=1):name=f"proj_to_l{i}"self.proj_convs[name]=nn.Conv2d(self.top_channels,c,kernel_size=1)defforward(self,feats:List[torch.Tensor])->List[torch.Tensor]:""" :param feats: [f1, f2, f3, f4, f5] :return: [f1',f2',f3',f4',f5],其中部分层经过 SDI 增强 """assertlen(feats)==len(self.feat_channels),"输入特征层数与 feat_channels 不匹配"f_top=feats[-1]# f5B,C_top,H_top,W_top=f_top.shape out_feats:List[torch.Tensor]=[]fori,finenumerate(feats,start=1):ifi==len(feats):# 最深层(f5)通常作为语义源,可直接保留out_feats.append(f)continueifinotinself.sdi_levels:# 不在 SDI 作用层级中,保持原样out_feats.append(f)continue# ---- SDI 注入过程 ----# 1) 将 f_top 投影到与 f_i 相同通道数proj_conv=self.proj_convs[f"proj_to_l{i}"]sem_feat=proj_conv(f_top)# [B, c_i, H_top, W_top]# 2) 上采样到 f_i 相同空间尺寸_,_,H_i,W_i=f.shape sem_feat_up=F.interpolate(sem_feat,size=(H_i,W_i),mode="bilinear",align_corners=False)# 3) 生成语义 gatesem_gate=torch.sigmoid(sem_feat_up)# [B, c_i, H_i, W_i]ifself.mode=="hadamard":# Hadamard product + 残差增强f_enhanced=f*(1.0+sem_gate)else:# 其他模式可以后续扩展,这里简单做加法占位f_enhanced=f+sem_gate out_feats.append(f_enhanced)returnout_feats

3、解码器(model/decoder_v2.py)

decoder_v2.py实现 UNet v2 的解码器模块,负责将多层级特征逐步上采样并融合浅层跳跃连接,实现空间分辨率的重建和最终的像素级分割输出,整体风格与经典 U-Net 解码路径保持一致。

  1. 提供上采样块 Up 和解码器类 DecoderV2;
  2. 通过转置卷积(ConvTranspose2d)逐级将特征图放大 2 倍;
  3. 在每一层上采样后,与对应编码器特征(如 f4、f3、f2、f1)进行 skip-connection 拼接;
  4. 使用 DoubleConv 对拼接后的特征进行融合,逐层恢复细节与结构信息;
  5. 最终通过 1×1 卷积输出分割 logits(通道数等于 num_classes),用于后续 Sigmoid / Softmax 等激活与损失计算。
# model/decoder_v2.py""" UNet v2 解码器(DecoderV2) 与经典 U-Net 解码路径类似: - 逐步上采样 - 与对应层级的 encoder 特征做 skip-connection - 通过 DoubleConv 融合特征 - youcans(at)qq.com """fromtypingimportListimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfrom.encoder_v2importDoubleConvclassUp(nn.Module):""" 上采样块: - 采用 ConvTranspose2d 做 2x 上采样 - 与对应 encoder 特征拼接(channel 方向) - 再通过 DoubleConv 融合 """def__init__(self,in_ch:int,out_ch:int):""" :param in_ch: 解码特征的通道数(上采样前) :param out_ch: 上采后 + 拼接后,再经卷积得到的输出通道数 """super().__init__()# 上采样:通道数减半,尺寸 ×2self.up=nn.ConvTranspose2d(in_ch,in_ch//2,kernel_size=2,stride=2)# 上采样后会与 skip 特征拼接,拼接后通道数 = in_ch//2 + skip_ch# DoubleConv 里具体 in/out 通道设置在 forward 前根据实际拼接情况决定# 这里采用经典 U-Net 的写法:用 in_ch 作为 DoubleConv 的输入通道self.conv=DoubleConv(in_ch,out_ch)defforward(self,x:torch.Tensor,skip:torch.Tensor)->torch.Tensor:x=self.up(x)# [B, in_ch//2, H*2, W*2]# 对齐尺寸(某些情况下可能因奇偶问题产生 1 像素差异)diff_y=skip.size(2)-x.size(2)diff_x=skip.size(3)-x.size(3)ifdiff_y!=0ordiff_x!=0:x=F.pad(x,[diff_x//2,diff_x-diff_x//2,diff_y//2,diff_y-diff_y//2])# 通道维拼接x=torch.cat([skip,x],dim=1)# [B, skip_ch + in_ch//2, H, W]x=self.conv(x)returnxclassDecoderV2(nn.Module):""" UNet v2 解码器: 输入:来自 Encoder (+ SDI) 的多层特征 [f1', f2', f3', f4', f5'] 输出:分割 logits [B, num_classes, H, W] """def__init__(self,base_channels:int=64,num_classes:int=1):super().__init__()c1=base_channels c2=base_channels*2c3=base_channels*4c4=base_channels*8c5=base_channels*16# 对应 Encoder 的通道设置:# f1: c1, f2: c2, f3: c3, f4: c4, f5: c5self.up1=Up(c5,c4)# f5 -> f4self.up2=Up(c4,c3)# ...self.up3=Up(c3,c2)self.up4=Up(c2,c1)self.out_conv=nn.Conv2d(c1,num_classes,kernel_size=1)defforward(self,feats:List[torch.Tensor])->torch.Tensor:""" :param feats: [f1', f2', f3', f4', f5'] :return: logits [B, num_classes, H, W] """assertlen(feats)==5,"DecoderV2 目前假定有 5 个层级特征"f1,f2,f3,f4,f5=feats# 注意:f1 分辨率最高,f5 最低x=self.up1(f5,f4)# -> 类似 f4 空间尺寸x=self.up2(x,f3)x=self.up3(x,f2)x=self.up4(x,f1)logits=self.out_conv(x)returnlogits

4、UNetV2 模型(unetv2.py)

unetv2.py将 Encoder、SDI 模块与 Decoder 组合成一个完整的 UNet v2 模型类 UNetV2,是训练与推理阶段真正被调用的主网络结构,实现端到端的“图像输入 → 分割输出”。

  1. 提供整体网络类:UNetV2;
  2. 内部集成 EncoderV2、SDIModule(可选)与 DecoderV2,形成统一的前向计算流程;
  3. 支持从 config_v2 读取模型关键参数(输入通道数、类别数、base_channels、是否启用 SDI、SDI 层级等),实现配置集中管理;
  4. 前向过程中先通过编码器获得多层特征,再选择性通过 SDI 模块进行语义与细节注入,最后交由解码器重建输出;
  5. 输出为大小与输入一致、通道数为 num_classes 的分割 logits,可直接接入损失函数和评估指标,用于视网膜血管等医学图像分割任务。
    forward(x) 内部流程如下:
    (1)使用 Encoder 提取特征:[f1, f2, f3, f4, f5]
    (2)使用 SDI 模块生成融合特征:[f1’, f2’, f3’, f4’, f5’]
    (3)使用 Decoder 从高层到低层逐步重建高分辨率分割图;
    最终输出 [B, num_classes, H, W] 的分割概率/ logit。
# model/unetv2.py""" UNetV2 主模型: - EncoderV2: 提取多层级特征 [f1..f5] - SDIModule: 在中间对特征进行语义注入(可选) - DecoderV2: 逐级上采样,输出分割结果 - youcans(at)qq.com """fromtypingimportSequence,Listimporttorchimporttorch.nnasnnfromcore.config_v2importcfg_v2from.encoder_v2importEncoderV2from.sdi_moduleimportSDIModulefrom.decoder_v2importDecoderV2classUNetV2(nn.Module):def__init__(self,in_channels:int=None,num_classes:int=None,base_channels:int=None,use_sdi:bool=None,sdi_levels:Sequence[int]=None,):super().__init__()# 从 config_v2 中读取默认值(方便后续统一配置)self.in_channels=in_channelsifin_channelsisnotNoneelsecfg_v2.IN_CHANNELS self.num_classes=num_classesifnum_classesisnotNoneelsecfg_v2.NUM_CLASSES self.base_channels=base_channelsifbase_channelsisnotNoneelsecfg_v2.BASE_CHANNELS self.use_sdi=use_sdiifuse_sdiisnotNoneelsecfg_v2.USE_SDI self.sdi_levels=sdi_levelsifsdi_levelsisnotNoneelsecfg_v2.SDI_LEVELS# 1. 编码器self.encoder=EncoderV2(in_channels=self.in_channels,base_channels=self.base_channels,)# 2. SDI 模块(可选)feat_channels=self.encoder.channels# [c1, c2, c3, c4, c5]ifself.use_sdi:self.sdi_module=SDIModule(feat_channels=feat_channels,sdi_levels=self.sdi_levels,mode=cfg_v2.SDI_FUSION_MODE,)else:self.sdi_module=None# 3. 解码器self.decoder=DecoderV2(base_channels=self.base_channels,num_classes=self.num_classes,)defforward(self,x:torch.Tensor)->torch.Tensor:""" :param x: [B, C_in, H, W] :return: logits [B, num_classes, H, W] """feats:List[torch.Tensor]=self.encoder(x)# [f1..f5]ifself.sdi_moduleisnotNone:feats=self.sdi_module(feats)# [f1'..f5']logits=self.decoder(feats)returnlogits

5、测试UNetV2模型(test4.py)

为了测试 UNetV2模型,在项目根目录编写一个测试程序test4.py

  1. 使用 cfg_v2 中的 IMG_SIZE & IN_CHANNELS 生成随机输入;
  2. 构建 UNetV2 模型(默认使用 SDI);
  3. 前向推理一次,检查输出形状是否为 [1, NUM_CLASSES, H, W];
  4. 自动给出 OK / ERROR 提示。
# test4.py""" 测试 UNetV2 模型结构是否正确: - 使用随机生成的一张图像(batch_size=1) - 大小为 cfg_v2.IMG_SIZE,通道数为 cfg_v2.IN_CHANNELS - 前向运行 UNetV2,检查输出形状是否正确 """importtorchfromcore.config_v2importcfg_v2frommodel.unetv2importUNetV2defmain():print("=== test4.py: 测试 UNetV2 forward ===\n")device=torch.device("cpu")# 测试脚本用 CPU 即可in_channels=cfg_v2.IN_CHANNELS num_classes=cfg_v2.NUM_CLASSES H,W=cfg_v2.IMG_SIZEprint(f"[INFO] IN_CHANNELS ={in_channels}")print(f"[INFO] NUM_CLASSES ={num_classes}")print(f"[INFO] IMG_SIZE ={cfg_v2.IMG_SIZE}\n")# 1. 构建模型model=UNetV2().to(device)model.eval()# 2. 构造随机输入x=torch.randn(1,in_channels,H,W,device=device)# 3. 前向推理withtorch.no_grad():y=model(x)print(f"[INFO] 输入张量形状: x.shape ={x.shape}")print(f"[INFO] 输出张量形状: y.shape ={y.shape}")# 4. 自动检查形状expected_shape=(1,num_classes,H,W)iftuple(y.shape)==expected_shape:print(f"[OK] UNetV2 forward 通过,输出形状符合预期:{expected_shape}🎉")else:print(f"[ERROR] UNetV2 输出形状不符合预期!")print(f" 实际:{tuple(y.shape)}, 期望:{expected_shape}")print("\n=== test4.py: 结束 ===")if__name__=="__main__":main()

结果如下图所示,说明 Encoder + SDI + Decoder + UNetV2 总体结构已经跑通 forward。


【本节完】


版权声明:
欢迎关注『youcans动手学 AI』系列
转发请注明原文链接:
【动手学UNet】(12)Unet_V2 模型实现

Copyright 2025 youcans
Crated:2025-12


版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/17 20:39:23

多模态突破:AI规模化应用的关键密码

2025年末的AI行业,正上演一场以多模态为核心的竞速赛。从豆包1.8实现视频理解能力的跨越式升级,到谷歌Gemini3强化跨模态交互,再到OpenAI获得迪士尼巨额投资深耕影视生成,多模态已成为衡量大模型竞争力的核心标尺。这种能够统一理…

作者头像 李华
网站建设 2026/5/3 9:08:52

DingTalkRevokeMsgPatcher终极指南:飞书消息防撤回完全解决方案

DingTalkRevokeMsgPatcher终极指南:飞书消息防撤回完全解决方案 【免费下载链接】DingTalkRevokeMsgPatcher 钉钉消息防撤回补丁PC版(原名:钉钉电脑版防撤回插件,也叫:钉钉防撤回补丁、钉钉消息防撤回补丁)…

作者头像 李华
网站建设 2026/5/3 15:31:27

Linux 内核驱动-中断

Linux 内核驱动--中断 概述 中断是计算机系统中一种重要的异步事件处理机制,它允许外部设备在需要处理器注意时暂停当前执行的程序,转而去处理设备的需求,处理完成后再返回原程序继续执行。 中断的主要作用包括: • 提高CPU利用率:避免CPU轮询等待外部设备。 • 实现实…

作者头像 李华
网站建设 2026/4/22 20:30:11

5分钟玩转BilibiliDown:解锁B站音频下载的实用技巧

还在为喜欢的B站背景音乐无处下载而烦恼吗?想将UP主精心制作的音频内容永久保存,却苦于找不到合适的工具?今天,就让我带你全面了解这款备受好评的B站音频下载工具——BilibiliDown,它不仅能下载视频,更是一…

作者头像 李华
网站建设 2026/4/28 9:24:15

MCP不是API替代品!AI Agent开发者的避坑指南,建议收藏细读

MCP是AI与外部工具交互的通用适配器,与API是互补而非替代关系。MCP适合AI自主决策、多工具协作和快速原型验证,而API则在性能敏感、复杂数据操作、安全合规和固定流程场景中更优。开发者应避免盲目滥用MCP,应根据场景精准搭配:用M…

作者头像 李华
网站建设 2026/5/2 18:34:16

大模型代理幻觉全解析:五大类型、十八种触发原因与十种解决方案

这篇文章全面综述了基于LLM的代理幻觉问题,创新性地将代理幻觉分为推理、执行、感知、记忆和通信五种类型,深入分析了十八种触发原因,并总结了十种有效缓解方法(知识利用、范式改进、事后验证等)。研究为理解LLM代理幻…

作者头像 李华