news 2026/5/11 21:58:09

04_残差网络

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
04_残差网络

描述

残差网络是现代卷积神经网络的一种,有效的抑制了深层神经网络的梯度弥散和梯度爆炸现象,使得深度网络训练不那么困难。

下面以cifar-10-batches-py数据集,实现一个ResNet18的残差网络,通过继承nn.Module实现残差块(Residual Block),网络模型类。

定义Block

ResNetBlock派生至nn.Module,需要自己实现forward函数。

torch.nn.Module是nn中十分重要的类,包含网络各层的定义及forward方法,可以从这个类派生自己的模型类。

nn.Module重要的函数:

  • forward(self,*input):forward函数为前向传播函数,需要自己重写,它用来实现模型的功能,并实现各个层的连接关系;
  • __call__(self, *input, **kwargs): __call__()的作用是使class实例能够像函数一样被调用,以“对象名()”的形式使用;
  • __repr__(self):__repr__函数为Python的一个内置函数,它能把一个对象用字符串的形式表达出来;
  • __init__(self):构造函数,自定义模型的网络层对象一般在这个函数中定义。
classResNetBlock(nn.Module):def__init__(self,input_channels,num_channels,stride=1):''' 构造函数:定义网络层 '''super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=stride)self.btn1=nn.BatchNorm2d(num_channels)self.conv2=nn.Conv2d(num_channels,num_channels,kernel_size=3,padding=1,stride=1)self.btn2=nn.BatchNorm2d(num_channels)ifstride!=1:self.downsample=nn.Conv2d(input_channels,num_channels,kernel_size=1,stride=stride)else:self.downsample=lambdax:xdefforward(self,X):''' 实现反向传播 '''Y=self.btn1(self.conv1(X))Y=nn.functional.relu(Y)Y=self.btn2(self.conv2(Y))Y+=self.downsample(X)returnnn.functional.relu(Y)

定义模型

ResNet同样派生于nn.Module,与ResNetBlock类似,需要实现forward。

torch.nn.Sequential是PyTorch 中一个用于构建顺序神经网络模型的容器类,它将多个神经网络层或模块按顺序组合在一起,简化模型搭建过程。‌Sequential器会严格按照添加的顺序执行内部的子模块,前向传播时自动传递数据,适用于简单神经网络的构建。

classResNet(nn.Module):def__init__(self,layer_dism,num_class=10):''' 构造函数:定义预处理model;构建block层 '''super(ResNet,self).__init__()# 预处理self.stem=nn.Sequential(nn.Conv2d(3,64,3,1),# 3x30x30nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2,2)# 64x15x15)self.layer1=self.build_resblock(64,64,layer_dism[0])self.layer2=self.build_resblock(64,128,layer_dism[1],2)self.layer3=self.build_resblock(128,256,layer_dism[2],2)self.layer4=self.build_resblock(256,512,layer_dism[3],2)self.avgpool=nn.AvgPool2d(1,1)self.btn=nn.Flatten()self.fc=nn.Linear(512,num_class)defbuild_resblock(self,input_channels,num_channels,block,stride=1):res_block=nn.Sequential()res_block.append(ResNetBlock(input_channels,num_channels,stride))for_inrange(1,block):res_block.append(ResNetBlock(num_channels,num_channels,stride))returnres_blockdefforward(self,X):out=self.stem(X)out=self.layer1(out)out=self.layer2(out)out=self.layer3(out)out=self.layer4(out)out=self.avgpool(out)returnself.fc(self.btn(out))

模型训练

加载数据

使用torchvision.datasets加载本地数据,如果本地没有数据,可以设置download=True自动下载。

# 定义数据转换transform=transforms.Compose([transforms.ToTensor(),# 将PIL图像转换为Tensortransforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))# 归一化])# 加载CIFAR-10训练集trainset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=True,download=False,transform=transform)trainloader=th.utils.data.DataLoader(trainset,batch_size=16,shuffle=False,num_workers=2)# 加载CIFAR-10测试集testset=torchvision.datasets.CIFAR10(root=r'D:\dwload',train=False,download=False,transform=transform)testloader=th.utils.data.DataLoader(testset,batch_size=16,shuffle=False,num_workers=2)

模型初始化

模型初始化是确保网络能够有效学习的关键步骤,一个好的初始值,会使模型收敛速度提高,使模型准确率更精确。

torch.nn.init模块提供了一系列的权重初始化函数:

  • torch.nn.init.uniform_ :均匀分布
  • torch.nn.init.normal_ :正态分布
  • torch.nn.init.constant_:初始化为指定常数
  • torch.nn.init.kaiming_uniform_:凯明均匀分布
  • torch.nn.init.kaiming_normal_:凯明正态分布
  • torch.nn.init.xavier_uniform_:Xavier均匀分布
  • torch.nn.init.xavier_normal_:Xavier正态分布

在初始化时,最好不要将模型的参数初始化为0,因为这样会导致梯度消失,进而影响训练效果。可以将模型初始化为一个很小的值,如0.01,0.001等。

definitialize_weight(m):ifisinstance(m,nn.Conv2d)orisinstance(m,nn.Linear):nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')# mode:权重方差计算方式,可选 'fan_in' 或 'fan_out'(输入、输出神经元数量)# nonlinearity:激活函数类型,用于调整计算公式 ,一般是relu、leaky_reluifm.biasisnotNone:nn.init.constant_(m.bias,0)

[2,2,2,2] 参数分别代表四个block的中的残差块数量(可以仔细看一下build_resblock函数)

resnet_18=ResNet([2,2,2,2])resnet_18.apply(initialize_weight)# 初始化模型loss_cross=nn.CrossEntropyLoss()trainer=th.optim.SGD(resnet_18.parameters())

训练

训练过程比较漫长,这里训练只有20轮,测试精度0.51。如果有N卡加持的话,可以适当调高epoch,精度能进一步提高。

forepochinrange(0,20):running_loss=0.0forinputs,labelsintrainloader:trainer.zero_grad()outputs=resnet_18(inputs)loss=loss_cross(outputs,labels)loss.backward()trainer.step()running_loss+=loss.item()print(f'[{epoch+1}] ev loss:{running_loss/3125}')running_loss=0.0
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/11 21:57:23

MATLAB脑网络分析终极指南:5步掌握GRETNA核心功能

MATLAB脑网络分析终极指南:5步掌握GRETNA核心功能 【免费下载链接】GRETNA A Graph-theoretical Network Analysis Toolkit in MATLAB 项目地址: https://gitcode.com/gh_mirrors/gr/GRETNA GRETNA作为MATLAB环境下的图论网络分析工具包,为神经科…

作者头像 李华
网站建设 2026/5/8 22:45:04

Windows微信机器人开发实战:零基础构建智能自动化助手

Windows微信机器人开发实战:零基础构建智能自动化助手 【免费下载链接】puppet-xp Wechaty Puppet WeChat Windows Protocol 项目地址: https://gitcode.com/gh_mirrors/pu/puppet-xp 还在为微信消息手动回复而烦恼?想要实现智能客服、群管理自动…

作者头像 李华
网站建设 2026/5/10 23:12:33

3、Linux基础操作与常用命令全解析

Linux基础操作与常用命令全解析 1. 系统电源控制 在Linux系统中,正确地开启和关闭系统至关重要,错误的操作可能会导致数据丢失或损坏。 - 启动系统 :开启系统电源即启动系统,这一过程被称为“引导”(booting)。在Linux内核引导时,屏幕会显示许多信息,之后会出现登…

作者头像 李华
网站建设 2026/5/10 0:56:23

Snipe-IT资产标签系统深度解析与实战应用

Snipe-IT资产标签系统深度解析与实战应用 【免费下载链接】snipe-it A free open source IT asset/license management system 项目地址: https://gitcode.com/GitHub_Trending/sn/snipe-it 你是否曾经在资产盘点时遇到过这样的困扰?面对数百台设备&#xff…

作者头像 李华
网站建设 2026/5/9 2:45:54

20、Linux系统下音频光盘与声音文件处理全攻略

Linux系统下音频光盘与声音文件处理全攻略 1. 音频光盘的使用 在安装了CD驱动器和声卡的系统上,就可以播放音频光盘。在Linux系统中,可以通过命令行的软件工具来控制音频CD的播放,其控制方式与传统CD播放器类似。同时,也有工具可以从CD读取音频数据并将其写入文件,这些文…

作者头像 李华
网站建设 2026/5/10 6:01:30

MCP MS-720 Agent集成资源稀缺泄露:资深架构师的私藏配置模板

第一章:MCP MS-720 Agent集成概述MCP MS-720 Agent 是现代监控平台中用于设备状态采集与远程控制的核心组件,专为边缘计算环境设计,支持多协议接入与动态配置更新。该代理程序能够在资源受限的设备上稳定运行,实现与中心管理平台的…

作者头像 李华