如何采用U-Net作为基础模型训练使用水体分割遥感图像数据集_使用深度学习模型来进行水体分割的遥感图像数据集 图像分割任务
文章目录
- 数据准备
- 模型定义
- 训练过程
- 模型优化
- 推理及可视化
水体分割遥感图像数据集
2841张卫星拍摄的水体图像集合,每张mask标签,其中白色代表水,黑色代表水以外的其他东西。
1
1
针对水体分割的遥感图像数据集,我们可以使用深度学习模型来进行图像分割任务。采用U-Net作为基础模型,适用于遥感图像的分割任务。以下是完整的流程,包括数据准备、模型定义、训练过程、评估和推理及可视化。
数据准备
首先需要定义一个自定义的数据集类来加载和预处理你的数据集。假设图像和mask标签分别存储在两个文件夹中,每张图像都有对应的mask。
importtorchfromtorch.utils.dataimportDataset,DataLoaderfromtorchvisionimporttransformsfromPILimportImageimportosclassWaterBodySegmentationDataset(Dataset):def__init__(self,image_dir,mask_dir,transform=None):self.image_dir=image_dir self.mask_dir=mask_dir self.transform=transform self.images=os.listdir(image_dir)def__len__(self):returnlen(self.images)def__getitem__(self,idx):img_path=os.path.join(self.image_dir,self.images[idx])mask_path=os.path.join(self.mask_dir,self.images[idx])# 假设图像和mask同名但位于不同目录image=Image.open(img_path).convert("RGB")mask=Image.open(mask_path).convert("L")# 确保mask是灰度图ifself.transform:image=self.transform(image)mask=self.transform(mask)returnimage,mask# 数据转换transform=transforms.Compose([transforms.Resize((400,400)),# 根据实际情况调整大小transforms.ToTensor(),])train_dataset=WaterBodySegmentationDataset(image_dir='path_to_train_images',mask_dir='path_to_train_masks',transform=transform)val_dataset=WaterBodySegmentationDataset(image_dir='path_to_val_images',mask_dir='path_to_val_masks',transform=transform)train_loader=DataLoader(train_dataset,batch_size=8,shuffle=True)val_loader=DataLoader(val_dataset,batch_size=8,shuffle=False)模型定义
接下来,我们定义一个U-Net模型结构用于图像分割任务。
importtorch.nnasnnimporttorch.nn.functionalasFclassUNet(nn.Module):def__init__(self):super(UNet,self).__init__()self.enc1=self.conv_block(3,64)self.enc2=self.conv_block(64,128)self.enc3=self.conv_block(128,256)self.enc4=self.conv_block(256,512)self.pool=nn.MaxPool2d(2)self.upconv3=nn.ConvTranspose2d(512,256,kernel_size=2,stride=2)self.dec3=self.conv_block(512,256)self.upconv2=nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)self.dec2=self.conv_block(256,128)self.upconv1=nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)self.dec1=self.conv_block(128,64)self.out_conv=nn.Conv2d(64,1,kernel_size=1)# 输出层,二分类问题defconv_block(self,in_channels,out_channels):returnnn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1),nn.ReLU(inplace=True),nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),nn.ReLU(inplace=True))defforward(self,x):enc1=self.enc1(x)enc2=self.enc2(self.pool(enc1))enc3=self.enc3(self.pool(enc2))enc4=self.enc4(self.pool(enc3))dec3=self.upconv3(enc4)dec3=torch.cat((dec3,enc3),dim=1)dec3=self.dec3(dec3)dec2=self.upconv2(dec3)dec2=torch.cat((dec2,enc2),dim=1)dec2=self.dec2(dec2)dec1=self.upconv1(dec2)dec1=torch.cat((dec1,enc1),dim=1)dec1=self.dec1(dec1)returntorch.sigmoid(self.out_conv(dec1))# 使用sigmoid函数输出概率训练过程
定义训练循环:
model=UNet()criterion=nn.BCELoss()# 二元交叉熵损失函数optimizer=torch.optim.Adam(model.parameters(),lr=0.001)num_epochs=20forepochinrange(num_epochs):model.train()running_loss=0.0forimages,masksintrain_loader:optimizer.zero_grad()outputs=model(images)loss=criterion(outputs,masks)loss.backward()optimizer.step()running_loss+=loss.item()*images.size(0)print(f"Epoch{epoch+1}/{num_epochs}, Loss:{running_loss/len(train_loader.dataset)}")模型优化
可以考虑使用学习率调度器和早停策略来优化模型性能:
fromtorch.optim.lr_schedulerimportReduceLROnPlateau scheduler=ReduceLROnPlateau(optimizer,'min',patience=5)defvalidate(model,val_loader,criterion):model.eval()val_loss=0.0withtorch.no_grad():forimages,masksinval_loader:outputs=model(images)loss=criterion(outputs,masks)val_loss+=loss.item()*images.size(0)returnval_loss/len(val_loader.dataset)forepochinrange(num_epochs):# ... 训练过程 ...val_loss=validate(model,val_loader,criterion)scheduler.step(val_loss)print(f"Validation Loss:{val_loss}")推理及可视化
推理并可视化结果:
importmatplotlib.pyplotaspltdefvisualize_predictions(model,dataloader,num_images=5):model.eval()withtorch.no_grad():fori,(images,masks)inenumerate(dataloader):ifi>=num_images:breakoutputs=model(images)preds=(outputs>0.5).float()# 阈值为0.5fig,axarr=plt.subplots(1,3)axarr[0].imshow(images[0].permute(1,2,0).numpy())# 显示原始图像axarr[1].imshow(masks[0].squeeze().numpy(),cmap='gray')# 显示真实标签axarr[2].imshow(preds[0].squeeze().numpy(),cmap='gray')# 显示预测结果plt.show()visualize_predictions(model,val_loader)通过上述步骤,您可以有效地利用水体分割的遥感图像数据集进行水体检测任务。请根据实际情况调整代码中的细节-