@浙大疏锦行
一、图像数据格式:灰度 vs 彩色
图像数据的核心是「通道数」和「张量维度」,PyTorch 中需遵循固定格式才能被模型正确处理。
1. 基础概念
| 类型 | 核心特征 | 取值范围 | 典型应用 |
|---|---|---|---|
| 灰度图 | 单通道,仅包含亮度信息,无色彩;每个像素只有 1 个数值 | 0-255(8 位) | 手写数字识别、医学影像 |
| 彩色图 | 主流为 RGB 三通道(红 / 绿 / 蓝),每个通道对应 1 个亮度值,三值叠加形成色彩 | 0-255(每通道) | 图像分类、目标检测 |
2. 张量格式(PyTorch 标准)
PyTorch 中图像张量必须是(Batch, Channel, Height, Width)(BCHW)格式,与 OpenCV/Pillow 的(Height, Width, Channel)(HWC)格式不同,需手动转换。
| 图像类型 | 单张图(HWC) | 单张图张量(CHW) | 批量图张量(BCHW) |
|---|---|---|---|
| 灰度 | (H, W, 1) | (1, H, W) | (B, 1, H, W) |
| 彩色 | (H, W, 3) | (3, H, W) | (B, 3, H, W) |
3. 实战:读取 + 格式转换
import torch import cv2 from PIL import Image import numpy as np # ========== 1. 灰度图处理 ========== # PIL读取灰度图(L表示灰度模式) gray_img = Image.open("gray_digit.png").convert('L') gray_np = np.array(gray_img) # 形状:(28, 28)(手写数字MNIST尺寸) # 转换为PyTorch张量(CHW):新增通道维度 gray_tensor = torch.from_numpy(gray_np).unsqueeze(0).float() / 255.0 # 归一化到0-1 print("灰度图张量形状(CHW):", gray_tensor.shape) # torch.Size([1, 28, 28]) # ========== 2. 彩色图处理 ========== # OpenCV读取(默认BGR格式,需转为RGB) color_img = cv2.imread("cat.jpg") # 形状:(480, 640, 3)(HWC,BGR) color_rgb = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB) # 转为RGB # 转换为PyTorch张量(CHW):HWC → CHW color_tensor = torch.from_numpy(color_rgb).permute(2, 0, 1).float() / 255.0 print("彩色图张量形状(CHW):", color_tensor.shape) # torch.Size([3, 480, 640]) # ========== 3. 批量图像(BCHW) ========== batch_gray = torch.stack([gray_tensor]*8) # 8张灰度图,形状(8, 1, 28, 28) batch_color = torch.stack([color_tensor]*8) # 8张彩色图,形状(8, 3, 480, 640) print("批量灰度张量:", batch_gray.shape) print("批量彩色张量:", batch_color.shape)二、图像模型的定义
图像任务核心用卷积神经网络(CNN),需继承nn.Module,核心层适配 4 维图像张量(BCHW),以下是规范的定义模板。
1. 通用 CNN 模型定义(兼容灰度 / 彩色)
import torch.nn as nn import torch.nn.functional as F class ImageClassifier(nn.Module): """ 图像分类CNN模型(适配灰度/彩色) :param in_channels: 输入通道数(灰度=1,彩色=3) :param num_classes: 分类类别数(如MNIST=10,猫狗分类=2) :param img_size: 输入图像尺寸(H=W,如28/224) """ def __init__(self, in_channels=1, num_classes=10, img_size=28): super().__init__() # 卷积块1:Conv → ReLU → MaxPool(下采样,尺寸减半) self.conv1 = nn.Conv2d( in_channels=in_channels, out_channels=16, kernel_size=3, # 3×3卷积核 padding=1 # 保持尺寸不变(padding=(kernel_size-1)/2) ) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 卷积块2 self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.pool2 = nn.MaxPool2d(2, 2) # 计算全连接层输入维度:两次池化后尺寸为 img_size/4 fc_input_dim = 32 * (img_size//4) * (img_size//4) # 全连接层(分类头) self.fc1 = nn.Linear(fc_input_dim, 128) self.dropout = nn.Dropout(0.2) # 防止过拟合 self.fc2 = nn.Linear(128, num_classes) def forward(self, x): """前向传播:输入BCHW张量,输出分类概率""" # 卷积块1:(B, C, H, W) → (B, 16, H/2, W/2) x = self.pool1(F.relu(self.conv1(x))) # 卷积块2:→ (B, 32, H/4, W/4) x = self.pool2(F.relu(self.conv2(x))) # 展平:4维特征 → 2维(B, 特征数) x = x.view(x.size(0), -1) # 全连接层 x = self.dropout(F.relu(self.fc1(x))) x = self.fc2(x) # 输出logits(未归一化的概率) return F.softmax(x, dim=1) # 转为0-1概率 # ========== 实例化模型 ========== # 灰度图模型(MNIST手写数字) mnist_model = ImageClassifier(in_channels=1, num_classes=10, img_size=28) # 彩色图模型(猫狗分类) cat_dog_model = ImageClassifier(in_channels=3, num_classes=2, img_size=224) # 测试输入 mnist_input = torch.randn(8, 1, 28, 28) # 8张灰度图 cat_dog_input = torch.randn(8, 3, 224, 224) # 8张彩色图 # 前向传播 mnist_output = mnist_model(mnist_input) cat_dog_output = cat_dog_model(cat_dog_input) print("MNIST模型输出形状:", mnist_output.shape) # (8, 10) print("猫狗模型输出形状:", cat_dog_output.shape) # (8, 2)2. 模型定义核心要点
- 卷积层
nn.Conv2d的in_channels必须匹配图像通道数(灰度 = 1,彩色 = 3); - 池化层会下采样图像尺寸,需准确计算全连接层的输入维度(避免形状不匹配);
x.view(x.size(0), -1)是关键:将 4 维卷积特征展平为 2 维,适配全连接层。
三、显存占用的 5 个核心来源
训练时 GPU 显存消耗主要来自以下 5 部分(按占用大小排序),每一部分都有明确的优化方法:
| 显存来源 | 核心原理 | 简化计算方式(float32) | 优化手段 |
|---|---|---|---|
| 1. 批量数据(BCHW) | 输入的图像批次张量占用显存(训练 / 验证都需要) | 批量大小 × 通道数 × 高 × 宽 ×4 字节 | 减小 batch size、降低图像分辨率、归一化到 0-1(不影响显存,但避免数值溢出) |
| 2. 神经元中间状态 | 前向传播中各层的输出张量(如卷积层 / 池化层输出) | 各层输出尺寸 ×4 字节,累加 | 验证 / 推理时用torch.no_grad()、梯度检查点(checkpoint)、减少网络深度 |
| 3. 模型参数 | 模型中可训练参数(卷积核、全连接层权重) | 总参数数 ×4 字节 | 模型轻量化(如 MobileNet)、减少卷积通道数、量化(int8) |
| 4. 梯度参数 | 每个模型参数对应的梯度张量(形状与参数完全一致) | 与模型参数显存相等 | 梯度累积(小 batch 累加多轮再更新)、梯度裁剪、只训练部分层 |
| 5. 优化器参数 | 优化器维护的状态(如 Adam 的动量 / 方差,每个参数对应 2 个张量) | Adam:参数数 ×8 字节;SGD:参数数 ×4 字节 | 用 SGD 代替 Adam、清空优化器缓存 |
实战:显存优化关键代码
import torch.cuda.amp as amp # 混合精度训练(核心优化) # 1. 混合精度训练(将float32转为float16,显存减半) scaler = amp.GradScaler() # 梯度缩放器(避免float16下溢) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = mnist_model.to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() # 训练循环中使用混合精度 model.train() for x_batch, y_batch in train_loader: x_batch = x_batch.to(DEVICE) y_batch = y_batch.to(DEVICE) with amp.autocast(): # 自动将张量转为float16 outputs = model(x_batch) loss = criterion(outputs, y_batch) # 反向传播(缩放梯度避免下溢) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() # 2. 验证/推理时关闭梯度(减少中间状态+梯度显存) model.eval() with torch.no_grad(): outputs = model(x_batch) # 无中间状态/梯度显存占用 # 3. 梯度累积(用小batch模拟大batch) accumulation_steps = 4 # 累积4轮梯度 = 等效batch size×4 for i, (x, y) in enumerate(train_loader): x, y = x.to(DEVICE), y.to(DEVICE) outputs = model(x) loss = criterion(outputs, y) / accumulation_steps # 归一化损失 loss.backward() # 每累积4轮更新一次参数 if (i+1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()四、batch size 与训练的核心关系
batch size(批次大小)是训练中最关键的超参数,直接影响显存、速度、模型效果,核心关系如下:
1. batch size ↔ 显存
- 正相关:batch size 越大,批量数据和中间状态占用的显存越多;
- 极限:超过显存上限会报
CUDA out of memory (OOM); - 建议:从 8/16 开始逐步增大,用
torch.cuda.max_memory_allocated()监控显存占用。
2. batch size ↔ 训练速度
- 正相关(有上限):batch size 越大,GPU 并行计算效率越高,每轮训练时间越短;
- 饱和点:当 batch size 占满 GPU 核心时,继续增大不会提速(甚至因数据传输耗时增加变慢)。
3. batch size ↔ 训练效果
| batch size 大小 | 收敛特点 | 泛化能力 | 适用场景 |
|---|---|---|---|
| 小(8/16/32) | 梯度更新频繁,震荡但易收敛到最优解;训练轮次多,总时间长 | 强 | 小数据集、复杂模型(CNN) |
| 大(64/128/256) | 梯度更新稳定,训练轮次少,总时间短;易陷入局部最优,需调大学习率 | 弱 | 大数据集、简单模型(MLP) |
| 极小(1,纯 SGD) | 梯度噪声大,收敛最慢;但泛化能力最优(学术研究常用) | 最优 | 小样本、追求极致泛化 |
4. batch size ↔ 学习率
- 适配原则:batch size 增大时,学习率需按比例增大(如 batch size 翻倍,学习率也翻倍);
- 原因:大 batch 的梯度估计更稳定,可承受更大的学习率,避免收敛过慢。
5. 合理选择 batch size 的建议
- 显存优先:先确定不 OOM 的最大 batch size(如 32),再根据效果调整;
- 效果优先:小数据集 / 复杂模型选小 batch(16/32),大数据集选大 batch(64/128);
- 折中方案:显存不足时,用「小 batch + 梯度累积」模拟大 batch(如 batch=8,累积 4 轮 = 等效 batch=32)。