当前位置: 首页 > news >正文

YOLOV8添加ASPP改进

1.v8本身的模块存放在nn文件夹下。

2.在nn文件夹下新建一个ASPP.py文件,将新添加的模块写入其中。

3.把下面这段代码复制进去


import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

# without BN version
class ASPP(nn.Module):
    def __init__(self, in_channel=512, out_channel=256):
        super(ASPP, self).__init__()
        self.mean = nn.AdaptiveAvgPool2d((1, 1))  # (1,1)means ouput_dim
        self.conv = nn.Conv2d(in_channel,out_channel, 1, 1)
        self.atrous_block1 = nn.Conv2d(in_channel, out_channel, 1, 1)
        self.atrous_block6 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=6, dilation=6)
        self.atrous_block12 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=12, dilation=12)
        self.atrous_block18 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=18, dilation=18)
        self.conv_1x1_output = nn.Conv2d(out_channel * 5, out_channel, 1, 1)

    def forward(self, x):
        size = x.shape[2:]

        image_features = self.mean(x)
        image_features = self.conv(image_features)
        image_features = F.upsample(image_features, size=size, mode='bilinear')

        atrous_block1 = self.atrous_block1(x)
        atrous_block6 = self.atrous_block6(x)
        atrous_block12 = self.atrous_block12(x)
        atrous_block18 = self.atrous_block18(x)

        net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6,
                                              atrous_block12, atrous_block18], dim=1))
        return net

import torch
import torch.nn.functional as F
import torch.nn as nn

class ASPP(nn.Module):
    def __init__(self, in_channel=512, out_channel=256):
        super(ASPP, self).__init__()
        self.mean = nn.AdaptiveAvgPool2d((1, 1))  # (1,1)means ouput_dim
        self.conv = nn.Conv2d(in_channel, out_channel, 1, 1)
        self.atrous_block1 = nn.Conv2d(in_channel, out_channel, 1, 1)
        self.atrous_block6 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=6, dilation=6)
        self.atrous_block12 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=12, dilation=12)
        self.atrous_block18 = nn.Conv2d(in_channel, out_channel, 3, 1, padding=18, dilation=18)
        self.conv_1x1_output = nn.Conv2d(out_channel * 5, out_channel, 1, 1)

    def forward(self, x):
        size = x.shape[2:]

        image_features = self.mean(x)
        image_features = self.conv(image_features)
        image_features = F.upsample(image_features, size=size, mode='bilinear')

        atrous_block1 = self.atrous_block1(x)
        atrous_block6 = self.atrous_block6(x)
        atrous_block12 = self.atrous_block12(x)
        atrous_block18 = self.atrous_block18(x)

        net = self.conv_1x1_output(torch.cat([image_features, atrous_block1, atrous_block6,
                                              atrous_block12, atrous_block18], dim=1))
        return net


if __name__ == '__main__':
    x = torch.randn(1, 256, 16, 16)
    model = ASPP(256, 256)
    print(model(x).shape)

SPPFCSPC模块使用以下代码 

import torch
import torch.nn.functional as F
import torch.nn as nn

####### SPPFCSPC #####
class SPPFCSPC(nn.Module):

    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=5):
        super(SPPFCSPC, self).__init__()
        c_ = int(2 * c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(c_, c_, 3, 1)
        self.cv4 = Conv(c_, c_, 1, 1)
        self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
        self.cv5 = Conv(4 * c_, c_, 1, 1)
        self.cv6 = Conv(c_, c_, 3, 1)
        self.cv7 = Conv(2 * c_, c2, 1, 1)

    def forward(self, x):
        x1 = self.cv4(self.cv3(self.cv1(x)))
        x2 = self.m(x1)
        x3 = self.m(x2)
        y1 = self.cv6(self.cv5(torch.cat((x1, x2, x3, self.m(x3)), 1)))
        y2 = self.cv2(x)
        return self.cv7(torch.cat((y1, y2), dim=1))
    ####### end of SPPFCSPC #####

if __name__ == '__main__':
    x = torch.randn(1, 256, 16, 16)
    model = SPPFCSPC(256, 256)
    print(model(x).shape)

4.引入配置好的模块环境。

5.复制一份yolov8.yaml文件,修改模型参数

6. yolov8-ASPP.yaml文件

6.修改配置代码

相关文章:

  • Pyhon第五章01:函数的定义和练习
  • Qt 控件概述 QWdiget 1.1
  • 运维面试题(四)
  • C++|范围for
  • OpenCV基础知识
  • 分类操作-06.根据id删除分类
  • JS基础部分
  • 奇安信二面
  • 北京大学第六弹:《DeepSeek应用场景中需要关注的十个安全问题和防范措施》
  • 【论文阅读】Adversarial Patch Attacks on Monocular Depth Estimation Networks
  • 硬件地址反序?用位操作为LED灯序“纠偏”。反转二进制数即可解决
  • TCP/IP协议中三次握手(Three-way Handshake)与四次挥手(Four-way Wave)
  • 2025年跨网文件交换系统推荐:安全的内外网文件传输系统Top10
  • 01-1 音视频知识学习(音频)
  • 【Java代码审计 | 第十四篇】MVC模型、项目结构、依赖管理及配置文件概念详解
  • 九、Prometheus 监控windows(外部)主机
  • How To Change Windows VPS Password
  • 【k8s001】K8s架构浅析
  • 网页制作16-Javascipt时间特效の设置D-DAY倒计时
  • 基于KL-ISODATA算法的电力负荷数据场景聚类matlab仿真
  • 哈马斯:愿与以色列达成为期5年的停火协议
  • 甘肃公布校园食品安全专项整治案例,有食堂涉腐败变质食物
  • 对谈|“对工作说不”是不接地气吗?
  • 美国证实加拿大及墨西哥汽车零部件免关税
  • 扬州市中医院“药膳面包”走红,内含党参、黄芪等中药材
  • 大学2025丨对话深大人工智能学院负责人李坚强:产学研生态比“造天才”更重要