news 2026/4/6 11:38:41

Day 39 模型可视化与推理

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 39 模型可视化与推理

@浙大疏锦行

一、nn.Module核心自带方法

nn.Module封装了模型的核心逻辑,以下是高频使用的自带方法,按功能分类:

1. 模型状态控制(训练 / 评估模式)

方法作用
model.train()切换为训练模式:启用 Dropout、BatchNorm 等层的训练行为(默认模式)
model.eval()切换为评估模式:关闭 Dropout、固定 BatchNorm 均值 / 方差,用于推理 / 验证
model.training属性,返回布尔值:True= 训练模式,False= 评估模式

示例

import torch import torch.nn as nn class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 16, 3) self.dropout = nn.Dropout(0.5) # 训练时随机失活,评估时关闭 def forward(self, x): x = self.conv(x) x = self.dropout(x) return x model = SimpleCNN() print(model.training) # True(默认训练模式) model.eval() print(model.training) # False(评估模式,dropout失效) model.train() print(model.training) # True(切回训练模式)

2. 设备迁移(CPU/GPU)

方法作用
model.to(device)将模型所有参数 / 缓冲区移到指定设备(cuda/cpu/mps),返回模型实例
model.cuda()快捷方式:移到默认 GPU(等价于model.to('cuda')
model.cpu()快捷方式:移到 CPU(等价于model.to('cpu')

示例

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) # 模型移到GPU/CPU # 验证设备 print(next(model.parameters()).device) # 输出:cuda:0 或 cpu

3. 参数管理(查看 / 遍历参数)

方法作用
model.parameters()返回生成器:包含所有可训练参数(nn.Parameter类型)
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.named_parameters()返回生成器:(参数名,参数张量),便于定位参数
model.state_dict()返回字典:{参数名:参数值},用于保存模型参数
model.load_state_dict()加载参数字典,用于恢复模型

示例

# 查看所有参数名称和形状 for name, param in model.named_parameters(): print(f"参数名:{name},形状:{param.shape},设备:{param.device}") # 统计总参数量(手动实现,无第三方库时用) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"总参数:{total_params},可训练参数:{trainable_params}")

4. 结构遍历(查看模型层)

方法作用
model.children()返回生成器:仅包含直接子层(如 Sequential 内的第一层),不递归
model.named_children()返回生成器:(层名,子层),仅直接子层
model.modules()返回生成器:递归包含所有层(包括嵌套层)
model.named_modules()返回生成器:(层名,层),递归所有层

示例

# 定义嵌套模型 class NestedModel(nn.Module): def __init__(self): super().__init__() self.block1 = nn.Sequential( nn.Conv2d(3, 16, 3), nn.ReLU() ) self.block2 = nn.Linear(16*30*30, 10) model = NestedModel() # children():仅直接子层(block1、block2) print("=== children() ===") for name, layer in model.named_children(): print(name, layer) # modules():递归所有层(包括Sequential内的Conv2d、ReLU) print("\n=== modules() ===") for name, layer in model.named_modules(): print(name, layer)

5. 前向传播与梯度

方法作用
model.forward(x)手动调用前向传播(不推荐),建议直接model(x)(调用__call__
model(x)等价于model.__call__(x),自动执行 forward + 钩子(hook)逻辑
model.zero_grad()清空所有参数的梯度(训练时反向传播前必须调用)

示例

x = torch.randn(1, 3, 32, 32).to(device) output = model(x) # 推荐:调用__call__,等价于model.forward(x) + 钩子 model.zero_grad() # 清空梯度 output.sum().backward() # 反向传播计算梯度

二、torchsummary库的summary方法

torchsummary是早期轻量库,核心功能是快速打印模型层结构、输出形状、总参数量,仅支持单输入模型,对嵌套模型 / 多输入支持差,维护较少。

1. 安装与基本用法

pip install torchsummary
from torchsummary import summary # 定义模型(输入:3通道32×32图像) class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc1 = nn.Linear(32 * 8 * 8, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.pool(nn.functional.relu(self.conv1(x))) x = self.pool(nn.functional.relu(self.conv2(x))) x = x.view(-1, 32 * 8 * 8) x = nn.functional.relu(self.fc1(x)) x = self.fc2(x) return x # 设备配置 + 模型初始化 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 调用summary:参数(模型,输入形状(通道,高,宽),batch_size可选) summary(model, input_size=(3, 32, 32), batch_size=1)

2. 输出解读

---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [1, 16, 32, 32] 448 MaxPool2d-2 [1, 16, 16, 16] 0 Conv2d-3 [1, 32, 16, 16] 4,640 MaxPool2d-4 [1, 32, 8, 8] 0 Linear-5 [1, 128] 262,272 Linear-6 [1, 10] 1,290 ================================================================ Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ----------------------------------------------------------------

3. 优缺点

优点缺点
极简、无多余依赖仅支持单输入模型
输出简洁、易理解对嵌套模型 / 多分支模型支持差
快速查看参数量 / 形状无批次维度、无内存占用细分
支持 GPU/CPU维护停滞,仅兼容 PyTorch 旧版本

三、torchinfo库的summary方法(推荐)

torchinfotorchsummary的升级版(原torchsummaryX),解决了多输入、嵌套模型、维度展示不清晰的问题,功能更全面,是当前 PyTorch 模型可视化的首选。

1. 安装与基本用法

pip install torchinfo
from torchinfo import summary # 复用上面的SimpleCNN模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SimpleCNN().to(device) # 核心参数:model, input_size, batch_dim, device, col_width等 summary( model, input_size=(1, 3, 32, 32), # (batch_size, 通道, 高, 宽) batch_dim=0, # 批次维度的位置(默认0) device=device, # 模型设备 col_width=20, # 列宽 col_names=["input_size", "output_size", "num_params", "trainable"], # 显示列 row_settings=["var_names"] # 显示层变量名 )

2. 输出解读

========================================================================================== Layer (type (var_name)) Input Shape Output Shape Param # Trainable ========================================================================================== SimpleCNN (SimpleCNN) [1, 3, 32, 32] [1, 10] -- -- ├─Conv2d (conv1) [1, 3, 32, 32] [1, 16, 32, 32] 448 True ├─MaxPool2d (pool) [1, 16, 32, 32] [1, 16, 16, 16] -- -- ├─Conv2d (conv2) [1, 16, 16, 16] [1, 32, 16, 16] 4,640 True ├─MaxPool2d (pool) [1, 32, 16, 16] [1, 32, 8, 8] -- -- ├─Linear (fc1) [1, 2048] [1, 128] 262,272 True ├─Linear (fc2) [1, 128] [1, 10] 1,290 True ========================================================================================== Total params: 268,650 Trainable params: 268,650 Non-trainable params: 0 Total mult-adds (M): 2.15 ========================================================================================== Input size (MB): 0.01 Forward/backward pass size (MB): 0.29 Params size (MB): 1.02 Estimated Total Size (MB): 1.32 ==========================================================================================

四、推理的写法:评估模式

def evaluate_classification(model, dataloader, device): """ 分类模型评估:计算准确率、F1-score(宏平均)、混淆矩阵 """ # 1. 切换到评估模式(必须!) model.eval() # 2. 初始化指标容器 all_preds = [] all_labels = [] # 3. 关闭梯度计算(加速+省显存) with torch.no_grad(): for batch_idx, (x, y) in enumerate(dataloader): # 数据移到设备 x = x.to(device, dtype=torch.float32) y = y.to(device, dtype=torch.long) # 4. 推理(前向传播) outputs = model(x) # 输出:(batch_size, num_classes) preds = torch.argmax(outputs, dim=1) # 取概率最大的类别 # 5. 收集预测结果和真实标签(转回CPU便于计算指标) all_preds.extend(preds.cpu().numpy()) all_labels.extend(y.cpu().numpy()) # 可选:打印进度 if (batch_idx + 1) % 10 == 0: print(f"Batch [{batch_idx+1}/{len(dataloader)}] 完成") # 6. 计算评估指标 accuracy = accuracy_score(all_labels, all_preds) f1_macro = f1_score(all_labels, all_preds, average="macro") # 宏平均F1(适合类别均衡) f1_weighted = f1_score(all_labels, all_preds, average="weighted") # 加权F1(适合类别不均衡) # 7. 打印结果 print("="*50) print(f"分类模型评估结果:") print(f"准确率 (Accuracy): {accuracy:.4f}") print(f"宏平均F1-score: {f1_macro:.4f}") print(f"加权F1-score: {f1_weighted:.4f}") print("="*50) return { "accuracy": accuracy, "f1_macro": f1_macro, "f1_weighted": f1_weighted, "preds": all_preds, "labels": all_labels } # 执行评估 eval_results = evaluate_classification(model, test_loader, device)
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/27 12:35:27

定制开发实战:海外版外卖系统PHP全栈解决方案

在数字化转型的浪潮下,全球外卖市场规模预计将在2025年突破2000亿美元。与国内市场不同,海外外卖平台面临多语言支持、跨境支付、税务合规、文化差异等复杂挑战。作为拥有二十年开发经验的PHP全栈架构师,我将深入解析如何基于PHP技术栈构建高…

作者头像 李华
网站建设 2026/4/4 22:23:48

Linux I/O模型总结

Linux I/O模型 一、I/O 操作的两个核心阶段 在深入具体模型之前,我们必须明确一个前提:任何一次 Linux 下的 I/O 操作(以网络 socket 读取为例),都分为两个不可分割的阶段: 数据就绪阶段:内核等…

作者头像 李华
网站建设 2026/4/1 5:23:34

PSD-95抗体:如何为缺血性脑卒中治疗开启神经保护新纪元?

一、缺血性脑卒中治疗面临哪些临床挑战?缺血性脑卒中作为全球致残率最高的神经系统疾病,其治疗时间窗窄、神经损伤不可逆的特点一直是临床面临的重大挑战。目前标准治疗方案阿替普酶虽能通过溶栓恢复血流,但存在出血风险高、治疗时间窗短&…

作者头像 李华
网站建设 2026/3/28 19:06:43

OpenAI开源“Circuit‑Sparsity”模型,0.4 B 参数实现 99.9% 权重归零!

12 月 15 日,OpenAI 在官方博客上公布了最新的开源项目——Circuit‑Sparsity 模型。该模型仅拥有 0.4 B 参数,但高达 99.9% 的权重被强制置零,形成极度稀疏的 Transformer 结构。OpenAI 表示,此举旨在破解大语言模型&#xff08…

作者头像 李华
网站建设 2026/4/5 23:22:39

18、软件开发中的交叉引用与测试驱动开发实践

软件开发中的交叉引用与测试驱动开发实践 在软件开发过程中,文档编写和测试是确保软件质量和可维护性的重要环节。下面将介绍 Sphinx 的交叉引用功能,以及测试驱动开发(TDD)的相关内容。 1. Sphinx 交叉引用 Sphinx 提供了内联标记来设置交叉引用。例如,要创建一个指向…

作者头像 李华
网站建设 2026/4/4 8:14:30

AI眼镜热销卖爆:产能紧张与供应链竟然都快跟不上了!

近期,AI眼镜成为消费电子市场的热点。自今年上半年多家厂商相继发布新品后,AI眼镜在天猫、京东、抖音等平台的成交额出现爆发式增长,双十一期间更是实现全网销量第一的成绩。然而,热销的背后却暴露出产能不足、供应链紧张的结构性…

作者头像 李华