当前位置: 首页 > news >正文

【DL学习笔记】yaml、json、随机种子、浮点精度、amp

文章目录

  • argparse模块——命令行参数解析
  • yaml文件的语法和读写
    • 基本语法
    • 数据类型
      • 对象
      • 数组
      • 常量
    • yaml 文件读取
    • yaml 文件写入
  • json文件的语法和读写
    • 读取 JSON 文件
    • 写入 JSON 文件
    • json 读取/存储 list
    • json 读取/存储 string
  • 随机种子
  • 浮点精度Float32,Float16,BFloat16
    • 浮点数的表示方法
    • 浮点数的值计算公式
    • 具体三种浮点数举例
      • 1、float32
      • 2、float16
      • 3、bfloat16
    • 指数偏置 Bias
    • 精度 与 数值范围
      • 1、数部分(Exponent)控制数据范围
      • 2、尾数部分(Mantissa)控制精度
      • 数据类型对比表
    • 浮点数类型应用场景与硬件支持
      • 应用场景
      • 硬件支持
  • 混合精度训练 torch.cuda.autocast()
    • 混合精度原理
    • autocast 工作流程
    • 代码示例
    • 数值下溢 / 溢出相关
    • torch.cuda.amp.autocast() 精度相关

argparse模块——命令行参数解析

经常在训练或者推理脚本中,用argparse模块进行命令行参数解析。比如YOLO的胡环境中的Anaconda\envs\yolo_env\Lib\argparse.py

如果不用命令行传入参数,需要在训练时将超参数写入到train.py中,需要修改时,必须在文件中修改代码,而且再运行时传参需要传很多参数。用argparse传参可以仅仅只传一个parse对象。

argparse 模块是 Python 标准库(安装好python就可以import argprase)中提供的一个 命令行解析模块 ,它可以让使用者以类似 Unix/Linux 命令参数的方式输入参数(在终端以命令行的方式指定参数),argparse 会自动将命令行指定的参数解析为 Python 变量,从而让使用者更加快捷的处理参数。

使用的三个步骤:

  • 创建 ArgumentParser 对象
  • 添加参数
  • 解析命令行参数
  • 通过args.input属性调用参数
import argparse# 创建 ArgumentParser 对象
parser = argparse.ArgumentParser(description="description")# 添加参数
parser.add_argument("-i", "--input", help="Input file")# 解析命令行参数
args = parser.parse_args()# 使用参数
print("Input file:", args.input)

yaml文件的语法和读写

YAML :YAML Ain’t Markup Language",即 YAML是一种非标记语言。

基本语法

  • 大小写敏感
  • 使用缩进表示层级关系
  • 缩进不允许使用 tab,只允许空格
  • 缩进的空格数不重要,只要相同层级的元素左对齐即可
  • ‘#’ 表示注释

数据类型

YAML 支持以下几种数据类型:

  • 对象:键值对的集合,又称为映射(mapping)/ 哈希(hashes) / 字典(dictionary)
  • 数组:一组按次序排列的值,又称为序列(sequence) / 列表(list)
  • 纯量(scalars):单个的、不可再分的值

对象

  • 键值对 使用冒号结构表示: key: value 。 注意,冒号后面一定要加一个空格,否则 YAML 解析器会报错
  • 键值对 可以使用内联形式(flow style)的写法
key: {child-key: value, child-key2: value2}
  • 也可以使用缩进形式(block style)的写法
key: child-key: value     # 冒号后面要加一个空格child-key2: value2    # 冒号后面要加一个空格

数组

- 开头的行在 YAML 中表示列表元素,如下表示一个列表,等价于:["A", "B", "C"]

- A
- B
- C

当列表作为某个键的值时,有两种写法:

  • 内联形式(flow style)的写法 :
    key: [value1, value2, value3]
  • 缩进形式(block style):
key:- value1- value2- value3

上面两种写法在解析后效果完全一致,都等价于:

{"key": ["value1", "value2", "value3"]}

举个例子

# 缩进形式
companies:-id: 1name: company1price: 200W-id: 2name: company2price: 500W    # 内联形式
companies: [{id: 1,name: company1,price: 200W},{id: 2,name: company2,price: 500W}]
  • companies 是一个 列表(数组)
  • 每个列表元素是一个 字典(对象)
  • 每个对象包含三个字段:id、name、price

解析后等价于 Python 中的结构:

{"companies": [{"id": 1, "name": "company1", "price": "200W"},{"id": 2, "name": "company2", "price": "500W"}]
}

常量

纯量是最基本的,不可再分的值,包括:

  • 字符串
  • 布尔值
  • 整数
  • 浮点数
  • Null
  • 时间
  • 日期
boolean: - TRUE  # true,True都可以- FALSE  # false,False都可以
float:- 3.14- 6.8523015e+5  # 可以使用科学计数法
int:- 123- 0b1010_0111_0100_1010_1110    # 二进制表示
null:nodeName: 'node'parent: ~  # 使用~表示null
string:- 'Hello world'  # 可以使用双引号或者单引号包裹特殊字符- myname   # 不含特殊字符时可省略引号- newlinenewline2    # 字符串可以拆成多行,每一行会被转化成一个空格
date:- 2018-02-17    # 日期必须使用ISO 8601格式,即yyyy-MM-dd
datetime: -  2018-02-17T15:02:31+08:00    # 时间使用ISO 8601格式,时间和日期之间使用T连接,最后使用+代表时区

yaml 文件读取

首先安装 yaml 包 : pip install pyyaml

假设有一个 example.yaml 文件,内容如下 :

name:- John- Tom
age:- 30- 25

我们可以通过以下代码,读取 example.yaml 文件中的内容

import yaml# 读取 YAML 文件
with open("example.yaml", "r") as file:data = yaml.safe_load(file)# 打印读取的数据
print(data)

在这里插入图片描述

yaml 文件写入

import yamldata = {'name':['John', 'Tom'], 'age':[30, 25]}with open('./example.yaml ', "w") as f:yaml.safe_dump({k: v for k, v in data.items()}, f)

json文件的语法和读写

简介

  • JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式
  • 通常用于保存配置、结构化数据,便于网络通信或本地存储
  • JSON 文件以 .json 为后缀
  • json 是 Python 的内置模块,随 Python 一起安装,不需要 pip install json

读取 JSON 文件

假设,我们有如下一个 json 文件,文件名为 example.json,内容如下 :

{"name": "Alice","age": 25,"skills": ["Python", "Machine Learning", "Data Analysis"],"is_student": false
} 
  • 我们可以使用 load 方法,将文件中的内容读取出来
  • 因为这个示例json文件中的数据格式为 字典,所以,读取出来的内容也是 字典

read_json.py:

import json# 读取 json 文件
with open('example.json', 'r', encoding='utf-8') as f:data = json.load(f)print(data)

写入 JSON 文件

  • 使用 dump 方法,将内容写入到 json 文件中
  • dump 方法中的参数 indent :用于指定缩进空格数量,美化格式,让生成的 JSON 文件更易读
import json# 要写入的数据
data = {"name": "Bob","age": 30,"skills": ["C++", "Deep Learning"],"is_student": True
}# 写入 json 文件
with open('output.json', 'w', encoding='utf-8') as f:json.dump(data, f, indent=4)  # indent 用于指定缩进空格数量,美化输出,让生成的 JSON 文件更易读

Json 文件除了配合 字典 使用,还可以与其他的数据类型配合使用,比如 :列表、字符串,等等

json 读取/存储 list

import jsondata = ["apple", "banana", "cherry"]# 写入 json 文件
with open('fruits.json', 'w', encoding='utf-8') as f:json.dump(data, f, indent=2)# 读取 json 文件
with open('fruits.json', 'r', encoding='utf-8') as f:read_data = json.load(f)print(read_data)

json 读取/存储 string

dump 方法中的参数 ensure_ascii 用于控制是否将非 ASCII 字符(如中文,或者标点符号)转义成 \uXXXX。

  • True :默认值,表示会转义。
  • False:保留字符原样
    比如下面例子,如果指定 ensure_ascii=True,则存储到 json 文件中的内容为 :“Hello\uff0cWorld!”
import jsondata = "Hello,World!"with open('hello_world.json', 'w', encoding='utf-8') as f:json.dump(data, f, ensure_ascii=False)with open('hello_world.json', 'r', encoding='utf-8') as f:read_data = json.load(f)print(read_data)

随机种子

在算法实验中,我们经常需要设置随机数,我们不希望这些随机数会造成实验结果不能复现。

但其实这些随机数,只是 “伪随机数”,计算机中没有真正的随机。只要我们通过设置 相同的 随机种子,生成的随机数/随机序列 会是相同的,这有助于我们调试和复现实验结果。

我们常用的库,都有提供给我们设置随机种子的方法,比如:

import random
random.seed(121)import numpy
numpy.random.seed(121)import torch
torch.manual_seed(121)
torch.cuda.manual_seed(121)

比如说,对于 pytorch 的 torch.manual_seed() 方法,我们设置 随机种子为 121

你可以尝试 多次运行以下代码,你会发现,每一次生成的随机数,都是一样的

不同的版本或者硬件平台,生成的随机数可能不同

import torch# 设置随机种子
torch.manual_seed(121)# 生成一个随机张量
a = torch.randn(2, 3)
print(a)# tensor([[1.4521, 0.0504, 0.4962],
#         [0.3959, 0.2918, 1.0004]])

浮点精度Float32,Float16,BFloat16

浮点数的表示方法

  • Float32:单精度浮点数格式

    • 使用 32位(4字节)来表示一个浮点数
    • 遵循 IEEE 754标准,32位 包括 1 位符号位、8 位指数部分和 23 位尾数部分
  • Float16:半精度浮点数格式

    • 使用 16位(2字节)来表示一个浮点数
    • 遵循 IEEE 754标准,16位 包括 1 位符号位、5 位指数部分和 10 位尾数部分
  • BFloat16:Brain Floating Point 16-bit

    • 使用 16位(2字节)来表示一个浮点数
    • 不遵循 IEEE 754标准,16位 包括 1 位符号位、8 位指数部分 和 7 位尾数部分

数据类型位数及结构对比表:

数据类型一共位数符号位数(Sign)指数位数(Exponent)尾数位数(Mantissa)
Float32321823
Float16161510
BFloat1616187

各类型结构示意:符号位 | 指数部分 | 尾数部分

float32 结构为 :
S | EEEEEEEE | MMMMMMMMMMMMMMMMMMMMMMMMfloat16 结构为 :
S | EEEEE | MMMMMMMMMMbfloat16 结构为 :
S | EEEEEEEE | MMMMMMMMMMMM

浮点数的值计算公式

浮点数的值通过以下公式计算:
value=(−1)S×(1+Mantissa)×2(Exponent−Bias)value = (-1)^S \times (1 + \text{Mantissa}) \times 2^{(\text{Exponent} - \text{Bias})} value=(1)S×(1+Mantissa)×2(ExponentBias)

参数说明

  • S:符号位,表示数字的正负:0 为正,1 为负
  • Mantissa:尾数部分
  • Exponent:指数部分,表示 2 的幂次
  • Bias:指数的偏置
    • float32 的 Bias 为 127
    • float16 的 Bias 为 15
    • bfloat16 的 Bias 为 127

具体三种浮点数举例

1、float32

float32:32位浮点数,通常称为单精度浮点数

举例: 我们用 float32 类型表示 -4.5

  1. -4.5 为负数,所以 符号位 S = 1
  2. 4 的二进制表示是 100,0.5 的二进制表示是 0.1 →→→ 4.5 的二进制表示是 100.1
  3. 根据 IEEE 754 标准,需要将二进制表示 100.1 规范化为 1.xxxx 的形式 →→→ 100.1 = 1.001 × 2²
    • 尾部部分:是 1.001 去掉整数 1 的部分,即小数点后面的 001,将其补齐为 23位,为 00100000000000000000000
    • 指数部分是 2, 加上偏置127 ,得到指数 129,二进制为: 10000001

所以,-4.5float32 中的表示是: 1 | 10000001 | 00100000000000000000000
其中:

  • S = 1
  • Exponent = 10000001
  • Mantissa = 00100000000000000000000

2、float16

float16 是 16 位浮点数,通常称为 半精度浮点数

float16 的浮点数表示方法 与 float32 的表示方法类似,只不过在表示上有较小的精度和范围。

举例: 我们用 float16 类型表示 -4.5

  1. -4.5 为负数,所以 符号位 S = 1
  2. 4 的二进制表示是 100,0.5 的二进制表示是 0.1 →→→ 4.5 的二进制表示是 100.1
  3. 根据 IEEE 754 标准,需要将二进制表示 规范化为 1.xxxx 的形式 →→→ 100.1 = 1.001 × 2²
    • 尾部部分:是 1.001 去掉整数 1 的部分,即小数点后面的 001,将其补齐为 10位,为 0010000000
    • 指数部分是 2, 加上偏置15 ,得到指数 17,二进制为: 10001

所以,-4.5float16 中的表示是: 1 | 10001 | 0010000000
其中:

  • S = 1
  • Exponent = 10001
  • Mantissa = 0010000000

3、bfloat16

BFloat16 主要用于 深度学习训练阶段,特别是在需要保持与 Float32 类似的 数值范围 的同时,减少内存占用(精度比较低)

举例: 我们用 float16 类型表示 -4.5

  1. -4.5 为负数,所以 符号位 S = 1
  2. 4 的二进制表示是 100,0.5 的二进制表示是 0.1 →→→ 4.5 的二进制表示是 100.1
  3. 根据 IEEE 754 标准,需要将二进制表示 规范化为 1.xxxx 的形式 →→→ 100.1 = 1.001 × 2²
    • 尾部部分:是 1.001 去掉整数 1 的部分,即小数点后面的 001,将其补齐为 7位,为 0010000
    • 指数部分是 2, 加上偏置127 ,得到指数 129,二进制为: 10000001

所以,-4.5bfloat16 中的表示是: 1 | 10000001 | 0010000
其中:

  • S = 1
  • Exponent = 10000001
  • Mantissa = 0010000

指数偏置 Bias

指数的偏置(Bias)主要是用来控制指数部分的正负。

举例: 我们用 Float32 类型表示 -0.5

  1. -0.5 为负数,所以 符号位 S = 1
  2. 0.5 的二进制表示是 0.1
  3. 根据 IEEE 754 标准,需要将二进制表示 规范化为 1.xxxx 的形式 →→→ 0.1 = 1.0 × 2⁻¹
    • 尾部部分:是 1.0 去掉整数 1 的部分,即小数点后面的 0,将其补齐为 23 位,为 00000000000000000000000
    • 指数部分是 -1, 加上偏置127 ,得到实际的指数 126,二进制为:01111110

由上面例子可知:

  • 实际指数 是 -1,它表示的是浮点数的真实指数。
  • 实际指数 -1 加上 指数偏置 127 后,得到存储的指数 126,它的二进制形式 01111110 是存储在浮点数表示中的值。

所以,指数偏置的作用:是将原本可能为负的指数值(-1),转换为 无符号整数(126),以便于计算机在表示时能够方便地使用无符号整数表示 指数部分。

精度 与 数值范围

  • 指数部分位数,控制数据范围
  • 尾数部分位数,控制可表示的精度

1、数部分(Exponent)控制数据范围

举例说明:

  • 10 转换为二进制:1.010 × 2³, 不加偏置,指数部分为 3,二进制表示为 11,需要 2 位二进制来表示指数部分
  • 10000 转换为二进制:1.001110010000 × 2¹³,不加偏置,指数部分为 13,二进制表示为 1101,需要4位二进制来表示指数部分

结论:指数部分位数越多,可以表示的 数据范围越大

2、尾数部分(Mantissa)控制精度

举例 1:(1.55 的有效位为 3位,1.5 的有效位为 2位)

  • 1.55 的二进制表示近似为:1.1000110011(循环小数), 尾数部分为 1000110011,需要 10 位二进制来表示尾数部分
  • 如果减小 1.55 精度为 1.51.5 的二进制表示是:1.1, 尾数部分为 1, 需要 1位二进制来表示尾数部分

举例 2:(10 的有效位为 2位,10000 的有效位为 5位)

  • 10 转换为二进制:1.010 × 2³, 尾数部分为 010,需要3位二进制来表示尾数部分
  • 10000 转换为二进制 :1.001110010000 × 2¹³,尾数部分为 001110010000, 需要12位二进制来表示尾数部分

结论:尾数部分位数越多,可以表示的 精度越高

数据类型对比表

数据类型数值范围精度
Float32−3.4028235×1038∼1.1754944×10−38-3.4028235 \times 10^{38} \sim 1.1754944 \times 10^{-38}3.4028235×10381.1754944×1038
1.1754944×10−38∼3.4028235×10381.1754944 \times 10^{-38} \sim 3.4028235 \times 10^{38}1.1754944×10383.4028235×1038
6~9 位有效数字
Float16−65504∼−6.1035156×10−5-65504 \sim -6.1035156 \times 10^{-5}655046.1035156×105
6.1035156×10−5∼655046.1035156 \times 10^{-5} \sim 655046.1035156×10565504
3~4 位有效数字
BFloat16−3.4028235×1038∼1.1754944×10−38-3.4028235 \times 10^{38} \sim 1.1754944 \times 10^{-38}3.4028235×10381.1754944×1038
1.1754944×10−38∼3.4028235×10381.1754944 \times 10^{-38} \sim 3.4028235 \times 10^{38}1.1754944×10383.4028235×1038
2~3 位有效数字

有效数字:指的是在浮点数表示中,能够精确表达的数字。即,从最左边非零数字到最后一位数字的所有数字

  • 123.45 有 5 个有效数字:1, 2, 3, 4, 5
  • 0.00456 有 3 个有效数字:4, 5, 6
  • 400 有 3 个有效数字:4, 0, 0

浮点数类型应用场景与硬件支持

应用场景

  • Float32:通用型,适用于机器学习训练+推理,精度要求高的任务(如科学计算)。
  • Float16:主打推理阶段,精度低但可借硬件加速,提升计算效率。
  • BFloat16:聚焦深度学习训练,平衡数值范围(类似 Float32 范围)与内存,减少占用、加速计算。

硬件支持

  • Float32:大多数现代 CPU、GPU 普遍良好支持。
  • Float16:NVIDIA Volta 及后续架构(如 V100、A100、T4 等 GPU)提供硬件加速。
  • BFloat16:某些硬件(特别是Google 的TPU、Intel 部分处理器)有硬件支持,加速深度学习训练推理。

核心是不同浮点数类型在 AI 任务(训练/推理)里的分工,以及对应硬件适配情况 。

混合精度训练 torch.cuda.autocast()

torch.cuda.amp.autocast() 是 PyTorch 中实现混合精度训练的技术,核心作用是 在保持数值精度的同时,提升训练速度、减少显存占用

混合精度原理

通常DL使用的精度为32位单精度浮点数。混合精度指结合不同精度的数值计算加速训练(关键计算用 32 位避免精度问题,其他用 16 位加速)

  • 16 位浮点优势:内存占用减半、计算速度更快;
  • 16 位浮点风险:精度低,易出现数值下溢/溢出,影响训练结果;

torch.cuda.amp.autocast()能够自动将16位浮点数转换为32位浮点数进行数值计算,并在必要时将结果转换回16位浮点数。
这种自动转换可以帮助避免数值下溢或溢出的问题,并在保持数值精度的同时提高计算速度和减少显存占用。

autocast 工作流程

使用 torch.cuda.autocast() 需配合以下步骤:

  1. 模型与数据上 GPU:确保计算在 GPU 环境执行(CPU 环境无效果);
  2. 包装前向传播:用 torch.cuda.autocast() 上下文管理器,包裹模型前向传递和损失计算,自动处理精度转换;
  3. 梯度缩放:通过 torch.cuda.GradScaler 对象,将反向传播的梯度缩放回 16 位,避免梯度下溢;
  4. 更新参数:执行梯度更新,完成训练迭代。

代码示例

# 混合精度训练,CPU环境下无作用
with amp.autocast(enabled=scaler is not None):pred = model(imgs)

混合精度训练通过 torch.cuda.amp.autocast() 实现低精度加速 + 高精度保障,在深度学习训练中,可有效提升效率、降低显存压力,适配 GPU 硬件加速场景。

数值下溢 / 溢出相关

问:什么叫数值下溢 / 溢出
答:数值下溢和数值溢出是指在数值计算中,某个数值的绝对值太小或太大而无法被表示的情况。

数值下溢是指计算结果的绝对值小于机器能够表示的最小数,通常是由于计算结果太接近于零而导致的。在计算机中,数值通常以有限的二进制位表示,如果数值太小,那么在表示过程中会丢失精度,最终结果可能与真实结果相差很大,这种现象称为数值下溢。

数值溢出则是相反的情况,它是指计算结果的绝对值超过了机器能够表示的最大数,通常是由于计算结果太大而导致的。当计算机尝试用有限的二进制位表示超出其表示范围的数值时,会发生数值溢出,这会导致计算结果不准确或者无法表示。

在数值计算中,数值下溢和数值溢出都可能导致计算结果的不准确性,从而影响模型的训练效果。因此,处理这些数值问题是数值计算中的一个重要问题。常用的解决方法包括使用高精度算法、使用浮点数的科学计数法表示、使用数值截断等方法。

torch.cuda.amp.autocast() 精度相关

问:使用 torch.cuda.amp.autocast() 将数据 从32位(单精度) 转换为 16位(半精度),会导致精度丢失嘛?
答:使用 torch.cuda.amp.autocast() 将数据从32位(单精度)转换为16位(半精度)会导致精度损失。由于16位浮点数只能表示更少的有效位数,因此它们的精度不如32位浮点数。在混合精度训练中,为了平衡精度和性能,通常会将网络的前向传播和反向传播过程中的参数和梯度计算使用半精度浮点数来加速计算。这种方法可以在一定程度上降低计算精度要求,但会带来一定的精度损失。

尽管存在精度损失,使用半精度浮点数的优点在于它们可以显著降低计算时间和显存消耗,从而使模型可以在更大的批量下进行训练,提高训练效率。此外,在实际应用中,对于某些任务,半精度精度的计算误差对于结果的影响可能不是很大,因此,半精度计算可以在保证结果准确性的前提下,大幅度提高模型的训练速度和效率。

http://www.dtcms.com/a/310868.html

相关文章:

  • hcip---ospf知识点总结及实验配置
  • 学习嵌入式第十八天
  • rag学习-以项目为基础快速启动掌握rag
  • 深入 Go 底层原理(十):defer 的实现与性能开销
  • Vue3+ts自定义指令
  • 深入 Go 底层原理(二):Channel 的实现剖析
  • 基于结构熵权-云模型的铸铁浴缸生产工艺安全评价
  • 打靶日记-RCE-labs(续)
  • linux eval命令的使用方法介绍
  • php完整处理word中表单数据的方法
  • 【软考中级网络工程师】知识点之级联
  • PHP面向对象编程与数据库操作完全指南-上
  • ctfshow_源码压缩包泄露
  • Arduino IDE离线安装ESP8266板管理工具
  • 网络安全基础知识【6】
  • Linux初步认识与指令与权限
  • 机器学习sklearn:聚类
  • 读书:李光耀回忆录-我一生的挑战-新加坡双语之路
  • 【物联网】基于树莓派的物联网开发【21】——MQTT获取树莓派传感器数据广播实战
  • Python So Easy 大虫小呓三部曲 - 高阶篇
  • html5+css3+canvas长文转长图工具支持换行
  • 国产嵌入式调试器之光? RT-Trace 初体验!
  • C++之vector类的代码及其逻辑详解 (中)
  • 电力系统分析学习笔记
  • 谷歌Chrome浏览器安装插件
  • 论文笔记:Bundle Recommendation and Generation with Graph Neural Networks
  • 设计Mock华为昇腾GPU的MindSpore和CANN的库的流程与实现
  • STM32——启动过程浅析
  • 个人电脑部署私有化大语言模型LLM
  • python+pyside6的简易画板