当前位置: 首页 > news >正文

PyTorch数据处理工具箱(可视化工具)

可视化工具

Tensorboard是Google TensorFlow的可视化工具,它可以记录训练数据、评估数据、
网络结构、图像等,并且可以在web上展示,对于观察神经网络训练的过程非常有帮助。
PyTorch可以采用tensorboard_logger、visdom等可视化工具,但这些方法比较复杂或不够
友好。为解决这一问题,人们推出了可用于PyTorch可视化的新的更强大的工具——
tensorboardX。

tensorboardX简介

tensorboardX功能很强大,支持scalar、image、figure、histogram、audio、text、
graph、onnx_graph、embedding、pr_curve and videosummaries等可视化方式。

安装也比较方便,先安装tensorflow(CPU或GPU版),然后安装tensorboardX,在命
令行运行以下命令即可。

pip install tensorboardX

使用tensorboardX的一般步骤如下所示。
1)导入tensorboardX,实例化SummaryWriter类,指明记录日志路径等信息。

from tensorboardX import SummaryWriter
#实例化SummaryWriter,并指明日志存放路径。在当前目录没有logs目录将自动创建。
writer = SummaryWriter(log_dir='logs')
#调用实例
writer.add_xxx()
#关闭writer
writer.close()

【说明】
①如果是Windows环境,log_dir注意路径解析,如:

writer = SummaryWriter(log_dir=r'D:\myboard\test\logs')

②SummaryWriter的格式为:

SummaryWriter(log_dir=None, comment='', **kwargs)
#其中comment在文件命名加上comment后缀

③如果不写log_dir,系统将在当前目录创建一个runs的目录。
2)调用相应的API接口,接口一般格式为:

add_xxx(tag-name, object, iteration-number)
#即add_xxx(标签,记录的对象,迭代次数)

3)启动tensorboard服务:
cd到logs目录所在的同级目录,在命令行输入如下命令,logdir等式右边可以是相对路
径或绝对路径。

tensorboard --logdir=logs --port 6006
#如果是Windows环境,要注意路径解析,如
#tensorboard --logdir=r'D:\myboard\test\logs' --port 6006

4)web展示。
在浏览器输入:

http://服务器IP或名称:6006 #如果是本机,服务器名称可以使用localhost

便可看到logs目录保存的各种图形,图4-4为示例图。

image
鼠标在图形上移动,还可以看到对应位置具体数据。
有关tensorboardX的更多内容,大家可参考其官
网:https://github.com/lanpa/tensorboardX。

用tensorboardX可视化神经网络

4.4.1节我们介绍了tensorboardX的主要内容,为帮助大家更好地理解,本节我们将介
绍几个实例。实例内容涉及如何使用tensorboardX可视化神经网络模型、可视化损失值、
图像等。

(1)导入需要的模块

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tensorboardX import SummaryWriter

(2)构建神经网络

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
self.bn = nn.BatchNorm2d(20)
def forward(self, x):
x = F.max_pool2d(self.conv1(x), 2)
x = F.relu(x) + F.relu(-x)
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = self.bn(x)
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
x = F.softmax(x, dim=1)
return x

(3)把模型保存为graph

#定义输入
input = torch.rand(32, 1, 28, 28)
#实例化神经网络
model = Net()
#将model保存为graph
with SummaryWriter(log_dir='logs',comment='Net') as w:
w.add_graph(model, (input, ))

完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tensorboardX import SummaryWriterclass Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1=nn.Conv2d(1,10,kernel_size=5)self.conv2=nn.Conv2d(10,20,kernel_size=5)self.conv2_drop=nn.Dropout2d()self.fc1=nn.Linear(320,50)self.fc2=nn.Linear(50,10)self.bn=nn.BatchNorm2d(20)def forward(self,x):x=F.max_pool2d(self.conv1(x),2)x=F.relu(x)+F.relu(-x)x=F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))x=self.bn(x)x=x.view(-1,320)x=F.relu(self.fc1(x))x=F.dropout(x,training=self.training)x=self.fc2(x)x=F.softmax(x,dim=1)return x#定义输入
input=torch.rand(32,1,28,28)
#实例化神经网络
model=Net()
#将model保存为graph
with SummaryWriter(log_dir='logs',comment='Net') as w:w.add_graph(model,(input,))

打开浏览器,结果如图4-5所示。

tensorboardx可视化计算图

用tensorboardX可视化损失值

可视化损失值,需要使用add_scalar函数,这里利用一层全连接神经网络,训练一元
二次函数的参数。

dtype = torch.FloatTensor
writer = SummaryWriter(log_dir='logs',comment='Linear')
np.random.seed(100)
x_train = np.linspace(-1, 1, 100).reshape(100,1)
y_train = 3*np.power(x_train, 2) +2+ 0.2*np.random.rand(x_train.size).reshape(100,1)
model = nn.Linear(input_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for epoch in range(num_epoches):
inputs = torch.from_numpy(x_train).type(dtype)
targets = torch.from_numpy(y_train).type(dtype)
output = model(inputs)
loss = criterion(output, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存loss的数据与epoch数值
writer.add_scalar('训练损失值', loss, epoch)

用tensorboardX可视化特征图

利用tensorboardX对特征图进行可视化,不同卷积层的特征图的抽取程度是不一样
的。
x从cifair10数据集获取,具体请参考第6章pytorch-06-02.ipynb。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import os# 永久解决方案:设置环境变量在程序最开始
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.pool2 = nn.MaxPool2d(2)self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, 10)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = self.fc2(x)return xdef visualize_feature_maps():# 初始化net = Net()x = torch.randn(2, 1, 28, 28)  # 使用更合理的输入范围x = (x - x.min()) / (x.max() - x.min())  # 归一化到[0,1]# TensorBoard设置writer = SummaryWriter(log_dir='logs/feature_maps_clean')# 可视化输入writer.add_images('input', x, dataformats='NCHW')# 注册hook函数def conv_hook(module, inp, out):out = out.detach()[:4]  # 只取前4个样本for i in range(min(10, out.size(1))):  # 每个层最多显示10个通道writer.add_images(f'{module.__class__.__name__}_channel_{i}',out[:, i:i + 1],dataformats='NCHW')hooks = []for name, layer in net.named_modules():if isinstance(layer, nn.Conv2d):hooks.append(layer.register_forward_hook(conv_hook))# 前向传播with torch.no_grad():net.eval()output = net(x)# 清理hookfor hook in hooks:hook.remove()writer.close()print("可视化完成!请运行 tensorboard --logdir=logs 查看")if __name__ == '__main__':visualize_feature_maps()
http://www.dtcms.com/a/343439.html

相关文章:

  • 嵌入式学习---(网络编程)
  • burpsuite2022.11激活步骤【超详细】
  • [系统架构设计师]通信系统架构设计理论与实践(十七)
  • anaconda+python+pycharm+mysql
  • 项目1总结其三(图片上传功能)
  • 站长导航网站,网址导航网站大全,网址导航网站合集,网址导航网址目录,网址导航网站推荐,欢迎提交收录
  • ICMP 协议分析
  • 从零开发Java坦克大战Ⅱ (下)-- 从单机到联机(完整架构功能实现)
  • PostgreSQL15——管理表空间
  • 基于Matlab的饮料满瓶检测图像处理
  • 宝塔面板深度解析:从快速部署到高效运维的全流程指南
  • 联想电脑使用U盘装机时,开机按F12时无法显示USB设备启动方式
  • 【python】python测试用例模板
  • 智能制造——解读46页大型集团企业MOM系统解决方案【附全文阅读】
  • 同为科技(TOWE)桌面PDU产品系列全方位解读
  • springboot 启动后get请求任意接口地址会跳到登录页
  • Vue.js 中使用 Highcharts 构建响应式图表 - 综合指南
  • unity中实现机械臂自主运动
  • almalinux9.6系统:k8s可选组件安装(2)
  • 部署Qwen2.5-VL-7B-Instruct-GPTQ-Int3
  • 数据结构 -- 链表--双向链表的特点、操作函数
  • EEA架构介绍
  • CH347 USB转JTAG芯片 SVF下载程序
  • pandas扩展:apply自定义函数、分组进阶(五大核心)、透视表
  • C6.0:晶体管放大器的原理与应用(基极偏置篇)
  • 单词记忆-轻松记忆10个实用英语单词(13)
  • 【openGauss】1分钟掌握:openGauss活动会话CPU占用率获取
  • Java获取被nginx代理的emqx客户端真实ip
  • STM32F030/070芯片解密及应用
  • DAY 23|动态规划1