pytorch学习日记
一、pytorch初步认识
1.1 两个实用函数
dir():打开使被看见
help():说明书
1.2 pychrm和jupyter使用对比
1.3 dataset和dataloder
dataset:数据集,定义数据的组织方式和单个样本的获取逻辑。
单样本级别:每次只处理一个样本
数据定义:说明数据在哪里、如何读取
被动工作:等待被调用
class MyDataset(Dataset):def __init__(self):# 1. 初始化:加载数据路径、标签等元信息passdef __getitem__(self, index):# 2. 根据索引返回单个样本 (数据, 标签)return data, labeldef __len__(self):# 3. 返回数据集总大小return total_size
DataLoader:数据加载器,管理数据集的批量加载、 shuffling、并行处理等。
批量级别:处理多个样本组成的批次
流程管理:控制数据加载的流程和策略
主动工作:驱动训练循环
dataloader = DataLoader(dataset, batch_size=32, # 批量大小shuffle=True, # 是否打乱num_workers=4) # 并行进程数for batch_data, batch_labels in dataloader:# 自动获得批量数据,用于训练outputs = model(batch_data)
完整代码示例
import torch
from torch.utils.data import Dataset, DataLoader# 1. 定义 Dataset - 数据如何组织
class MyDataset(Dataset):def __init__(self, data, labels):self.data = dataself.labels = labelsdef __getitem__(self, index):return self.data[index], self.labels[index]def __len__(self):return len(self.data)# 2. 创建数据集实例
dataset = MyDataset(torch.randn(1000, 10), torch.randint(0, 2, (1000,)))# 3. 创建 DataLoader - 如何加载数据
dataloader = DataLoader(dataset, batch_size=32,shuffle=True,num_workers=2)# 4. 训练循环
for epoch in range(10):for batch_data, batch_labels in dataloader: # 自动批量加载# 训练代码...pass
1.4 Dataset类实战
修改文件名的代码
控制台测试
分析:img,lable = bees_dataset[1];
Dataset 是 PyTorch 中的一个抽象基类
MyData 类实际上继承了 Dataset 类,但构造方法 init 有三个参数:
self - 实例自身(Python自动传递)
root_dir - 根目录路径
label_dir - 标签目录名称
继承 Dataset 后,必须实现这两个核心方法:
- getitem 方法
def __getitem__(self, index):# 必须返回 (数据, 标签)return data, label
- len 方法
def __len__(self):# 必须返回数据集大小return length
二、TensorBoard的使用
TensorBoard 是 TensorFlow 提供的一个强大的可视化工具包,主要用于机器学习和深度学习实验的可视化分析。
2.1 SummaryWriter
SummaryWriter是 TensorBoard 的核心类,负责创建日志文件并将各种数据写入到 TensorBoard 可读的格式中。
创建 SummaryWriter
from torch.utils.tensorboard import SummaryWriter# 基本创建
writer = SummaryWriter()# 指定日志目录的创建
writer = SummaryWriter('runs/experiment_1')# 指定注释(会在目录名后追加)
writer = SummaryWriter(comment='_lr_0.01_batch_32')
2.2 SummaryWriter的两个关键函数
记录标量数据writer.add_scalar()
# 记录单个标量
writer.add_scalar('Loss/train', loss_value, epoch)# 记录多个标量(同一图表)
writer.add_scalars('Loss', {'train': train_loss,'val': val_loss
}, epoch)
pycharm中按着ctrl键,鼠标点到函数会变蓝,点击有文档解释。这里tag是图表的标题,scalar_value是纵轴,globalstep是横轴。
测试代码
tensorboard --logdir=logs是tensorboard启动命令
下面的命令是避免默认端口被占用而指定端口号
更改坐标轴后,如没有额外操作会在更新该日志的同时保留之前的日志,导致曲线重合
可手动删除旧日志
记录图像writer.add_image()
tag是地址,img_tensor是图像数据,对格式有限制
# 支持的格式:
img_tensor = torch.randn(3, 224, 224) # PyTorch Tensor
img_tensor = np.random.rand(3, 224, 224) # NumPy 数组
img_tensor = "path/to/image.jpg" # 图像路径(字符串)
global_step记录当前训练步骤,用于在 TensorBoard 中滑动查看不同步骤的图像。
dataformats = ‘CHW’ # 默认:通道-高度-宽度
dataformats = ‘HWC’ # 高度-宽度-通道
不同格式的运行示例(非默认的shape参数要加一个dataformats=‘HWC’)
# 方式1:CHW 格式 (默认)
img_chw = torch.randn(3, 224, 224) # [通道, 高, 宽]
writer.add_image('CHW格式图像', img_chw, global_step=0)# 方式2:HWC 格式
img_hwc = torch.randn(224, 224, 3) # [高, 宽, 通道]
writer.add_image('HWC格式图像', img_hwc, global_step=0, dataformats='HWC')
测试代码