当前位置: 首页 > news >正文

MMRotate ReDet ReFPN 报错 `assert input.type == self.in_type`

在跑实验时,使用 configs/redet/redet_re50_refpn_1x_dota_le90.py,结果报错:

Traceback (most recent call last):File "H:/Workspace/DeepLearning/mmrotate/tools/train.py", line 196, in <module>main()File "H:/Workspace/DeepLearning/mmrotate/tools/train.py", line 183, in maintrain_detector(File "h:\workspace\deeplearning\mmrotate\mmrotate\apis\train.py", line 145, in train_detectorrunner.run(data_loaders, cfg.workflow)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 136, in runepoch_runner(data_loaders[i], **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 53, in trainself.run_iter(data_batch, train_mode=True, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 31, in run_iteroutputs = self.model.train_step(data_batch, self.optimizer,File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\parallel\data_parallel.py", line 77, in train_stepreturn self.module.train_step(*inputs[0], **kwargs[0])File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 248, in train_steplosses = self(**data)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_funcreturn old_func(*args, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 172, in forwardreturn self.forward_train(img, img_metas, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\detectors\two_stage.py", line 127, in forward_trainx = self.extract_feat(img)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\detectors\two_stage.py", line 69, in extract_featx = self.neck(x)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_funcreturn old_func(*args, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 298, in forwardlaterals = [File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 299, in <listcomp>self.lateral_convs[i](inputs[i + self.start_level])File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "h:\workspace\deeplearning\mmrotate\mmrotate\models\necks\re_fpn.py", line 148, in forwardx = self.conv(x)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1194, in _call_implreturn forward_call(*input, **kwargs)File "D:\Environments\Anaconda3\envs\openmmlab\lib\site-packages\e2cnn\nn\modules\r2_conv\r2convolution.py", line 326, in forwardassert input.type == self.in_type
AssertionError

按照以下提示修改 mmrotate/models/necks/re_fpn.py 三处地方,其余地方不变。

# 1. 引入 build_enn_divide_feature 函数
from ..utils import (build_enn_divide_feature,build_enn_feature, build_enn_norm_layer, ennConv,ennInterpolate, ennMaxPool, ennReLU
)class ConvModule(enn.EquivariantModule):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias='auto',conv_cfg=None,norm_cfg=None,activation='relu',inplace=False,order=('conv', 'norm', 'act')):super(ConvModule, self).__init__()assert conv_cfg is None or isinstance(conv_cfg, dict)assert norm_cfg is None or isinstance(norm_cfg, dict)# 2. 用 build_enn_divide_feature 替换 build_enn_featureself.in_type = build_enn_divide_feature(in_channels)self.out_type = build_enn_divide_feature(out_channels)# 后续保持不变...def forward(self, x, activate=True, norm=True):"""Forward function of ConvModule."""# 3. 如果传入的是普通 Tensor,则封装为 GeometricTensorif isinstance(x, torch.Tensor):x = enn.GeometricTensor(x, self.in_type)for layer in self.order:if layer == 'conv':x = self.conv(x)elif layer == 'norm' and norm and self.with_norm:x = self.norm(x)elif layer == 'act' and activate and self.with_activatation:x = self.activate(x)return x
http://www.dtcms.com/a/297207.html

相关文章:

  • Franky — 边缘计算智能语音助手 / Edge‑Computing Smart Voice Assistant
  • 04-netty基础-Reactor三种模型
  • docker compose xtify-music-web
  • 华为OpenStack架构学习9篇 连载—— 02 OpenStack界面管理【附全文阅读】
  • VR 三维重建:重塑建筑工程全生命周期的数字化革命
  • [NLP]多电源域设计的仿真验证方法
  • Redis 5.0.14安装教程
  • Android 10.0 sts CtsSecurityBulletinHostTestCases的相关异常分析
  • 关于自定义域和 GitHub Pages(Windows)
  • OpenCV(04)梯度处理,边缘检测,绘制轮廓,凸包特征检测,轮廓特征查找
  • [python][flask]Flask-Login 使用详解
  • uniapp小程序上传图片并压缩
  • 吊汤:厨房的鲜味密码
  • 若依框架 ---一套快速开发平台
  • STM32-中断配置教程(寄存器版)
  • 【应急响应】进程隐藏技术与检测方式(二)
  • Gin 框架的中间件机制
  • 三种深度学习模型(GRU、CNN-GRU、贝叶斯优化的CNN-GRU/BO-CNN-GRU)对北半球光伏数据进行时间序列预测
  • win11 使用adb 获取安卓系统日志
  • ESP32学习笔记_Peripherals(4)——MCPWM基础使用
  • C++ : list的模拟
  • Kafka——多线程开发消费者实例
  • 使用OpenCV做个图片校正工具
  • 技术演进中的开发沉思-45 DELPHI VCL系列:6种方法
  • 关于新学C++编程Visual Studio 2022开始,使用Cmake工具构建Opencv和SDK在VS里编译项目开发简介笔记
  • RocketMQ常见问题梳理
  • 三、Spark 运行环境部署:全面掌握四种核心模式
  • 【内网穿透】使用FRP实现内网与公网Linux/Ubuntu服务器穿透项目部署多项目穿透方案
  • vue使用xlsx库导出excel
  • 编程语言Java——核心技术篇(三)异常处理详解