@浙大疏锦行
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorchvisionimportdatasets,transformsfromtorch.utils.dataimportDataLoaderimportmatplotlib.pyplotaspltimportnumpyasnp plt.rcParams["font.family"]=["SimHei"]plt.rcParams['axes.unicode_minus']=False# 解决负号显示问题device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")print(f"使用设备:{device}")train_transform=transforms.Compose([transforms.RandomCrop(32,padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2,contrast=0.2,saturation=0.2,hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),])test_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010))])train_dataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform)test_dataset=datasets.CIFAR10(root='./data',train=False,transform=test_transform)batch_size=64train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)由于本人的cuda不可用,因此使用cpu进行计算
classCNN(nn.Module):def__init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,padding=1)self.bn1=nn.BatchNorm2d(num_features=32)self.relu1=nn.ReLU()self.pool1=nn.MaxPool2d(kernel_size=2,stride=2)self.conv2=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,padding=1)self.bn2=nn.BatchNorm2d(num_features=64)self.relu2=nn.ReLU()self.pool2=nn.MaxPool2d(kernel_size=2)self.conv3=nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1)self.bn3=nn.BatchNorm2d(num_features=128)self.relu3=nn.ReLU()self.pool3=nn.MaxPool2d(kernel_size=2)self.fc1=nn.Linear(in_features=128*4*4,out_features=512)self.dropout=nn.Dropout(p=0.5)self.fc2=nn.Linear(in_features=512,out_features=10)defforward(self,x):x=self.conv1(x)x=self.bn1(x)x=self.relu1(x)x=self.pool1(x)x=self.conv2(x)x=self.bn2(x)x=self.relu(x)x=self.pool3(x)x=x.view(-1,128*4*4)x=self.fc1(x)x=self.relu3(x)x=self.dropout(x)x=self.fc2(x)returnx model=CNN()model=model.to(device)criterion=nn.CrossEntropyLoss()optimizer=optim.Adam(model.parameters(),lr=0.001)scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=3,factor=0.5)deftrain(model,train_loader,test_loader,criterion,optimizer,scheduler,device,epochs):model.train()all_iter_losses=[]iter_indices=[]train_acc_history=[]test_acc_history=[]train_loss_history=[]test_loss_history=[]forepochinrange(epochs):running_loss=0.0correct=0total=0forbatch_idx,(data,target)inenumerate(train_loader):data,target=data.to(device),target.to(device)optimizer.zero_grad()output=model(data)loss=criterion(output,target)loss.backward()optimizer.step()iter_loss=loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch*len(train_loader))running_loss+=iter_loss _,predicted=output.max(1)total+=target.size(0)correct+=predicted.eq(target).sum().item()if(batch_idx+1)%100==0:print(f'Epoch:{epoch+1}/{epochs}| Batch:{batch_idx+1}/{len(train_loader)}'f'| 单Batch损失:{iter_loss:.4f}| 累计平均损失:{running_loss/(batch_idx+1):.4f}')epoch_train_loss=running_loss/len(train_loader)epoch_train_acc=100.*correct/total train_acc_history.append(epoch_train_acc)train_loss_history.append(epoch_train_loss)model.eval()test_loss=0correct_test=0total_test=0withtorch.no_grad():fordata,targetintest_loader:data,target=data.to(device),target.to(device)output=model(data)test_loss+=criterion(output,target).item()_,predicted=output.max(1)total_test+=target.size(0)correct_test+=predicted.eq(target).sum().item()epoch_test_loss=test_loss/len(test_loader)epoch_test_acc=100.*correct_test/total_test test_acc_history.append(epoch_test_acc)test_loss_history.append(epoch_test_loss)scheduler.step(epoch_test_loss)print(f'Epoch{epoch+1}/{epochs}完成 | 训练准确率:{epoch_train_acc:.2f}% | 测试准确率:{epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses,iter_indices)plot_epoch_metrics(train_acc_history,test_acc_history,train_loss_history,test_loss_history)returnepoch_test_accdefplot_iter_losses(losses,indices):plt.figure(figsize=(10,4))plt.plot(indices,losses,'b-',alpha=0.7,label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()defplot_epoch_metrics(train_acc,test_acc,train_loss,test_loss):epochs=range(1,len(train_acc)+1)plt.figure(figsize=(12,4))plt.subplot(1,2,1)plt.plot(epochs,train_acc,'b-',label='训练准确率')plt.plot(epochs,test_acc,'r-',label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('训练和测试准确率')plt.legend()plt.grid(True)plt.subplot(1,2,2)plt.plot(epochs,train_loss,'b-',label='训练损失')plt.plot(epochs,test_loss,'r-',label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('训练和测试损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()epochs=20print("开始使用CNN训练模型...")final_accuracy=train(model,train_loader,test_loader,criterion,optimizer,scheduler,device,epochs)print(f"训练完成!最终测试准确率:{final_accuracy:.2f}%")