news 2026/6/10 1:51:52

Day 44 Dataset和Dataloader类

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 44 Dataset和Dataloader类

@浙大疏锦行

在遇到大规模数据集时,显存常常无法一次性存储所有数据,所以需要使用分批训练的方法。为此,PyTorch提供了DataLoader类,该类可以自动将数据集切分为多个批次batch,并支持多线程加载数据。此外,还存在Dataset类,该类可以定义数据集的读取方式和预处理方式。

1. DataLoader类:决定数据如何加载

2. Dataset类:告诉程序去哪里找数据,如何读取单个样本,以及如何预处理。

为了引入这些概念,我们现在接触一个新的而且非常经典的数据集:MNIST手写数字数据集。该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28像素,共包含10个类别。因为每个数据的维度比较小,所以既可以视为结构化数据,用机器学习、MLP训练,也可以视为图像数据,用卷积神经网络训练。

import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader , Dataset # DataLoader 是 PyTorch 中用于加载数据的工具 from torchvision import datasets, transforms # torchvision 是一个用于计算机视觉的库,datasets 和 transforms 是其中的模块 import matplotlib.pyplot as plt # 设置随机种子,确保结果可复现 torch.manual_seed(42) # 1. 数据预处理,该写法非常类似于管道pipeline # transforms 模块提供了一系列常用的图像预处理操作 # 先归一化,再标准化 transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量并归一化到[0,1] transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用 ]) # 2. 加载MNIST数据集,如果没有会自动下载 train_dataset = datasets.MNIST( root='./data', train=True, download=True, transform=transform ) test_dataset = datasets.MNIST( root='./data', train=False, transform=transform )

一、Dataset类

现在我们想要取出来一个图片,看看长啥样,因为datasets.MNIST本质上集成了torch.utils.data.Dataset,所以自然需要有对应的方法。

import matplotlib.pyplot as plt # 随机选择一张图片,可以重复运行,每次都会随机选择 sample_idx = torch.randint(0, len(train_dataset), size=(1,)).item() # 随机选择一张图片的索引 # len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字 image, label = train_dataset[sample_idx] # 获取图片和标签

这里很难理解,为什么train_dataset[sample_idx]可以获取到图片和标签,是因为 datasets.MNIST这个类继承了torch.utils.data.Dataset类,这个类中有一个方法__getitem__,这个方法会返回一个tuple,tuple中第一个元素是图片,第二个元素是标签。

我们来详细介绍下torch.utils.data.Dataset类

PyTorch 的torch.utils.data.Dataset是一个抽象基类,所有自定义数据集都需要继承它并实现两个核心方法:

- __len__():返回数据集的样本总数。

- __getitem__(idx):根据索引idx返回对应样本的数据和标签。

PyTorch 要求所有数据集必须实现__getitem__和__len__,这样才能被DataLoader等工具兼容。这是一种接口约定,类似函数参数的规范。这意味着,如果你要创建一个自定义数据集,你需要实现这两个方法,否则PyTorch将无法识别你的数据集。

在 Python 中,__getitem__和__len__ 是类的特殊方法(也叫魔术方法 ),它们不是像普通函数那样直接使用,而是需要在自定义类中进行定义,来赋予类特定的行为。以下是关于这两个方法具体的使用方式:

__getitem__方法用于让对象支持索引操作,当使用[]语法访问对象元素时,Python 会自动调用该方法。

# 示例代码 class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] def __getitem__(self, idx): return self.data[idx] # 创建类的实例 my_list_obj = MyList() # 此时可以使用索引访问元素,这会自动调用__getitem__方法 print(my_list_obj[2]) # 输出:30

通过定义__getitem__方法,让MyList类的实例能够像 Python 内置的列表一样使用索引获取元素。

__len__方法用于返回对象中元素的数量,当使用内置函数len()作用于对象时,Python 会自动调用该方法。

class MyList: def __init__(self): self.data = [10, 20, 30, 40, 50] def __len__(self): return len(self.data) # 创建类的实例 my_list_obj = MyList() # 使用len()函数获取元素数量,这会自动调用__len__方法 print(len(my_list_obj)) # 输出:5

这里定义的__len__方法,使得MyList类的实例可以像普通列表一样被len()函数调用获取长度。

# minist数据集的简化版本 class MNIST(Dataset): def __init__(self, root, train=True, transform=None): # 初始化:加载图片路径和标签 self.data, self.targets = fetch_mnist_data(root, train) # 这里假设 fetch_mnist_data 是一个函数,用于加载 MNIST 数据集的图片路径和标签 self.transform = transform # 预处理操作 def __len__(self): return len(self.data) # 返回样本总数 def __getitem__(self, idx): # 获取指定索引的样本 # 获取指定索引的图像和标签 img, target = self.data[idx], self.targets[idx] # 应用图像预处理(如ToTensor、Normalize) if self.transform is not None: # 如果有预处理操作 img = self.transform(img) # 转换图像格式 # 这里假设 img 是一个 PIL 图像对象,transform 会将其转换为张量并进行归一化 return img, target # 返回处理后的图像和标签

- Dataset = 厨师(准备单个菜品)

- DataLoader = 服务员(将菜品按订单组合并上桌)

预处理(如切菜、调味)属于厨师的工作,而非服务员。所以在dataset就需要添加预处理步骤。

# 可视化原始图像(需要反归一化) def imshow(img): img = img * 0.3081 + 0.1307 # 反标准化 npimg = img.numpy() plt.imshow(npimg[0], cmap='gray') # 显示灰度图像 plt.show() print(f"Label: {label}") imshow(image)

二、Dataloader类

# 3. 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关 shuffle=True # 随机打乱数据 ) test_loader = DataLoader( test_dataset, batch_size=1000 # 每个批次1000张图片 # shuffle=False # 测试时不需要打乱数据 )

三、总结

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/9 14:44:07

docker安装mongodb

一、前期准备 1.在服务器上面创建三个目录做为挂载到docker (/mongo/data,/mongo/logs,/mongo/conf ) 由于我们要把宿主的配置文件同步容器中,所以要在/mogo/conf创建mogodb的配置文件 mongod.conf ,内容如下: # 系统日志 systemLog:destination: fil…

作者头像 李华
网站建设 2026/6/8 13:47:04

我的网络安全实战学习笔记:记录从零到熟练的每个关键步骤与工具

01 什么是网络安全 网络安全可以基于攻击和防御视角来分类,我们经常听到的 “红队”、“渗透测试” 等就是研究攻击技术,而“蓝队”、“安全运营”、“安全运维”则研究防御技术。 无论网络、Web、移动、桌面、云等哪个领域,都有攻与防两面…

作者头像 李华
网站建设 2026/6/9 15:02:13

Langchain-Chatchat结合OCR技术处理扫描版PDF

Langchain-Chatchat 结合 OCR 技术处理扫描版 PDF 在企业知识管理日益智能化的今天,一个常见的难题浮出水面:大量历史文档以扫描图像的形式沉睡在档案库中。这些 PDF 文件看似清晰可读,实则对计算机而言是一片“黑盒”——没有文本层&#xf…

作者头像 李华
网站建设 2026/6/9 20:08:14

Langchain-Chatchat问答系统可解释性增强方法探索

Langchain-Chatchat问答系统可解释性增强方法探索 在企业知识管理日益复杂的今天,一个看似简单的问题——“年假是多少天?”——却可能牵出一连串的信任危机:员工不相信AI的回答是否准确,法务部门质疑其来源是否合规,I…

作者头像 李华
网站建设 2026/6/9 18:42:31

Langchain-Chatchat能否接入语音识别实现语音问答?

Langchain-Chatchat能否接入语音识别实现语音问答? 在企业知识管理日益智能化的今天,越来越多组织希望构建一个既能保障数据隐私、又能提供自然交互体验的本地化问答系统。Langchain-Chatchat 作为当前开源社区中“本地知识库 大语言模型”架构的代表作…

作者头像 李华
网站建设 2026/6/9 18:42:33

Langchain-Chatchat能否接入外部数据库作为知识源?

Langchain-Chatchat 能否接入外部数据库作为知识源? 在企业智能化转型的浪潮中,一个常见的痛点浮出水面:我们拥有海量的结构化数据——从 CRM 系统中的客户记录,到 ERP 中的订单流水,再到内部 Wiki 和产品手册。但这些…

作者头像 李华