当前位置: 首页 > news >正文

基础分割模型U-Net

数据集carvana:https://www.kaggle.com/competitions/carvana-image-masking-challenge/data

import os
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.optim as optim
import totch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
form torchvision import transforms
import PIL
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as pltclass DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, mid_channels=None):super().__init__()if not mid_channels:mid_channels = out_channelsself.double_conv = nn.Sequential(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(),)def forward(self, x):return self.double_conv(x)class Down(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2),DoubleConv(in_channels, out_channels))def forward(self, x):return self.maxpool_conv(x)class Up(nn.Module):def __init__(self, in_channels, out_channels, bilinear=True):super().__init__()if bilinear:self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels//2)else:self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1, x2):x1 = self.up(x1)diffX = x2.size()[2] - x1.size()[2]diffY = x2.size()[3] - x1.size()[3]x1 = F.pad(x1, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])x = torch.cat([x2, x1], dim=1)return self.conv(x)class OutConv(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv(x)class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=True):super().__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)factor = 2 if bilinear else 1self.down4 = Down(512, 1024//factor)self.up1 = Up(1024, 512//factor, bilinear)self.up2 = Up(512, 256//factor, bilinear)self.up3 = Up(256, 128//factor, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x)x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsclass CarvanaDataset(Dataset):def __init__(self, base_dir, idx_list, mode="train", transform=None):super().__init__()self.base_dir = base_dirself.idx_list = idx_listself.images = os.listdir(base_dir+"train")self.masks = os.listdir(base_dir+"train_masks")self.mode = modeself.transform = transformdef __len__(self):return len(self.idx_list)def __getitem__(self, idx):image_file = self.images[self.idx_list[idx]]mask_file = image_file[:-4] + "_mask.gif"image = PIL.Image.open(os.path.join(self.base_dir, "train", image_file))if self.mode == "train":mask = PIL.Image.open(os.path.join(self.base_dir, "train_masks", mask_file))if self.transform is not None:image = self.transform(image)mask = self.transform(mask)mask[mask!=0] = 1.0return image, mask.float()else:if self.transform is not None:image = self.transform(image)return imagedef dice_coeff(pred, target):# 评价指标eps = 1e-4num = pred.size(0)m1 = pred.view(num, -1)m2 = target.view(num, -1)intersection = (m1 * m2).sum()return (2. * intersection + eps) / (m1.sum() + m2.sum() + eps)class DiceLoss(nn.Module):# 分割模型常用dice系数作为损失函数,这里是自定义对应的损失函数def __init__(self, weight=None, size=average=True):super().__init__()def forward(self, inputs, targets, smooth=1):inputs = torch.sigmoid(inputs)inputs = inputs.view(-1)targets = targets.view(-1)intersection = (inputs * targets).num()dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)return 1 - diceif __name__ == "__main__":batch_size = 16num_works = 4epochs = 10lr = 1e-3img_size = 256weight_decay = 1e-8interval = 50device = torch.device("cuda:0")base_dir = "./carvana"transform = transforms.Compose([transforms.Resize((img_size, img_size)),transforms.ToTensor(),])train_idxs, val_idxs = train_test_split(range(len(os.listdir(base_dir+"train_masks"))),test_size=0.3,)train_data = CarvanaDataset(base_dir, train_idxs, transform=transform)val_data = CarvanaDataset(base_dir, val_idxs, transform=transform)train_loader = DataLoader(train_data, batch_size=batch_size, num_workders=num_works, shuffle=True,)val_loader = DataLoader(val_data, batch_size=batch_size, num_workders=num_works, shuffle=False,)image, mask = next(iter(train_loader))plt.imsave("tmp_check.jpg", image[0][0])plt.imsave("tmp_mask.jpg", mask[0][0], camp="gray")model = UNet(3, 1)model = model.cuda(device)criterion = nn.BCEWithLogitsLoss()# criterion = DiceLoss() # 使用自定义的损失函数optimizer = optim.Adam(model.paramters(), lr=lr, weight_decay=weight_decay)# 规划器,动态调整学习率 每个epoch变为原来的0.8倍scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.8)def train():model.train(epoch, epochs, interval)train_loss = 0for i, (data, mask) in enumerate(train_loader):data, mask = data.to(device), mask.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, mask)loss.backward()optimizer.setp()train_loss += loss.item() * data.size(0)if (i+1) % interval == 0:print("loader({}/{}): lr:{:.7f} \ttrain_loss: {:.4f}".format(i + 1,len(train_loader),optimizer.state_dict()["param_groups"][0]["lr"],(train_loss / ((i+1)*train_loader.batch_size))),)train_loss = train_loss / len(train_loader.dataset)print("Epoch({}/{}): lr:{:.7f} \ttrain_loss: {:.4f}".format(epoch, epochs, lr, train_loss), end="")def val(epoch):model.eval()val_loss = 0dice_score = 0with torch.no_grad():for data, mask in val_loader:data, mask = data.to(device), mask.to(device)output = model(data)loss = criterion(output, mask)val_loss += loss.item() * data.size(0)dice_score += dice_coeff(torch.sigmoid(output).cpu(), mask.cpu()*data.size(0))val_loss = val_loss / len(val_loader.dataset)dice_score = dice_score / len(val_loader)print(" \tval_loss: {:.4f} \tdice_score: {:.4f}".format(val_loss, dice_score))return dice_scorebest_dice = 0for epoch in range(1, epochs + 1):train_loss = train(epoch, epochs, interval)dice_score = val(epoch)scheduler.step()	# 动态调整学习率if dice_score > best_dice:torch.save(model, "UNet_best.pth")torch.sava(model, "UNet_last.pth")print("best dice score:", best_dice)# 修改模型层import copymodel1 = copy.deepcopy(model)x = torch.rand(1, 3, 224, 224)out = model(x)print(out.shape)model1.outc = OutConv(64, 5)out1 = model1(x)print(out1.shape)# 保存整个模型torch.save(model, "UNet.pth")# 保存模型权重,同时适用于多卡的情况torch.save(model.state_dict(), "UNet2.pth")# 冻结最后一层的参数,让其不进行梯度回传,适用于模型微调model.outc.conv.weight.requires_grad = Falsemodel.outc.conv.bias.requires_grad = Falsefor layer, param in model.named_parameters():print(layer, "\t", param.requires_grad)

文章转载自:

http://eq9e9V57.jzpxj.cn
http://jWMoed80.jzpxj.cn
http://pQIb4MI3.jzpxj.cn
http://syE8BUzB.jzpxj.cn
http://EPWd3kWX.jzpxj.cn
http://q11KiONV.jzpxj.cn
http://pd0EDXzI.jzpxj.cn
http://mWv3eye0.jzpxj.cn
http://KXKdObyT.jzpxj.cn
http://CFJm5P81.jzpxj.cn
http://ROhGylZK.jzpxj.cn
http://Jltvbwl5.jzpxj.cn
http://c0kO7vmb.jzpxj.cn
http://mf3tSfaf.jzpxj.cn
http://d1SPCIjX.jzpxj.cn
http://iHm32Pq6.jzpxj.cn
http://NSBg5fUd.jzpxj.cn
http://Eh69i13N.jzpxj.cn
http://7kPZGBOC.jzpxj.cn
http://QERfLU6V.jzpxj.cn
http://K1C0sYaC.jzpxj.cn
http://0Vrd6zqn.jzpxj.cn
http://NPMnoou2.jzpxj.cn
http://F1tVpcvV.jzpxj.cn
http://ivLoUDzM.jzpxj.cn
http://cHMy59vv.jzpxj.cn
http://1htVUxiC.jzpxj.cn
http://3EWR5iti.jzpxj.cn
http://VS1FoJxp.jzpxj.cn
http://HpVXjMkS.jzpxj.cn
http://www.dtcms.com/a/386841.html

相关文章:

  • LeetCode:8.无重复字符的最长字串
  • 卷积神经网络搭建实战(一)——torch云端的MNIST手写数字识别(全解一)
  • 实验四 Cache 3种不同的地址映射机制(仿真)
  • 北航计算机保研机试题+解答
  • Python Flask快速入门
  • AirPodsDesktop,一个AirPods 桌面助手
  • Java 调用 C++ 动态库(DLL)完整实践:有图像有实体处理场景
  • 教育行业智慧文档平台:构建安全合规、高效协同的教学研究与资源共享解决方案
  • 网编day7(网络词典)(部分)
  • CodeBuddy AI 深度体验:模型怎么选不踩坑?
  • MQ高级.
  • 46.Mysql基础及案例
  • 贪心算法应用:文件合并问题详解
  • 什么是“孤块”?
  • 神卓N600 公网盒子公网访问群晖NAS绿联飞牛
  • 浅谈背包DP(C++实现,配合lc经典习题讲解)
  • 虚拟化嵌套支持在云服务器容器化Hyper-V环境的配置标准
  • 修改el-checkbox默认颜色
  • ROS接口信息整理
  • 【C++11】lambda匿名函数、包装器、新的类功能
  • 【Linux系统】深入理解线程,互斥及其原理
  • 1. C++ 中的 C
  • 探讨基于国产化架构的非结构化数据管理平台建设路径与实践
  • C++11移动语义
  • 代码随想录第14天| 翻转、对称与深度
  • 算法改进篇 | 改进 YOLOv12 的水面垃圾检测方法
  • 一个我自己研发的支持k-th路径查询的数据结构-owl tree
  • 首款“MODA”游戏《秘境战盟》将在Steam 新品节中开放公开试玩
  • ε-δ语言(Epsilon–Delta 语言)
  • QCA9882 Module with IPQ4019 Mainboard High-Performance Mesh Solution