当前位置: 首页 > news >正文

Pytorch xpu环境配置 Pytorch使用Intel集成显卡

1、硬件集显要为Intel ARC并安装正确驱动

2、安装Intel oneAPI Base Toolkit (https://www.intel.cn/content/www/cn/zh/developer/tools/oneapi/base-toolkit-download.html)安装后大约20G左右,注意安装路径

3、安装Visual Studio Build Tools (Microsoft C++ 生成工具 - Visual Studio)

安装时所有选项默认就行,安装如下组件就行

4、安装xpu版Pytorch 安装后大约6G左右

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu 
# 使用--target=d:\python\lib修改安装路径

5、测试

每次打开CMD窗口都要执行一次setvars.bat文件(oneAPI安装路径\oneAPI\setvars.bat)然后再执行python文件,注意只能在CMD窗口中执行,不能使用PowerShell

import torch

print(torch.xpu.is_available())

一个简单的模型训练例子 

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# 使用cpu时删除所有.to(xpu)和.to(cpu)

plt.rcParams['font.family'] = 'SimHei'
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
xpu = torch.device('xpu') # 使用CPU时可以删除这句
 
# 1. 定义一个简单的神经网络模型
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(2, 2)  # 输入层到隐藏层
        self.fc2 = nn.Linear(2, 1)  # 隐藏层到输出层
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))  # ReLU 激活函数
        x = self.fc2(x)
        return x
 
# 2. 创建模型实例
model = SimpleNN()
model.to(xpu) # 使用CPU时可以删除这句
 
# 3. 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器
 
# 4. 假设我们有训练数据 X 和 Y
X = torch.randn(10, 2, requires_grad=True).to(xpu)  # 10 个样本,2 个特征
Y = torch.randn(10, 1).to(xpu)  # 10 个目标值
print(X,Y)
# 5. 训练循环
losses = []
for epoch in range(500):  # 训练 500 轮
    model.train()  # 设置模型为训练模式
    optimizer.zero_grad()  # 清空之前的梯度
    output = model(X)  # 前向传播
    loss = criterion(output, Y) # 计算损失
    losses.append(loss.item())
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数
 
 
# 可视化预测结果与实际目标值对比
y_pred_final = model(X).detach().to("cpu").numpy()  # 最终预测值
y_actual = Y.to("cpu").numpy()  # 实际值
 
plt.figure(figsize=(8, 5))
plt.plot(range(1, 11), y_actual, 'o-', label='实际值', color='blue')
plt.plot(range(1, 11), y_pred_final, 'x--', label='预测值', color='red')
plt.xlabel('Sample Index')
plt.ylabel('Value')
plt.title('Actual vs Predicted Values')
plt.legend()
plt.grid()
plt.show()

相关文章:

  • 单粒子翻转对FPGA的影响及解决方法
  • windows下安装pipx
  • 【JAVA架构师成长之路】【JVM实战】第2集:生产环境内存飙高排查实战
  • 视频输入设备-V4L2的开发流程简述
  • 交叉编译openssl及curl
  • 【Mac】MacOS系统下常用的开发环境配置2025版
  • 【论文阅读】多模态——LSeg
  • 使用 Elasticsearch 进行集成测试初始化​​数据时的注意事项
  • 9. Flink的性能优化
  • 训练 FLUX LoRA模型安装与部署
  • 高频 SQL 50 题(基础版)| 高级字符串函数 / 正则表达式 / 子句:1667. 修复表中的名字、1527. 患某种疾病的患者、196. 删除重复的电子邮箱、176. 第二高的薪水、...
  • 【UI自动化实现思路第二章】OCR 图片文字识别方法
  • NO2.C++语言基础|C++和Java|常量|重载重写重定义|构造函数|强制转换|指针和引用|野指针和悬空指针|const修饰指针|函数指针(C++)
  • 算法提升第一章:基础算法总结
  • 【JAVA架构师成长之路】【JVM实战】第1集:生产环境CPU飙高排查实战
  • DeepSeek本地调用,集成到自己的平台中,做二次集成
  • 2025-03-06 学习记录--C/C++-C 库函数 - strcat()、strncpy()
  • 【每日学点HarmonyOS Next知识】Web上传文件、监听上下左右区域连续点击、折叠悬停、字符串相关、播放沙盒视频
  • 微服务架构下的 Node.js
  • [项目]基于FreeRTOS的STM32四轴飞行器: 四.LED控制
  • 网站做境外第三方支付/外贸平台有哪些?
  • 有没有做外贸免费网站/百度竞价是什么意思?
  • 做国际黄金的网站/青岛app开发公司
  • 企业建立网站账户如何做/seo技术分享免费咨询
  • 微网站如何做微信支付宝支付宝/学seo网络推广
  • 深圳专业企业网站制作/百度广告商