当前位置: 首页 > news >正文

【机器学习深度学习】交互式线性回归 demo

目录

一、环境准备

二、Demo 功能

三、完整交互 demo 代码

3.1 执行代码

3.2 示例交互演示

3.3 运行结果

3.4 运行线性图


使用 PyTorch 构建交互式线性回归模型:输入数据、拟合直线、图像可视化并实现实时预测,助你深入理解机器学习从数据到模型的全过程。

一、环境准备

需要你本地能跑 Python + matplotlib + PyTorch,无需其他安装。


二、Demo 功能

  • 你输入一些样本点(比如面积和价格)

  • 模型会学出一条最合适的线

  • 自动训练并画图展示

  • 你输入新数据,它来预测输出!


三、完整交互 demo 代码

PyTorch 的基本训练流程:

输入 → 线性模型建模 → 拟合 → 可视化 → 预测

3.1 执行代码

import torch
import torch.nn as nn
import matplotlib.pyplot as plt# 1. 让你输入数据(面积 → 房价)
print("请输入一些训练数据(输入特征和输出标签),格式如:50 100")
print("输入完后输入空行结束")X_list = []
y_list = []while True:line = input("请输入一组 (x y): ")if not line.strip():breaktry:x, y = map(float, line.strip().split())X_list.append([x])y_list.append([y])except:print("格式错误,请输入两个数字")X = torch.tensor(X_list, dtype=torch.float32)
y = torch.tensor(y_list, dtype=torch.float32)# 2. 定义模型
model = nn.Linear(1, 1)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 3. 训练模型
epochs = 1000
for epoch in range(epochs):optimizer.zero_grad()outputs = model(X)loss = criterion(outputs, y)loss.backward()optimizer.step()# 4. 显示结果
print("\n训练完成!")
w = model.weight.item()
b = model.bias.item()
print(f"模型学到的函数为: y = {w:.2f} * x + {b:.2f}")# 5. 可视化
with torch.no_grad():predicted = model(X).numpy()plt.scatter(X.numpy(), y.numpy(), label='原始数据')
plt.plot(X.numpy(), predicted, 'r-', label='拟合直线')
plt.title(f"y = {w:.2f}x + {b:.2f}")
plt.legend()
plt.grid(True)
plt.show()# 6. 输入新数据做预测
while True:new_x = input("\n输入新的 x(或输入 'q' 退出预测):")if new_x.lower() == 'q':breaktry:x_value = torch.tensor([[float(new_x)]])y_pred = model(x_value).item()print(f"预测值为:y = {y_pred:.2f}")except:print("请输入有效数字")

3.2 示例交互演示

请输入一组 (x y): 50 100
请输入一组 (x y): 60 120
请输入一组 (x y): 70 140
请输入一组 (x y): 80 160
请输入一组 (x y): 

你会看到:

  • 模型学到函数 y = 2.00 * x + 0.00

  • 会弹出一张图,直线穿过这些点

  • 然后你可以输入 90,它预测 180


3.3 运行结果

请输入一些训练数据(输入特征和输出标签),格式如:50 100
输入完后输入空行结束
请输入一组 (x y): 1 20
请输入一组 (x y): 2 40
请输入一组 (x y): 3 60
请输入一组 (x y): 4 80
请输入一组 (x y): 训练完成!
模型学到的函数为: y = 19.88 * x + 0.36输入新的 x(或输入 'q' 退出预测):80
预测值为:y = 1590.59输入新的 x(或输入 'q' 退出预测):5
预测值为:y = 99.75输入新的 x(或输入 'q' 退出预测):q

3.4 运行线性图

相关文章:

  • day48-硬件学习之GPT定时器、UART及I2C
  • 【开源工具】Windows一键配置防火墙阻止策略(禁止应用联网)| 附完整Python源码
  • 事件循环(Event Loop)机制对比:Node.js vs 浏览器​
  • ethers.js express vue2 定时任务每天凌晨2点监听合约地址数据同步到Mysql整理
  • 【CMake基础入门教程】第六课:构建静态库 / 动态库 与安装规则(install)
  • MySQL至KES迁移最佳实践
  • 用 Spark 优化亿级用户画像计算:Delta Lake 增量更新策略详解
  • vue3 json 转 实体
  • 2.1、STM32 CAN外设简介
  • Vue3 中 Axios 深度整合指南:从基础到高级实践引言总结
  • MR30分布式IO:产线改造省时 70%
  • 22. 括号生成
  • AI编程工具深度对比:腾讯云代码助手CodeBuddy、Cursor与通义灵码
  • ubuntu20.04如何给appImage创建快捷方式
  • EXILIUM×亚矩云手机:重构Web3虚拟生存法则,开启多端跨链元宇宙自由征途
  • 【JeecgBoot AIGC】打造智能AI应用
  • 51c~嵌入式~PLC~三菱~合集1
  • 记dwz(JUI)前端框架使用之--服务端响应提示框
  • 如何在x86_64 Linux上部署Android Cuttlefish模拟器运行环境
  • Spring Cloud Feign 整合 Sentinel 实现服务降级与熔断保护