当前位置: 首页 > news >正文

使用Deeplabv3+进行遥感影像土地利用分类

文章目录

    • 环境配置
    • 数据集
    • 代码
      • Deeplabv3+

环境配置

可参考这个:环境配置

数据集

可参考这个:数据集

代码

  1. ASPP模块
    多尺度空洞卷积捕获不同范围的上下文信息
    适合遥感影像中不同大小的地物目标
    包含全局平均池化捕获全局上下文

  2. 编码器-解码器结构
    编码器: ResNet backbone提取多层次特征
    解码器: 融合高层语义信息和低层空间细节

Deeplabv3+

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import configclass ASPP(nn.Module):"""ASPP模块 - 多尺度空洞卷积"""def __init__(self, in_channels, out_channels=256):super(ASPP, self).__init__()# 1x1卷积self.conv_1x1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 3x3空洞卷积,rate=6self.conv_3x3_1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 3x3空洞卷积,rate=12self.conv_3x3_2 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 3x3空洞卷积,rate=18self.conv_3x3_3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 全局平均池化self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))# 输出卷积self.conv_out = nn.Sequential(nn.Conv2d(out_channels * 5, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Dropout(0.5))def forward(self, x):x1 = self.conv_1x1(x)x2 = self.conv_3x3_1(x)x3 = self.conv_3x3_2(x)x4 = self.conv_3x3_3(x)# 全局平均池化并上采样到原始尺寸x5 = self.global_avg_pool(x)x5 = F.interpolate(x5, size=x.size()[2:], mode='bilinear', align_corners=True)# 拼接所有特征x = torch.cat([x1, x2, x3, x4, x5], dim=1)x = self.conv_out(x)return xclass DeepLabV3Plus(nn.Module):"""Deeplabv3+模型 - 适合遥感影像土地利用分类"""def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):super(DeepLabV3Plus, self).__init__()self.n_channels = n_channelsself.n_classes = n_classes# 选择backboneif backbone == 'resnet50':self.backbone = models.resnet50(pretrained=True)low_level_channels = 256high_level_channels = 2048elif backbone == 'resnet101':self.backbone = models.resnet101(pretrained=True)low_level_channels = 256high_level_channels = 2048else:  # resnet34self.backbone = models.resnet34(pretrained=True)low_level_channels = 64high_level_channels = 512# 修改第一层卷积以适应多波段输入if n_channels != 3:self.backbone.conv1 = nn.Conv2d(n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)# 移除最后的全连接层和平均池化层self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])# ASPP模块self.aspp = ASPP(high_level_channels, 256)# 低层特征处理self.low_level_conv = nn.Sequential(nn.Conv2d(low_level_channels, 48, 1, bias=False),nn.BatchNorm2d(48),nn.ReLU(inplace=True))# 解码器self.decoder_conv = nn.Sequential(nn.Conv2d(256 + 48, 256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Dropout(0.1))# 分类器self.classifier = nn.Conv2d(256, n_classes, 1)# 初始化权重self._init_weights()def _init_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def forward(self, x):# 获取输入尺寸input_size = x.size()[2:]# Backbone特征提取features = self.backbone(x)# 低层特征(浅层特征,保留更多空间信息)low_level_features = Noneif hasattr(self.backbone, 'layer1'):# 对于标准ResNetlow_level_features = self.backbone.layer1(x)else:# 对于Sequential包装的backbone,需要手动获取中间特征# 这里简化处理,实际使用时可能需要根据具体backbone结构调整low_level_features = features# 高层特征通过ASPPhigh_level_features = self.aspp(features)# 上采样高层特征high_level_features = F.interpolate(high_level_features, scale_factor=4, mode='bilinear', align_corners=True)# 处理低层特征low_level_features = self.low_level_conv(low_level_features)# 拼接高低层特征x = torch.cat([high_level_features, low_level_features], dim=1)# 解码器卷积x = self.decoder_conv(x)# 分类x = self.classifier(x)# 上采样到原始输入尺寸x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)return xclass DeepLabV3PlusWithAuxiliary(DeepLabV3Plus):"""带辅助损失的Deeplabv3+,用于训练时提升性能"""def __init__(self, n_channels=config.NUM_BANDS, n_classes=config.NUM_CLASSES, backbone='resnet50'):super(DeepLabV3PlusWithAuxiliary, self).__init__(n_channels, n_classes, backbone)# 辅助分类器self.aux_classifier = nn.Sequential(nn.Conv2d(256, 256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Dropout(0.1),nn.Conv2d(256, n_classes, 1))def forward(self, x):input_size = x.size()[2:]features = self.backbone(x)# 低层特征low_level_features = self.backbone.layer1(x) if hasattr(self.backbone, 'layer1') else features# 高层特征通过ASPPhigh_level_features = self.aspp(features)# 辅助输出aux_output = self.aux_classifier(high_level_features)aux_output = F.interpolate(aux_output, size=input_size, mode='bilinear', align_corners=True)# 主分支high_level_features = F.interpolate(high_level_features, scale_factor=4, mode='bilinear', align_corners=True)low_level_features = self.low_level_conv(low_level_features)x = torch.cat([high_level_features, low_level_features], dim=1)x = self.decoder_conv(x)x = self.classifier(x)x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)if self.training:return x, aux_outputelse:return x# 便捷创建函数
def create_deeplabv3plus(model_type='standard', **kwargs):"""创建Deeplabv3+模型Args:model_type: 'standard' 或 'with_auxiliary'**kwargs: 模型参数"""if model_type == 'with_auxiliary':return DeepLabV3PlusWithAuxiliary(**kwargs)else:return DeepLabV3Plus(**kwargs)

训练时参考这个

# 创建模型
model = create_deeplabv3plus(n_channels=config.NUM_BANDS,n_classes=config.NUM_CLASSES,backbone='resnet50'
)# 训练时(如果使用辅助损失版本)
if model.training:output, aux_output = model(x)loss = main_loss(output, target) + 0.4 * aux_loss(aux_output, target)
else:output = model(x)

推理时模型自动返回最终输出(logits),与原有评估 / 预测流程兼容,无需修改eval.py和predict.py。

http://www.dtcms.com/a/545272.html

相关文章:

  • 深度学习之图像分割:从基础概念到核心技术全解析
  • Linux-unzip解压命令的安装与使用
  • 基于深度学习技术实现染色质开放区域的预测与分析系统源代码+数据库,采用Flask + Vue3 实现前后端分离的植物染色质可及性预测系统
  • 7.OpenStack管理(一)
  • Vscode | 突然无法正常连接远程服务器
  • Kubernetes 实战入门:核心资源操作指南
  • 写作网站保底和全勤的区别wordpress 心情评论插件
  • php做购物网站怎么样网站404页面做晚了
  • 电子商务智能建站查询价格的网站
  • 定制线束源头工厂解决方案品牌推荐-力可欣储能线束:为新能源汽车提供持久动力
  • Spring Boot中的JUC并发解析
  • k8s一站式学习
  • 7.1.4 大数据方法论与实践指南-数据服务接口
  • 网安面试题收集(6)
  • 建设网站需要多少钱济南兴田德润o地址济南网站建设加q479185700
  • 站长之家appwordpress添加版权
  • LeetCode每日一题——Pow(x, n)
  • 6.3.2.2 大数据方法论与实践指南-离线任务质量治理
  • 成都php网站制作程序员网站建设公司新报价
  • SODA v9.5.2 甜盐相机,自然美颜相机
  • 【小白笔记】判断一个正整数是否为质数(Prime Number)-循环语句中的else语句
  • 传奇网站一般怎么做的在国外做h网站怎么样
  • Next.js, Node.js, JavaScript, TypeScript 的关系
  • 做一个综合商城网站多少钱合肥seo关键词排名
  • 网站开发与管理对应的职业及岗位优质的seo网站排名优化软件
  • 新人如何学会安装与切换Rust版本:从工具链管理到生产实践
  • 公司网站制作源码wordpress 最快的版本
  • Rust:与JSON、TOML等格式的集成
  • 应用商城发布项目
  • 6.3.3.1 大数据方法论与实践指南-大数据质量度量指标体系