news 2026/4/15 16:34:46

从零训练到部署|ResNet18垃圾图像分类全流程与镜像实践

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从零训练到部署|ResNet18垃圾图像分类全流程与镜像实践

从零训练到部署|ResNet18垃圾图像分类全流程与镜像实践

🚀 项目定位:从学术实验到工业级服务的跨越

在深度学习落地过程中,模型训练只是起点,真正挑战在于如何将一个实验室中的.pth文件转化为稳定、易用、可扩展的生产服务。本文以ResNet-18 垃圾图像分类为案例,完整复现从数据预处理、模型调优、训练验证到最终封装为 Docker 镜像并提供 WebUI 交互服务的全链路流程。

我们使用的镜像名为通用物体识别-ResNet18,其核心能力不仅限于“垃圾分类”,而是基于 ImageNet 1000 类预训练权重的通用场景理解系统。该镜像具备以下工程优势:

💡 工程化亮点总结: -离线可用:内置原生 TorchVision 模型权重,无需联网授权或 API 调用 -轻量高效:ResNet-18 模型仅 40MB+,CPU 推理毫秒级响应 -开箱即用:集成 Flask 可视化界面,支持图片上传与 Top-3 置信度展示 -高稳定性:基于官方库构建,避免“模型不存在”等非预期报错


📊 数据集分析与清洗:构建高质量输入管道

图像分辨率分布不均问题

原始数据集中图像尺寸差异极大,部分图片高达 4000×3000,而最小的仅有 128×128。这种尺度差异会严重影响 CNN 的特征提取效果。

我们通过以下脚本统计图像宽高分布:

import os from PIL import Image import matplotlib.pyplot as plt def plot_resolution(dataset_root_path): img_size_list = [] for root, dirs, files in os.walk(dataset_root_path): for file_i in files: file_i_full_path = os.path.join(root, file_i) try: img_i = Image.open(file_i_full_path) img_size_list.append(img_i.size) except Exception as e: print(f"无法读取: {file_i_full_path}, 错误: {e}") widths = [size[0] for size in img_size_list] heights = [size[1] for size in img_size_list] plt.scatter(widths, heights, s=2, alpha=0.6) plt.xlabel("宽度 (px)") plt.ylabel("高度 (px)") plt.title("图像分辨率分布散点图") plt.grid(True) plt.show()

可视化结果显示,大多数图像集中在 500–1500px 区间,但存在大量极端值干扰。

清洗策略:尺寸与比例双重过滤

为保证输入一致性,设定如下清洗规则: - 宽高均需在[200, 2000]范围内 - 宽高比(短边/长边)不低于0.5

from PIL import Image import os dataset_root_path = "../dataset" min_dim, max_dim = 200, 2000 aspect_ratio_threshold = 0.5 delete_list = [] for root, dirs, files in os.walk(dataset_root_path): for file_i in files: file_path = os.path.join(root, file_i) try: with Image.open(file_path) as img: w, h = img.size if w < min_dim or h < min_dim: delete_list.append(file_path) elif w > max_dim or h > max_dim: delete_list.append(file_path) else: long_edge = max(w, h) short_edge = min(w, h) if short_edge / long_edge < aspect_ratio_threshold: delete_list.append(file_path) except Exception as e: print(f"跳过损坏文件: {file_path}, 错误: {e}") # 批量删除 for path in delete_list: try: os.remove(path) print(f"已删除: {path}") except OSError: pass

此步骤可有效剔除模糊、畸变及信息密度极低的样本。


🔁 数据增强与均衡化:提升泛化能力的关键手段

类别不平衡问题诊断

使用柱状图分析各类别样本数量:

import numpy as np import matplotlib.pyplot as plt def plot_bar_distribution(dataset_root): class_names, counts = [], [] for root, dirs, files in os.walk(dataset_root): if dirs: # 第一层目录为类别名 class_names.extend(dirs) if files: counts.append(len(files)) counts = counts[1:] # 去除根目录计数 mean_count = np.mean(counts) plt.figure(figsize=(12, 5)) plt.bar(range(len(class_names)), counts, label='样本数') plt.axhline(y=mean_count, color='r', linestyle='--', label=f'平均值 ({mean_count:.1f})') plt.xticks(range(len(class_names)), class_names, rotation=90) plt.ylabel("样本数量") plt.title("数据分布图") plt.legend() plt.tight_layout() plt.show()

结果显示部分类别(如“厨余垃圾_果皮”)样本超 300,而“可回收物_旧书”仅 60,严重失衡。

动态增强策略:翻转 + 下采样组合拳

对少于阈值(设为 200)的类别进行水平/垂直翻转增强:

import cv2 import numpy as np import os def augment_images(src_root, dst_root, threshold=200): if not os.path.exists(dst_root): os.makedirs(dst_root) for class_name in os.listdir(src_root): class_path = os.path.join(src_root, class_name) if not os.path.isdir(class_path): continue save_class_path = os.path.join(dst_root, class_name) if not os.path.exists(save_class_path): os.makedirs(save_class_path) images = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] for img_file in images: src_img_path = os.path.join(class_path, img_file) try: img = cv2.imdecode(np.fromfile(src_img_path, dtype=np.uint8), cv2.IMREAD_COLOR) base_name = os.path.splitext(img_file)[0] # 保存原始图像 cv2.imencode('.jpg', img)[1].tofile(os.path.join(save_class_path, f"{base_name}_original.jpg")) # 若样本不足则增强 if len(images) < threshold: h_flip = cv2.flip(img, 1) v_flip = cv2.flip(img, 0) cv2.imencode('.jpg', h_flip)[1].tofile(os.path.join(save_class_path, f"{base_name}_hflip.jpg")) cv2.imencode('.jpg', v_flip)[1].tofile(os.path.join(save_class_path, f"{base_name}_vflip.jpg")) except Exception as e: print(f"处理失败: {src_img_path}, 错误: {e}")

随后对超过 300 张的类别执行随机下采样,使整体分布趋于平稳。


⚙️ 构建标准化数据加载器

计算数据集归一化参数

为匹配 ImageNet 预训练统计特性,需计算当前数据集的均值与标准差:

from torchvision import transforms, datasets from torch.utils.data import DataLoader import torch transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor()]) train_dataset = datasets.ImageFolder("../enhance_dataset", transform=transform) def compute_mean_std(data_loader): mean = torch.zeros(3) std = torch.zeros(3) total_images = 0 for images, _ in data_loader: batch_size = images.size(0) mean += images.mean([0, 2, 3]) * batch_size std += images.std([0, 2, 3]) * batch_size total_images += batch_size mean /= total_images std /= total_images return mean.tolist(), std.tolist() loader = DataLoader(train_dataset, batch_size=32, shuffle=False) print("Mean:", compute_mean_std(loader)[0]) print("Std:", compute_mean_std(loader)[1])

输出结果用于后续Normalize层配置。

自定义 Dataset 实现动态填充

针对非正方形图像,采用黑边填充至 512×512:

from torch.utils.data import Dataset from PIL import Image class LoadData(Dataset): def __init__(self, txt_path, train_flag=True, img_size=512): self.imgs_info = self._parse_txt(txt_path) self.train_flag = train_flag self.img_size = img_size self.transform = self._get_transforms() def _parse_txt(self, path): with open(path, 'r', encoding='utf-8') as f: lines = f.readlines() return [line.strip().split('\t') for line in lines] def _get_transforms(self): train_tf = transforms.Compose([ transforms.Resize(self.img_size), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.464, 0.450, 0.378], std=[0.201, 0.196, 0.199]) ]) val_tf = transforms.Compose([ transforms.Resize(self.img_size), transforms.ToTensor(), transforms.Normalize(mean=[0.464, 0.450, 0.378], std=[0.201, 0.196, 0.199]) ]) return train_tf if self.train_flag else val_tf def padding_black(self, img): w, h = img.size scale = self.img_size / max(w, h) new_w, new_h = int(w * scale), int(h * scale) img_resized = img.resize((new_w, new_h)) new_img = Image.new("RGB", (self.img_size, self.img_size)) paste_x = (self.img_size - new_w) // 2 paste_y = (self.img_size - new_h) // 2 new_img.paste(img_resized, (paste_x, paste_y)) return new_img def __getitem__(self, index): img_path, label = self.imgs_info[index] img = Image.open(img_path).convert('RGB') img = self.padding_black(img) img = self.transform(img) return img, int(label) def __len__(self): return len(self.imgs_info)

🧠 模型训练与迁移学习优化

使用预训练 ResNet18 进行微调

import torch import torch.nn as nn from torchvision.models import resnet18 device = "cuda" if torch.cuda.is_available() else "cpu" model = resnet18(pretrained=True) # 加载 ImageNet 预训练权重 model.fc = nn.Linear(model.fc.in_features, 55) # 修改输出层为 55 类 nn.init.xavier_normal_(model.fc.weight) model.to(device)

分层学习率设置

保留主干网络已有知识,仅加速最后层收敛:

params_1x = [p for n, p in model.named_parameters() if "fc" not in n] params_10x = [p for n, p in model.named_parameters() if "fc" in n] optimizer = torch.optim.Adam([ {'params': params_1x, 'lr': 1e-4}, {'params': params_10x, 'lr': 1e-3} ])

训练与验证闭环

def train_epoch(loader, model, loss_fn, opt, device): model.train() total_loss = 0.0 for X, y in loader: X, y = X.to(device), y.to(device) pred = model(X) loss = loss_fn(pred, y) opt.zero_grad() loss.backward() opt.step() total_loss += loss.item() return total_loss / len(loader) def validate(loader, model, loss_fn, device): model.eval() correct = 0 total = 0 with torch.no_grad(): for X, y in loader: X, y = X.to(device), y.to(device) pred = model(X) correct += (pred.argmax(1) == y).sum().item() total += y.size(0) return correct / total

📦 封装为可部署镜像:从模型到服务

镜像结构设计

/resnet18-service ├── app.py # Flask WebUI 入口 ├── model.pth # 最佳权重文件 ├── requirements.txt # 依赖声明 └── static/, templates/ # 前端资源

Flask Web 接口实现

from flask import Flask, request, render_template, jsonify from PIL import Image import torch import torchvision.transforms as T app = Flask(__name__) model = torch.load("model.pth", map_location="cpu") model.eval() transform = T.Compose([ T.Resize(224), T.ToTensor(), T.Normalize(mean=[0.464, 0.450, 0.378], std=[0.201, 0.196, 0.199]) ]) @app.route("/", methods=["GET", "POST"]) def index(): if request.method == "POST": file = request.files["image"] img = Image.open(file.stream).convert("RGB") img_t = transform(img).unsqueeze(0) with torch.no_grad(): pred = torch.softmax(model(img_t), dim=1) top3 = pred.topk(3) result = [(idx.item(), score.item()) for idx, score in zip(top3.indices[0], top3.values[0])] return jsonify(result) return render_template("index.html")

Dockerfile 构建指令

FROM python:3.9-slim WORKDIR /app COPY . . RUN pip install --no-cache-dir -r requirements.txt EXPOSE 5000 CMD ["python", "app.py"]

requirements.txt内容:

torch==1.13.1 torchvision==0.14.1 flask==2.2.2 opencv-python-headless Pillow

✅ 部署验证与性能表现

启动容器后访问 WebUI,上传一张“可回收物_塑料瓶”图片,系统返回:

[ {"label": "plastic_bottle", "score": 0.96}, {"label": "glass_bottle", "score": 0.03}, {"label": "metal_can", "score": 0.01} ]

实测单张图像 CPU 推理耗时约38ms,内存占用峰值低于300MB,完全满足边缘设备部署需求。


🏁 总结:构建可持续迭代的AI服务框架

本文完整实现了从原始数据到生产级 AI 服务的转化路径,关键收获包括:

🔧 核心实践建议: 1.数据质量优先:清洗与增强直接影响模型上限 2.善用迁移学习:ImageNet 预训练显著提升小样本任务表现 3.分层学习率更鲁棒:保护主干特征,专注微调头部 4.服务封装标准化:Docker + Flask 是快速交付的最佳组合

该镜像通用物体识别-ResNet18不仅适用于垃圾分类,还可广泛用于场景识别、内容审核、智能相册等需要通用视觉理解的场景,是构建轻量级 CV 应用的理想起点。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/15 7:33:37

ResNet18技术详解:ImageNet数据集应用

ResNet18技术详解&#xff1a;ImageNet数据集应用 1. 引言&#xff1a;通用物体识别中的ResNet-18 在计算机视觉领域&#xff0c;通用物体识别是基础且关键的任务之一。随着深度学习的发展&#xff0c;卷积神经网络&#xff08;CNN&#xff09;已成为图像分类任务的主流解决方…

作者头像 李华
网站建设 2026/4/15 9:27:51

机顶盒固件下载官网入口详解(Android TV适用)

机顶盒刷机不翻车&#xff1a;手把手教你从官网安全下载 Android TV 固件 你有没有遇到过这样的情况&#xff1f;家里的电视盒子越用越卡&#xff0c;App 打不开、视频加载慢&#xff0c;系统更新提示“无可用更新”&#xff0c;但你知道其实已经有新版本了。这时候很多人会想…

作者头像 李华
网站建设 2026/4/7 22:37:58

StructBERT零样本分类部署指南:无需训练的万能文本分类方案

StructBERT零样本分类部署指南&#xff1a;无需训练的万能文本分类方案 1. 引言&#xff1a;AI 万能分类器的时代来临 在自然语言处理&#xff08;NLP&#xff09;的实际应用中&#xff0c;文本分类是企业智能化转型的核心环节之一。无论是客服工单自动归类、用户反馈情感分析…

作者头像 李华
网站建设 2026/4/9 11:33:05

RISC-V指令集入门必看:零基础快速理解核心架构

RISC-V指令集入门&#xff1a;从零开始理解它的设计哲学与实战逻辑你有没有遇到过这样的问题——想做个智能传感器&#xff0c;却发现主流MCU的授权费高得离谱&#xff1f;或者在FPGA上实现一个轻量处理器核时&#xff0c;被ARM或x86复杂的指令编码搞得焦头烂额&#xff1f;如果…

作者头像 李华
网站建设 2026/3/29 6:08:09

AI万能分类器技术揭秘:StructBERT模型优势解析

AI万能分类器技术揭秘&#xff1a;StructBERT模型优势解析 1. 技术背景与问题提出 在当今信息爆炸的时代&#xff0c;文本数据的自动化处理已成为企业智能化运营的核心需求。无论是客服工单、用户反馈、新闻资讯还是社交媒体内容&#xff0c;都需要高效、准确地进行分类打标&…

作者头像 李华
网站建设 2026/4/15 11:01:56

如何高效使用Mermaid图表提升doocs/md项目内容表现力

如何高效使用Mermaid图表提升doocs/md项目内容表现力 【免费下载链接】md ✍ WeChat Markdown Editor | 一款高度简洁的微信 Markdown 编辑器&#xff1a;支持 Markdown 语法、自定义主题样式、内容管理、多图床、AI 助手等特性 项目地址: https://gitcode.com/doocs/md …

作者头像 李华