当前位置: 首页 > news >正文

神经网络入门—自定义神经网络续集

修改网络

神经网络入门—自定义网络-CSDN博客

修改数据集,y=x^2

# 生成一些示例数据
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]], dtype=torch.float32)
y_train = torch.tensor([[1.0], [4.0], [9.0], [16.0]], dtype=torch.float32)

将预测代码改为,可以接收用户输入并输出

# 加载模型
loaded_model = Net()
loaded_model.load_state_dict(torch.load('model.pth'))
loaded_model.eval()  # 将模型设置为评估模式
while True:
    # 输入新数据进行预测
    num=float(input())
    new_input = torch.tensor([[num]], dtype=torch.float32)
    with torch.no_grad():
        prediction = loaded_model(new_input)
        print(f"输入 {new_input.item()} 的预测结果: {prediction.item()}")

结果

分析

训练数据x为[1.0,2.0,3.0,4.0]

x为3.0和3.5时,测试数据与训练数据较为接近,模型能较为准确预测结果

x为5.0和10.0时,测试数据与训练数据有一定差别,模型预测结果比较不准确

x为-1时,模型预测为负数,实际应为正数,因为我们的训练集没有负数,所以模型没有学到这点

重新设计网络

增加-100-100数据集

# 生成 -100 到 100 范围内的 x
x_train = torch.arange(-100, 101, dtype=torch.float32).unsqueeze(1)
# 计算对应的 y,假设 y 是 x 的平方
y_train = x_train ** 2

Loss收敛慢,网络不能拟合实际函数

即时增加到3000次迭代仍然不能解决问题/(ㄒoㄒ)/~~

问题:

  1. 模型结构过于简单:当前模型仅包含两个全连接层,对于拟合 \(y = x^2\) 这样的非线性函数,可能表达能力不够。可以增加网络的深度和宽度,例如添加更多的隐藏层。
  2. 学习率不合适:学习率太大可能会使训练过程不稳定,太小则会导致收敛速度过慢。可以尝试使用自适应学习率的优化器,如 Adam。
  3. 训练轮数不足:可以适当增加训练轮数,让模型有更多的机会学习数据的特征。

增加网络层数

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # 增加网络的宽度和深度
        self.fc1 = nn.Linear(1, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 20)
        self.fc5 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        x = F.relu(x)
        x = self.fc5(x)
        return x

增加神经元个数

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # 增加网络的宽度和深度
        self.fc1 = nn.Linear(1, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 200)
        self.fc4 = nn.Linear(200, 200)
        self.fc5 = nn.Linear(200, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.relu(x)
        x = self.fc4(x)
        x = F.relu(x)
        x = self.fc5(x)
        return x

Loss波动,疑似出现过拟合

相关文章:

  • 2. 单词个数统计
  • WPS JS宏编程教程(从基础到进阶)-- 第六部分:JS集合与映射在 WPS 的应用
  • 关于使用@Slf4j后引入log,idea标红解决办法
  • Linux | I.MX6ULL外设功能验证(11)
  • FreeRTOS项目工程完善指南:STM32F103C8T6系列
  • 【结合vue源码,分析vue2及vue3的数据绑定实现原理】
  • 【力扣hot100题】(083)零钱兑换
  • Redis 持久化机制详解:RDB/AOF 过程、优缺点及配置。Redis持久化中的Fork与Copy-on-Write技术解析。
  • android studio 2022打开了v1 签名但是生成的apk没有v1签名问题
  • C# 组件的使用方法
  • Python proteinflow 库介绍
  • Java中List方法的使用详解
  • ​​大数据量统计优化方案(日/月/年统计场景)​
  • WORD 中批量将植物拉丁名替换为斜体
  • 淘酒屋(香港)控股助力汾阳白酒国际化:开启中国酒业新征程
  • wsl-docker环境下启动ES报错vm.max_map_count [65530] is too low
  • Easy-Trans 极简数据翻译框架深度实战指南
  • 数据中台、BI业务访谈(二):组织架构梳理的坑
  • 【正点原子】一键锁定IP:STM32MP135 开机就上网!
  • C++ 调试器类 Debugger 的设计与实现
  • 网站建设公司 提成/技能培训班有哪些
  • 泉州住房建设局网站/关键词推广效果分析
  • 数据库网站建设教程/seo公司 引擎
  • 做网站推广用自己维护吗/app推广引流
  • 专业网站建设公司排名/舆情监控系统
  • 对加强政务门户网站建设的意见/html做一个简单的网页