ResNet18垃圾分类应用:云端GPU 1小时搭建演示系统
引言
想象一下,你正在参加一场环保科技展会,周围都是各种高科技设备。突然,一个展台前排起了长队——原来是一台能够自动识别垃圾种类的AI设备,参观者只需把垃圾放在摄像头前,屏幕上就会立刻显示"可回收物"、"厨余垃圾"、"有害垃圾"或"其他垃圾"。更让人惊讶的是,这套系统从零搭建只用了1小时!
这就是我们今天要介绍的ResNet18垃圾分类应用。作为计算机视觉领域的经典模型,ResNet18特别适合这种轻量级的图像分类任务。借助云端GPU资源,即使你是AI新手,也能快速搭建出这样一套演示系统。本文将带你一步步实现这个项目,从环境准备到模型部署,最后生成一个可交互的演示界面。
1. 环境准备:10分钟搞定基础配置
首先我们需要准备运行环境。传统方式需要自己安装CUDA、PyTorch等工具,过程繁琐且容易出错。而使用CSDN星图镜像广场提供的预置镜像,可以一键获得已经配置好的环境。
- 登录CSDN星图平台,搜索"PyTorch 1.12 + CUDA 11.3"基础镜像
- 选择GPU实例(推荐T4或V100,显存8G以上)
- 点击"立即创建",等待约2分钟环境初始化完成
创建成功后,你会获得一个完整的Python 3.8环境,已经预装了: - PyTorch 1.12.1 - torchvision 0.13.1 - OpenCV 4.6.0 - Flask 2.2.2(用于构建Web界面)
验证环境是否正常,可以运行以下命令:
python -c "import torch; print(torch.__version__); print(torch.cuda.is_available())"如果输出类似以下内容,说明环境配置正确:
1.12.1+cu113 True2. 数据准备:垃圾分类数据集处理
垃圾分类任务需要特定的数据集。国内常用的垃圾分类数据集通常包含4大类: - 可回收物(如塑料瓶、纸张) - 厨余垃圾(如果皮、剩菜) - 有害垃圾(如电池、药品) - 其他垃圾(如纸巾、陶瓷)
我们可以使用公开的"Garbage Classification"数据集,它包含约2500张图片,已经按上述4类分类好。下载并解压数据集:
wget https://example.com/garbage_dataset.zip unzip garbage_dataset.zip数据集目录结构应该是这样的:
garbage_dataset/ ├── recyclable/ ├── kitchen/ ├── hazardous/ └── other/为了训练模型,我们需要将数据集划分为训练集和验证集。使用以下Python脚本完成划分:
import os import shutil from sklearn.model_selection import train_test_split dataset_path = 'garbage_dataset' output_path = 'garbage_split' classes = ['recyclable', 'kitchen', 'hazardous', 'other'] # 创建输出目录 os.makedirs(os.path.join(output_path, 'train'), exist_ok=True) os.makedirs(os.path.join(output_path, 'val'), exist_ok=True) for cls in classes: # 为每个类别创建子目录 os.makedirs(os.path.join(output_path, 'train', cls), exist_ok=True) os.makedirs(os.path.join(output_path, 'val', cls), exist_ok=True) # 获取当前类别所有图片 images = os.listdir(os.path.join(dataset_path, cls)) # 按8:2划分训练集和验证集 train_files, val_files = train_test_split(images, test_size=0.2, random_state=42) # 复制文件到对应目录 for f in train_files: shutil.copy(os.path.join(dataset_path, cls, f), os.path.join(output_path, 'train', cls, f)) for f in val_files: shutil.copy(os.path.join(dataset_path, cls, f), os.path.join(output_path, 'val', cls, f))3. 模型训练:30分钟完成迁移学习
ResNet18是一个18层深的卷积神经网络,在ImageNet数据集上预训练过,非常适合作为基础模型进行迁移学习。我们只需要替换最后的全连接层,就能适应我们的垃圾分类任务。
以下是完整的训练代码(保存为train.py):
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, models, transforms from torch.utils.data import DataLoader # 数据增强和归一化 train_transforms = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 加载数据集 train_dataset = datasets.ImageFolder('garbage_split/train', train_transforms) val_dataset = datasets.ImageFolder('garbage_split/val', val_transforms) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) # 初始化模型 model = models.resnet18(pretrained=True) num_features = model.fc.in_features model.fc = nn.Linear(num_features, 4) # 4个输出类别 # 使用GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 训练循环 num_epochs = 10 best_acc = 0.0 for epoch in range(num_epochs): # 训练阶段 model.train() running_loss = 0.0 for inputs, labels in train_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() # 验证阶段 model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_acc = 100 * correct / total print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, Val Acc: {val_acc:.2f}%') # 保存最佳模型 if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), 'best_model.pth') print(f'Training complete, best validation accuracy: {best_acc:.2f}%')运行训练脚本:
python train.py在T4 GPU上,10个epoch的训练大约需要20-30分钟。训练完成后,你会得到一个best_model.pth文件,这就是我们训练好的垃圾分类模型。
4. 快速部署:10分钟搭建交互式演示
为了让展会参观者能够体验我们的垃圾分类系统,我们需要创建一个简单的Web界面。使用Flask框架可以快速实现这个功能。
创建app.py文件:
from flask import Flask, render_template, request, jsonify import torch from torchvision import models, transforms from PIL import Image import io app = Flask(__name__) # 加载模型 model = models.resnet18(pretrained=False) num_features = model.fc.in_features model.fc = torch.nn.Linear(num_features, 4) model.load_state_dict(torch.load('best_model.pth')) model.eval() # 类别标签 class_names = ['可回收物', '厨余垃圾', '有害垃圾', '其他垃圾'] # 图像预处理 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) @app.route('/') def index(): return render_template('index.html') @app.route('/classify', methods=['POST']) def classify(): if 'file' not in request.files: return jsonify({'error': 'No file uploaded'}) file = request.files['file'] if file.filename == '': return jsonify({'error': 'No file selected'}) try: # 读取并预处理图像 image_bytes = file.read() image = Image.open(io.BytesIO(image_bytes)) image = transform(image).unsqueeze(0) # 预测 with torch.no_grad(): output = model(image) _, predicted = torch.max(output, 1) class_id = predicted.item() return jsonify({ 'class_name': class_names[class_id], 'class_id': class_id }) except Exception as e: return jsonify({'error': str(e)}) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)创建templates/index.html文件:
<!DOCTYPE html> <html> <head> <title>垃圾分类AI演示系统</title> <style> body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; text-align: center; } .upload-area { border: 2px dashed #ccc; padding: 30px; margin: 20px 0; cursor: pointer; } #preview { max-width: 300px; max-height: 300px; margin: 20px auto; display: none; } .result { font-size: 24px; font-weight: bold; margin: 20px 0; padding: 15px; border-radius: 5px; } .recyclable { background-color: #4CAF50; color: white; } .kitchen { background-color: #FFC107; color: black; } .hazardous { background-color: #F44336; color: white; } .other { background-color: #9E9E9E; color: white; } </style> </head> <body> <h1>垃圾分类AI演示系统</h1> <p>上传垃圾图片,AI会自动识别其分类</p> <div class="upload-area" id="uploadArea"> <p>点击或拖拽图片到此处</p> <input type="file" id="fileInput" accept="image/*" style="display: none;"> </div> <img id="preview" alt="预览图"> <div class="result" id="result"></div> <script> const uploadArea = document.getElementById('uploadArea'); const fileInput = document.getElementById('fileInput'); const preview = document.getElementById('preview'); const resultDiv = document.getElementById('result'); uploadArea.addEventListener('click', () => fileInput.click()); uploadArea.addEventListener('dragover', (e) => { e.preventDefault(); uploadArea.style.borderColor = '#666'; }); uploadArea.addEventListener('dragleave', () => { uploadArea.style.borderColor = '#ccc'; }); uploadArea.addEventListener('drop', (e) => { e.preventDefault(); uploadArea.style.borderColor = '#ccc'; if (e.dataTransfer.files.length) { fileInput.files = e.dataTransfer.files; handleFileSelect(); } }); fileInput.addEventListener('change', handleFileSelect); function handleFileSelect() { const file = fileInput.files[0]; if (!file) return; const reader = new FileReader(); reader.onload = (e) => { preview.src = e.target.result; preview.style.display = 'block'; classifyImage(file); }; reader.readAsDataURL(file); } function classifyImage(file) { resultDiv.textContent = '识别中...'; resultDiv.className = 'result'; const formData = new FormData(); formData.append('file', file); fetch('/classify', { method: 'POST', body: formData }) .then(response => response.json()) .then(data => { if (data.error) { resultDiv.textContent = '错误: ' + data.error; return; } resultDiv.textContent = '分类结果: ' + data.class_name; resultDiv.className = 'result ' + ['recyclable', 'kitchen', 'hazardous', 'other'][data.class_id]; }) .catch(error => { resultDiv.textContent = '错误: ' + error.message; }); } </script> </body> </html>启动Web服务:
python app.py现在,访问 http://你的服务器IP:5000 就能看到垃圾分类演示系统了。你可以上传垃圾图片,系统会实时返回分类结果。
5. 常见问题与优化技巧
在实际部署过程中,你可能会遇到以下问题:
- 模型准确率不高
- 增加训练数据量,特别是样本较少的类别
- 调整数据增强策略,增加旋转、颜色抖动等
尝试更长的训练时间或更小的学习率
Web界面响应慢
- 使用gunicorn等WSGI服务器替代Flask开发服务器
- 启用模型的多线程推理
对输入图片进行压缩,减少传输和处理时间
特定类别识别错误
- 收集更多错误样本进行针对性训练
- 调整类别权重,解决数据不平衡问题
尝试不同的模型结构,如ResNet34或EfficientNet
部署到生产环境
- 使用Docker容器化应用
- 添加API密钥认证
- 实现批处理功能,提高吞吐量
总结
通过本文的指导,我们完成了从零开始搭建一个垃圾分类AI演示系统的全过程。让我们回顾一下关键要点:
- 云端GPU加速开发:使用预置镜像快速搭建PyTorch环境,省去繁琐的配置过程
- 迁移学习高效实用:基于ResNet18进行微调,短短30分钟训练就能获得不错的效果
- 演示系统快速搭建:用Flask构建的Web界面简洁直观,适合展会等演示场景
- 完整流程可复制:从数据准备到模型训练再到部署,每个步骤都有详细说明和代码
- 实际应用价值高:这套系统不仅适用于展会演示,稍加改造就能用于智能垃圾桶等实际场景
现在,你已经掌握了使用ResNet18构建图像分类应用的核心方法。不妨尝试用同样的思路解决其他分类问题,或者进一步优化这个垃圾分类系统的性能。
💡获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。