Pytorch工具箱2
utils.data.Dataset
用于定义数据集,
__getitem__
方法一次只能获取一个样本。
utils.data.DataLoader
语法结构:
data.DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,
)
torchvision.ImageFolder
相关参数:
dataset
:加载的数据集。batch_size
:批大小。shuffle
:是否将数据打乱。sampler
:样本抽样。num_workers
:使用多进程加载的进程数,0代表不使用多进程。collate_fn
:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可。pin_memory
:是否将数据保存在锁页内存(pin memory区),其中的数据转到GPU会快一些。drop_last
:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。
torchvision.transforms
提供了对PIL Image对象和Tensor对象的常用操作。
对PIL Image的常见操作:
Scale/Resize
:调整尺寸,长宽比保持不变。CenterCrop、RandomCrop、RandomSizedCrop
:裁剪图像。Pad
:填充。ToTensor
:把一个取值范围是[0,255]的PIL.Image转换成Tensor。RandomHorizontalFlip
:图像随机水平翻转。RandomVerticalFlip
:图像随机垂直翻转。ColorJitter
:修改亮度、对比度和饱和度。
可以读取不同目录下的图像数据。
对Tensor的常见操作:
Normalize
:标准化,即减均值,除以标准差。ToPILImage
:将Tensor转为PIL Image。
组合操作:如果要对数据集进行多个操作,可通过
Compose
将这些操作像管道一样拼接起来,类似于nn.Sequential
。可视化工具TensorBoard
使用步骤:
导入tensorboard,实例化SummaryWriter类:
from torch.utils.tensorboard import SummaryWriter # 实例化SummaryWriter,并指明日志存放路径。在当前目录没有logs目录将自动创建。 writer = SummaryWriter(log_dir='logs') # 调用实例 writer.add_xxx() # 关闭writer writer.close()
调用相应的API接口:
接口一般格式为:
add_xxx(tag-name, object, iteration-number)
,其中xxx
指的是各种可视化方法,如:Scalar
:用于可视化单一数值,例如损失值、准确率等随训练过程的变化。Image
:用于可视化图像数据。Figure
:用于可视化图形或复杂的图表。Histogram
:用于可视化数据的分布。Audio
:用于可视化音频数据。Text
:用于可视化文本数据,例如模型生成的文本或训练日志。Graph
:用于可视化计算图结构。ONNX Graph
:用于可视化ONNX模型的计算图结构。ONNX是开放神经网络交换格式。Embedding
:用于可视化高维数据的低维表示。例如,t-SNE或PCA降维后的词向量或特征。PR Curve
:用于可视化精确度-召回率曲线。Video Summaries
:用于可视化视频数据。
启动tensorboard服务:
cd到logs目录所在的同级目录,在命令行输入如下命令,
logdir
等式右边可以是相对路径或绝对路径:tensorboard --logdir=logs --port 6006 # 如果是windows环境,要注意路径解析,如 tensorboard --logdir=r'D:\myboard\test\logs' --port 6006
Web展示:
在浏览器输入:
http://服务器IP或名称:6006
(如果是本机,服务器名称可以使用localhost)