PyTorch 数据处理工具箱
一、数据处理工具箱基本介绍
咱们在做深度学习时,得先把数据处理好才能用,PyTorch 专门给了一套 “工具箱” 来干这个,主要包括utils.data
、torchvision
这些部分,各自有不同的用处,一起配合就能把数据弄得符合模型训练的要求。
二、核心数据处理工具:utils.data
这部分主要是帮咱们管理和批量处理数据,有两个关键工具:
1. Dataset:逐个拿数据
- 它就像一个 “数据收纳盒”,能把咱们的数据集(比如数据和对应的标签)整理好。
- 要想用它,得自己建一个类继承它,然后写三个重要的 “小功能”:
__init__
:初始化,把数据和标签存到这个 “收纳盒” 里,比如 PPT 里就存了一个 5 组 2 维数据的集合,还有对应的 5 个标签。__getitem__
:按索引拿数据,每次只能拿一个样本,还会把 numpy 格式的数据转换成 PyTorch 能认的 Tensor 格式。比如要拿第 2 个数据,调用它就能得到对应的 Tensor 数据和标签。__len__
:告诉咱们这个 “收纳盒” 里一共有多少个样本,比如 PPT 里的数据集就有 5 个样本。
2. DataLoader:批量拿数据
- Dataset 每次只能拿一个数据,训练模型时咱们通常需要一次拿一批数据(叫 “batch”),DataLoader 就专门干这个,能把 Dataset 里的数据批量取出来。
- 它有很多实用的设置(参数),比如:
batch_size
:每次拿多少个样本,比如设成 2,就一次拿 2 个数据。shuffle
:要不要把数据打乱,True 就是打乱,这样训练模型效果更好。num_workers
:用多少个 “小助手”(进程)来加载数据,0 就是不用额外的 “小助手”。drop_last
:如果数据总数不是每次拿的批量数的整数倍,要不要把多出来的那几个数据扔掉,True 就是扔掉。
- 用它的时候,把 Dataset 对象传进去,再设好参数,就能循环拿批量数据了。不过它不是 “迭代器”,得用
iter
命令转一下才能像迭代器那样用。
三、图像数据专用处理:torchvision
如果咱们处理的是图像数据,torchvision
这个工具就特别好用,它包含了处理图像的各种功能,主要有transforms
和ImageFolder
。
1. transforms:给图像 “做美容”
- 它能对图像做各种处理,比如调整大小、裁剪、翻转,还能把图像转换成模型能认的格式,主要分两类操作:
- 对 PIL Image(一种常见的图像格式)的操作:比如
Resize
调整图像尺寸还不改变长宽比例;RandomCrop
随机裁剪图像;RandomHorizontalFlip
让图像随机左右翻转,这些操作能让图像数据更多样,帮模型更好地学习。 - 对 Tensor(PyTorch 常用数据格式)的操作:比如
Normalize
给数据做 “标准化”,让数据更稳定,方便模型训练;ToPILImage
能把 Tensor 格式转回到 PIL Image 格式,方便看图像。
- 对 PIL Image(一种常见的图像格式)的操作:比如
- 如果想一次做多个操作,就用
Compose
把这些操作串起来,像流水线一样,图像依次经过每个操作处理。比如 PPT 里就把 “中心裁剪”“随机裁剪”“转 Tensor”“标准化” 串起来用。
2. ImageFolder:读取不同文件夹的图像
- 平时咱们可能会把不同类的图像放在不同文件夹里,比如 “猫” 的图像放一个文件夹,“狗” 的放另一个,ImageFolder 能直接读取这些文件夹里的图像,还会自动给不同类的图像贴好标签(比如第一个文件夹的图像标签是 0,第二个是 1)。
- 用的时候,只要告诉它图像文件夹的路径,再加上之前说的
transforms
操作,就能把图像数据加载好,之后再用DataLoader
批量处理,就能给模型训练用了。PPT 里还举了例子,加载完数据后,还能显示图像、保存图像。
四、可视化工具:TensorBoard
训练模型的时候,咱们想看看模型结构对不对、损失值有没有下降、特征图长什么样,就可以用 TensorBoard,它能把这些东西直观地展示出来。
1. 用 TensorBoard 的步骤
- 第一步:导入工具,建一个 “记录者”(SummaryWriter),告诉它日志存在哪个文件夹里,文件夹不存在的话会自动建。
- 第二步:调用各种 “记录方法”(add_xxx),把要展示的东西记下来。比如
add_scalar
记损失值、准确率这种单一数字;add_image
记图像;add_graph
记模型结构,每个方法都要设标签(方便区分)、要记录的内容和迭代次数(比如训练到第几次)。 - 第三步:启动服务,在命令行里输入指定命令,找到日志文件夹的位置,还能设端口号(比如 6006)。
- 第四步:在浏览器里输入 “http://localhost:6006”(localhost表示本机,6006 是刚才设的端口号),就能看到展示的内容了。
2. 常见的可视化功能
- 可视化神经网络:用
add_graph
把模型传进去,就能看到模型的结构,比如有多少个卷积层、全连接层,各层之间怎么连接的,PPT 里就展示了一个包含卷积层、 dropout 层、全连接层的模型结构。 - 可视化损失值:训练时用
add_scalar
把每次迭代的损失值记下来,TensorBoard 会画出损失值随迭代次数变化的曲线,能清楚看到损失值是在下降还是没变化,判断模型训练得好不好。PPT 里的曲线就显示损失值随着训练一步步降低了。 - 可视化特征图:特征图是图像经过模型各层处理后得到的结果,用相应的代码把各层的特征图取出来,再用
add_image
记下来,就能在 TensorBoard 里看到特征图的样子,帮咱们理解模型是怎么 “看” 图像的。PPT 里就展示了卷积层处理后的特征图。