news 2026/6/9 16:30:11

第P3周:Pytorch实现天气识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P3周:Pytorch实现天气识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客

  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 显示图片

4. 划分数据集

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、 个人总结

过拟合的确认方法

解决方案

1. 正则化措施

2. 数据增强优化

3. 批归一化(BN)应用

4. 早停策略

5. 动态学习率调整

6. 优化器升级

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision from torchvision import transforms, datasets import os,PIL,pathlib,random device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

data_dir = './data/' data_dir = pathlib.Path(data_dir) data_paths = list(data_dir.glob('*')) classeNames = [str(path).split("\\")[1] for path in data_paths] classeNames
  • 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象。
  • 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
  • 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classeNames
  • 第四步:打印classeNames列表,显示每个文件所属的类别名称。

3. 显示图片

import matplotlib.pyplot as plt from PIL import Image # 指定图像文件夹路径 image_folder = './data/cloudy/' # 获取文件夹中的所有图像文件 image_files = [f for f in os.listdir(image_folder) if f.endswith((".jpg", ".png", ".jpeg"))] # 创建Matplotlib图像 fig, axes = plt.subplots(3, 8, figsize=(16, 6)) # 使用列表推导式加载和显示图像 for ax, img_file in zip(axes.flat, image_files): img_path = os.path.join(image_folder, img_file) img = Image.open(img_path) ax.imshow(img) ax.axis('off') # 显示图像 plt.tight_layout() plt.show()

total_datadir = './data/' # 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863 train_transforms = transforms.Compose([ transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸 transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间 transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。 ]) total_data = datasets.ImageFolder(total_datadir,transform=train_transforms) total_data

4. 划分数据集

train_size = int(0.8 * len(total_data)) test_size = len(total_data) - train_size train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size]) train_dataset, test_dataset train_size,test_size batch_size = 32 train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True) for X, y in test_dl: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break

二、构建简单的CNN网络

import torch.nn.functional as F class Network_bn(nn.Module): def __init__(self): super(Network_bn, self).__init__() """ nn.Conv2d()函数: 第一个参数(in_channels)是输入的channel数量 第二个参数(out_channels)是输出的channel数量 第三个参数(kernel_size)是卷积核大小 第四个参数(stride)是步长,默认为1 第五个参数(padding)是填充大小,默认为0 """ self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0) self.bn1 = nn.BatchNorm2d(12) self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0) self.bn2 = nn.BatchNorm2d(12) self.pool1 = nn.MaxPool2d(2,2) self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0) self.bn4 = nn.BatchNorm2d(24) self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0) self.bn5 = nn.BatchNorm2d(24) self.pool2 = nn.MaxPool2d(2,2) self.fc1 = nn.Linear(24*50*50, len(classeNames)) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = self.pool1(x) x = F.relu(self.bn4(self.conv4(x))) x = F.relu(self.bn5(self.conv5(x))) x = self.pool2(x) x = x.view(-1, 24*50*50) x = self.fc1(x) return x device = "cuda" if torch.cuda.is_available() else "cpu" print("Using {} device".format(device)) model = Network_bn().to(device) model

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-4 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 20 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、 个人总结

过拟合的确认方法

当模型出现疑似过拟合时,可通过以下方法进一步确认:

  1. 增加训练轮次:持续训练时若验证集准确率下降而损失上升,则确认过拟合
  2. 降低模型复杂度:如减少全连接层或通道数后验证指标提升,说明原模型过于复杂
  3. 增强数据多样性:添加数据增强后若验证指标改善,表明原训练数据不足导致模型死记硬背

解决方案

1. 正则化措施

  • Dropout:在全连接层前添加nn.Dropout(0.5),通过随机丢弃神经元强制学习鲁棒特征
  • L2正则化:优化器中设置weight_decay=1e-4,抑制过大权重,防止噪声拟合

2. 数据增强优化

  • 方法:使用torchvision.transforms实现随机裁剪、翻转、颜色变换等
  • 作用:提升数据多样性,促使模型学习通用特征而非样本细节,增强泛化能力

3. 批归一化(BN)应用

  • 原理:对卷积层输出进行归一化处理(均值≈0,方差≈1)后缩放平移
  • 优势
    • 稳定层间输入分布,加速收敛并缓解梯度问题
    • 具有正则化效果,类似"mini-batch级数据增强"

4. 早停策略

  • 实施:持续监控验证集准确率,当性能不再提升(超过patience轮次)时终止训练
  • 价值
    • 防止过度拟合训练噪声
    • 节省计算资源,自动获取最佳泛化模型

5. 动态学习率调整

  • 机制:当验证性能停滞时,按因子(如0.5)降低学习率
  • 效益
    • 实现更精细的参数优化
    • 平缓的参数更新可降低过拟合风险

6. 优化器升级

  • 改进方案:将SGD替换为Adam优化器并配合权重衰减
  • 原因:当前SGD学习率(1e-4)偏低导致收敛缓慢,且缺乏动量和正则化支持
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/5 19:39:16

Python返回数组/List长度的方法

你想了解在 Python 中获取列表(List,也常被俗称 “数组”)长度的方法,这是 Python 基础中高频使用的操作,核心是通过内置函数实现,同时需要掌握不同场景下的使用细节(比如嵌套列表、numpy 数组等…

作者头像 李华
网站建设 2026/6/6 8:21:10

运维转行做什么好?零基础入门到精通,收藏这篇就够了

运维工程师转行网络安全是职业发展路径中比较常见的一种转行,这种转行通常基于以下几个原因和优势: **1.技能相关性:**运维工程师通常负责维护和管理企业的IT基础设施,包括服务器、网络和存储系统。这些工作内容与网络安全领域有…

作者头像 李华
网站建设 2026/6/6 6:46:06

XML 编码:深入解析与实际应用

XML 编码:深入解析与实际应用 引言 XML(可扩展标记语言)是一种用于存储和传输数据的标记语言。它被广泛应用于互联网、企业内部系统以及移动应用中。本文将深入解析XML编码的原理、规范以及在实际应用中的优势。 一、XML编码概述 1.1 XML的起源与发展 XML最早由W3C(万…

作者头像 李华
网站建设 2026/6/6 7:31:38

【Python基础】Python字符串操作全攻略:新手入门必备指南

目录 Python字符串操作全攻略:新手入门必备指南1. 引言:什么是字符串?2. 前置知识3. 字符串的创建与访问3.1 创建字符串3.2 访问字符串元素:索引和切片3.3 字符串的不可变性 4. 常用字符串操作方法4.1 获取字符串长度:…

作者头像 李华
网站建设 2026/6/8 22:45:03

全网十大降AI工具大比拼:知网、维普、万方实测数据公开

家人们,现在学校查得是真严,不仅重复率,还得降ai率,学校规定必须得20%以下... 折腾了半个月,终于把市面上各类方法试了个遍,坑踩了不少,智商税也交了。今天这就把这份十大降AI工具合集掏心窝子…

作者头像 李华
网站建设 2026/6/6 8:23:05

高效过审必备:盘点十大适合中国大学生的降AI工具

家人们,现在学校查得是真严,不仅重复率,还得降ai率,学校规定必须得20%以下... 折腾了半个月,终于把市面上各类方法试了个遍,坑踩了不少,智商税也交了。今天这就把这份十大降AI工具合集掏心窝子…

作者头像 李华