使用Deeplabv3+进行遥感影像土地利用分类
文章目录
- 环境配置
- 数据集
- 代码
- Deeplabv3+
 
 
环境配置
可参考这个:环境配置
数据集
可参考这个:数据集
代码
-  ASPP模块 
 多尺度空洞卷积捕获不同范围的上下文信息
 适合遥感影像中不同大小的地物目标
 包含全局平均池化捕获全局上下文
-  编码器-解码器结构 
 编码器: ResNet backbone提取多层次特征
 解码器: 融合高层语义信息和低层空间细节
Deeplabv3+
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import configclass ASPP(nn.Module):"""ASPP模块 - 多尺度空洞卷积"""def __init__(self, in_channels, out_channels=256):super(ASPP, self).__init__()# 1x1卷积self.conv_1x1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 3x3空洞卷积,rate=6self.conv_3x3_1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 3x3空洞卷积,rate=12self.conv_3x3_2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 3x3空洞卷积,rate=18self.conv_3x3_3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 全局平均池化self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 输出卷积self.conv_out = nn.Sequential(nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Dropout(0.5))def forward(self, x):x1 = self.conv_1x1(x)x2 = self.conv_3x3_1(x)x3 = self.conv_3x3_2(x)x4 = self.conv_3x3_3(x)# 全局平均池化并上采样到原始尺寸x5 = self.global_avg_pool(x)x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)# 拼接所有特征x = torch.cat([x1, x2, x3, x4, x5], dim=1)x = self.conv_out(x)return xclass DeepLabV3Plus(nn.Module):"""Deeplabv3+模型 - 适合遥感影像土地利用分类"""def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):super(DeepLabV3Plus, self).__init__()self.n_channels = n_channelsself.n_classes = n_classes# 选择backboneif backbone == 'resnet50':self.backbone = models.resnet50(pretrained=True)low_level_channels = 256high_level_channels = 2048elif backbone == 'resnet101':self.backbone = models.resnet101(pretrained=True)low_level_channels = 256high_level_channels = 2048else:  # resnet34self.backbone = models.resnet34(pretrained=True)low_level_channels = 64high_level_channels = 512# 修改第一层卷积以适应多波段输入if n_channels != 3:self.backbone.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)# 移除最后的全连接层和平均池化层self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])# ASPP模块self.aspp = ASPP(high_level_channels, 256)# 低层特征处理self.low_level_conv = nn.Sequential(nn.Conv2d(low_level_channels, 48, 1, bias=False),nn.BatchNorm2d(48),nn.ReLU(inplace=True))# 解码器self.decoder_conv = nn.Sequential(nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Dropout(0.1))# 分类器self.classifier = nn.Conv2d(256, n_classes, 1)# 初始化权重self._init_weights()def _init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def forward(self, x):# 获取输入尺寸input_size = x.size()[2:]# Backbone特征提取features = self.backbone(x)# 低层特征(浅层特征,保留更多空间信息)low_level_features = Noneif hasattr(self.backbone, 'layer1'):# 对于标准ResNetlow_level_features = self.backbone.layer1(x)else:# 对于Sequential包装的backbone,需要手动获取中间特征# 这里简化处理,实际使用时可能需要根据具体backbone结构调整low_level_features = features# 高层特征通过ASPPhigh_level_features = self.aspp(features)# 上采样高层特征high_level_features = F.interpolate(high_level_features, scale_factor=4, mode='bilinear', align_corners=True)# 处理低层特征low_level_features = self.low_level_conv(low_level_features)# 拼接高低层特征x = torch.cat([high_level_features, low_level_features], dim=1)# 解码器卷积x = self.decoder_conv(x)# 分类x = self.classifier(x)# 上采样到原始输入尺寸x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)return xclass DeepLabV3PlusWithAuxiliary(DeepLabV3Plus):"""带辅助损失的Deeplabv3+,用于训练时提升性能"""def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):super(DeepLabV3PlusWithAuxiliary, self).__init__(n_channels, n_classes, backbone)# 辅助分类器self.aux_classifier = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Dropout(0.1),nn.Conv2d(256, n_classes, 1))def forward(self, x):input_size = x.size()[2:]features = self.backbone(x)# 低层特征low_level_features = self.backbone.layer1(x) if hasattr(self.backbone, 'layer1') else features# 高层特征通过ASPPhigh_level_features = self.aspp(features)# 辅助输出aux_output = self.aux_classifier(high_level_features)aux_output = F.interpolate(aux_output, size=input_size, mode='bilinear', align_corners=True)# 主分支high_level_features = F.interpolate(high_level_features, scale_factor=4, mode='bilinear', align_corners=True)low_level_features = self.low_level_conv(low_level_features)x = torch.cat([high_level_features, low_level_features], dim=1)x = self.decoder_conv(x)x = self.classifier(x)x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)if self.training:return x, aux_outputelse:return x# 便捷创建函数
def create_deeplabv3plus(model_type='standard', **kwargs):"""创建Deeplabv3+模型Args:model_type: 'standard' 或 'with_auxiliary'**kwargs: 模型参数"""if model_type == 'with_auxiliary':return DeepLabV3PlusWithAuxiliary(**kwargs)else:return DeepLabV3Plus(**kwargs)
训练时参考这个
# 创建模型
model = create_deeplabv3plus(n_channels=config.NUM_BANDS,n_classes=config.NUM_CLASSES,backbone='resnet50'
)# 训练时(如果使用辅助损失版本)
if model.training:output, aux_output = model(x)loss = main_loss(output, target) + 0.4 * aux_loss(aux_output, target)
else:output = model(x)
推理时模型自动返回最终输出(logits),与原有评估 / 预测流程兼容,无需修改eval.py和predict.py。
