【DL学习笔记】yaml、json、随机种子、浮点精度、amp
文章目录
- argparse模块——命令行参数解析
- yaml文件的语法和读写
- 基本语法
- 数据类型
- 对象
- 数组
- 常量
- yaml 文件读取
- yaml 文件写入
- json文件的语法和读写
- 读取 JSON 文件
- 写入 JSON 文件
- json 读取/存储 list
- json 读取/存储 string
- 随机种子
- 浮点精度Float32,Float16,BFloat16
- 浮点数的表示方法
- 浮点数的值计算公式
- 具体三种浮点数举例
- 1、float32
- 2、float16
- 3、bfloat16
- 指数偏置 Bias
- 精度 与 数值范围
- 1、数部分(Exponent)控制数据范围
- 2、尾数部分(Mantissa)控制精度
- 数据类型对比表
- 浮点数类型应用场景与硬件支持
- 应用场景
- 硬件支持
- 混合精度训练 torch.cuda.autocast()
- 混合精度原理
- autocast 工作流程
- 代码示例
- 数值下溢 / 溢出相关
- torch.cuda.amp.autocast() 精度相关
argparse模块——命令行参数解析
经常在训练或者推理脚本中,用argparse模块进行命令行参数解析。比如YOLO的胡环境中的Anaconda\envs\yolo_env\Lib\argparse.py
如果不用命令行传入参数,需要在训练时将超参数写入到train.py中,需要修改时,必须在文件中修改代码,而且再运行时传参需要传很多参数。用argparse传参可以仅仅只传一个parse对象。
argparse 模块是 Python 标准库(安装好python就可以import argprase
)中提供的一个 命令行解析模块 ,它可以让使用者以类似 Unix/Linux 命令参数的方式输入参数(在终端以命令行的方式指定参数),argparse 会自动将命令行指定的参数解析为 Python 变量,从而让使用者更加快捷的处理参数。
使用的三个步骤:
- 创建 ArgumentParser 对象
- 添加参数
- 解析命令行参数
- 通过
args.input
属性调用参数
import argparse# 创建 ArgumentParser 对象
parser = argparse.ArgumentParser(description="description")# 添加参数
parser.add_argument("-i", "--input", help="Input file")# 解析命令行参数
args = parser.parse_args()# 使用参数
print("Input file:", args.input)
yaml文件的语法和读写
YAML :YAML Ain’t Markup Language",即 YAML是一种非标记语言。
基本语法
- 大小写敏感
- 使用缩进表示层级关系
- 缩进不允许使用 tab,只允许空格
- 缩进的空格数不重要,只要相同层级的元素左对齐即可
- ‘#’ 表示注释
数据类型
YAML 支持以下几种数据类型:
- 对象:键值对的集合,又称为映射(mapping)/ 哈希(hashes) / 字典(dictionary)
- 数组:一组按次序排列的值,又称为序列(sequence) / 列表(list)
- 纯量(scalars):单个的、不可再分的值
对象
- 键值对 使用冒号结构表示:
key: value
。 注意,冒号后面一定要加一个空格,否则 YAML 解析器会报错 - 键值对 可以使用内联形式(flow style)的写法
key: {child-key: value, child-key2: value2}
- 也可以使用缩进形式(block style)的写法
key: child-key: value # 冒号后面要加一个空格child-key2: value2 # 冒号后面要加一个空格
数组
以 -
开头的行在 YAML 中表示列表元素,如下表示一个列表,等价于:["A", "B", "C"]
- A
- B
- C
当列表作为某个键的值时,有两种写法:
- 内联形式(flow style)的写法 :
key: [value1, value2, value3]
- 缩进形式(block style):
key:- value1- value2- value3
上面两种写法在解析后效果完全一致,都等价于:
{"key": ["value1", "value2", "value3"]}
举个例子
# 缩进形式
companies:-id: 1name: company1price: 200W-id: 2name: company2price: 500W # 内联形式
companies: [{id: 1,name: company1,price: 200W},{id: 2,name: company2,price: 500W}]
- companies 是一个 列表(数组)
- 每个列表元素是一个 字典(对象)
- 每个对象包含三个字段:id、name、price
解析后等价于 Python 中的结构:
{"companies": [{"id": 1, "name": "company1", "price": "200W"},{"id": 2, "name": "company2", "price": "500W"}]
}
常量
纯量是最基本的,不可再分的值,包括:
- 字符串
- 布尔值
- 整数
- 浮点数
- Null
- 时间
- 日期
boolean: - TRUE # true,True都可以- FALSE # false,False都可以
float:- 3.14- 6.8523015e+5 # 可以使用科学计数法
int:- 123- 0b1010_0111_0100_1010_1110 # 二进制表示
null:nodeName: 'node'parent: ~ # 使用~表示null
string:- 'Hello world' # 可以使用双引号或者单引号包裹特殊字符- myname # 不含特殊字符时可省略引号- newlinenewline2 # 字符串可以拆成多行,每一行会被转化成一个空格
date:- 2018-02-17 # 日期必须使用ISO 8601格式,即yyyy-MM-dd
datetime: - 2018-02-17T15:02:31+08:00 # 时间使用ISO 8601格式,时间和日期之间使用T连接,最后使用+代表时区
yaml 文件读取
首先安装 yaml 包 : pip install pyyaml
假设有一个 example.yaml 文件,内容如下 :
name:- John- Tom
age:- 30- 25
我们可以通过以下代码,读取 example.yaml 文件中的内容
import yaml# 读取 YAML 文件
with open("example.yaml", "r") as file:data = yaml.safe_load(file)# 打印读取的数据
print(data)
yaml 文件写入
import yamldata = {'name':['John', 'Tom'], 'age':[30, 25]}with open('./example.yaml ', "w") as f:yaml.safe_dump({k: v for k, v in data.items()}, f)
json文件的语法和读写
简介
- JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式
- 通常用于保存配置、结构化数据,便于网络通信或本地存储
- JSON 文件以 .json 为后缀
- json 是 Python 的内置模块,随 Python 一起安装,不需要
pip install json
读取 JSON 文件
假设,我们有如下一个 json 文件,文件名为 example.json,内容如下 :
{"name": "Alice","age": 25,"skills": ["Python", "Machine Learning", "Data Analysis"],"is_student": false
}
- 我们可以使用 load 方法,将文件中的内容读取出来
- 因为这个示例json文件中的数据格式为 字典,所以,读取出来的内容也是 字典
read_json.py:
import json# 读取 json 文件
with open('example.json', 'r', encoding='utf-8') as f:data = json.load(f)print(data)
写入 JSON 文件
- 使用 dump 方法,将内容写入到 json 文件中
- dump 方法中的参数 indent :用于指定缩进空格数量,美化格式,让生成的 JSON 文件更易读
import json# 要写入的数据
data = {"name": "Bob","age": 30,"skills": ["C++", "Deep Learning"],"is_student": True
}# 写入 json 文件
with open('output.json', 'w', encoding='utf-8') as f:json.dump(data, f, indent=4) # indent 用于指定缩进空格数量,美化输出,让生成的 JSON 文件更易读
Json 文件除了配合 字典 使用,还可以与其他的数据类型配合使用,比如 :列表、字符串,等等
json 读取/存储 list
import jsondata = ["apple", "banana", "cherry"]# 写入 json 文件
with open('fruits.json', 'w', encoding='utf-8') as f:json.dump(data, f, indent=2)# 读取 json 文件
with open('fruits.json', 'r', encoding='utf-8') as f:read_data = json.load(f)print(read_data)
json 读取/存储 string
dump 方法中的参数 ensure_ascii 用于控制是否将非 ASCII 字符(如中文,或者标点符号)转义成 \uXXXX。
- True :默认值,表示会转义。
- False:保留字符原样
比如下面例子,如果指定 ensure_ascii=True,则存储到 json 文件中的内容为 :“Hello\uff0cWorld!”
import jsondata = "Hello,World!"with open('hello_world.json', 'w', encoding='utf-8') as f:json.dump(data, f, ensure_ascii=False)with open('hello_world.json', 'r', encoding='utf-8') as f:read_data = json.load(f)print(read_data)
随机种子
在算法实验中,我们经常需要设置随机数,我们不希望这些随机数会造成实验结果不能复现。
但其实这些随机数,只是 “伪随机数”,计算机中没有真正的随机。只要我们通过设置 相同的 随机种子,生成的随机数/随机序列 会是相同的,这有助于我们调试和复现实验结果。
我们常用的库,都有提供给我们设置随机种子的方法,比如:
import random
random.seed(121)import numpy
numpy.random.seed(121)import torch
torch.manual_seed(121)
torch.cuda.manual_seed(121)
比如说,对于 pytorch 的 torch.manual_seed()
方法,我们设置 随机种子为 121
你可以尝试 多次运行以下代码,你会发现,每一次生成的随机数,都是一样的
不同的版本或者硬件平台,生成的随机数可能不同
import torch# 设置随机种子
torch.manual_seed(121)# 生成一个随机张量
a = torch.randn(2, 3)
print(a)# tensor([[1.4521, 0.0504, 0.4962],
# [0.3959, 0.2918, 1.0004]])
浮点精度Float32,Float16,BFloat16
浮点数的表示方法
-
Float32:单精度浮点数格式
- 使用 32位(4字节)来表示一个浮点数
- 遵循 IEEE 754标准,32位 包括 1 位符号位、8 位指数部分和 23 位尾数部分
-
Float16:半精度浮点数格式
- 使用 16位(2字节)来表示一个浮点数
- 遵循 IEEE 754标准,16位 包括 1 位符号位、5 位指数部分和 10 位尾数部分
-
BFloat16:Brain Floating Point 16-bit
- 使用 16位(2字节)来表示一个浮点数
- 不遵循 IEEE 754标准,16位 包括 1 位符号位、8 位指数部分 和 7 位尾数部分
数据类型位数及结构对比表:
数据类型 | 一共位数 | 符号位数(Sign) | 指数位数(Exponent) | 尾数位数(Mantissa) |
---|---|---|---|---|
Float32 | 32 | 1 | 8 | 23 |
Float16 | 16 | 1 | 5 | 10 |
BFloat16 | 16 | 1 | 8 | 7 |
各类型结构示意:符号位 | 指数部分 | 尾数部分
float32 结构为 :
S | EEEEEEEE | MMMMMMMMMMMMMMMMMMMMMMMMfloat16 结构为 :
S | EEEEE | MMMMMMMMMMbfloat16 结构为 :
S | EEEEEEEE | MMMMMMMMMMMM
浮点数的值计算公式
浮点数的值通过以下公式计算:
value=(−1)S×(1+Mantissa)×2(Exponent−Bias)value = (-1)^S \times (1 + \text{Mantissa}) \times 2^{(\text{Exponent} - \text{Bias})} value=(−1)S×(1+Mantissa)×2(Exponent−Bias)
参数说明:
S
:符号位,表示数字的正负:0 为正,1 为负Mantissa
:尾数部分Exponent
:指数部分,表示 2 的幂次Bias
:指数的偏置float32
的 Bias 为 127float16
的 Bias 为 15bfloat16
的 Bias 为 127
具体三种浮点数举例
1、float32
float32:32位浮点数,通常称为单精度浮点数
举例: 我们用 float32
类型表示 -4.5
-4.5
为负数,所以 符号位S = 1
- 4 的二进制表示是
100
,0.5 的二进制表示是0.1
→→→ 4.5 的二进制表示是100.1
- 根据 IEEE 754 标准,需要将二进制表示
100.1
规范化为1.xxxx
的形式 →→→100.1 = 1.001 × 2²
- 尾部部分:是
1.001
去掉整数 1 的部分,即小数点后面的001
,将其补齐为 23位,为00100000000000000000000
- 指数部分是 2, 加上偏置127 ,得到指数 129,二进制为:
10000001
- 尾部部分:是
所以,-4.5
在 float32
中的表示是: 1 | 10000001 | 00100000000000000000000
其中:
S = 1
Exponent = 10000001
Mantissa = 00100000000000000000000
2、float16
float16
是 16 位浮点数,通常称为 半精度浮点数
float16
的浮点数表示方法 与 float32
的表示方法类似,只不过在表示上有较小的精度和范围。
举例: 我们用 float16
类型表示 -4.5
-4.5
为负数,所以 符号位S = 1
- 4 的二进制表示是
100
,0.5 的二进制表示是0.1
→→→ 4.5 的二进制表示是100.1
- 根据 IEEE 754 标准,需要将二进制表示 规范化为
1.xxxx
的形式 →→→100.1 = 1.001 × 2²
- 尾部部分:是
1.001
去掉整数 1 的部分,即小数点后面的001
,将其补齐为 10位,为0010000000
- 指数部分是 2, 加上偏置15 ,得到指数 17,二进制为:
10001
- 尾部部分:是
所以,-4.5
在 float16
中的表示是: 1 | 10001 | 0010000000
其中:
S = 1
Exponent = 10001
Mantissa = 0010000000
3、bfloat16
BFloat16 主要用于 深度学习训练阶段,特别是在需要保持与 Float32 类似的 数值范围 的同时,减少内存占用(精度比较低)
举例: 我们用 float16
类型表示 -4.5
-4.5
为负数,所以 符号位S = 1
- 4 的二进制表示是
100
,0.5 的二进制表示是0.1
→→→ 4.5 的二进制表示是100.1
- 根据 IEEE 754 标准,需要将二进制表示 规范化为
1.xxxx
的形式 →→→100.1 = 1.001 × 2²
- 尾部部分:是
1.001
去掉整数 1 的部分,即小数点后面的001
,将其补齐为 7位,为0010000
- 指数部分是 2, 加上偏置127 ,得到指数 129,二进制为:
10000001
- 尾部部分:是
所以,-4.5
在 bfloat16
中的表示是: 1 | 10000001 | 0010000
其中:
S = 1
Exponent = 10000001
Mantissa = 0010000
指数偏置 Bias
指数的偏置(Bias)主要是用来控制指数部分的正负。
举例: 我们用 Float32
类型表示 -0.5
-0.5
为负数,所以 符号位S = 1
0.5
的二进制表示是0.1
- 根据 IEEE 754 标准,需要将二进制表示 规范化为
1.xxxx
的形式 →→→0.1 = 1.0 × 2⁻¹
- 尾部部分:是
1.0
去掉整数 1 的部分,即小数点后面的0
,将其补齐为 23 位,为00000000000000000000000
- 指数部分是
-1
, 加上偏置127 ,得到实际的指数 126,二进制为:01111110
- 尾部部分:是
由上面例子可知:
- 实际指数 是
-1
,它表示的是浮点数的真实指数。 - 实际指数
-1
加上 指数偏置 127 后,得到存储的指数 126,它的二进制形式01111110
是存储在浮点数表示中的值。
所以,指数偏置的作用:是将原本可能为负的指数值(-1
),转换为 无符号整数(126
),以便于计算机在表示时能够方便地使用无符号整数表示 指数部分。
精度 与 数值范围
- 指数部分位数,控制数据范围
- 尾数部分位数,控制可表示的精度
1、数部分(Exponent)控制数据范围
举例说明:
- 将
10
转换为二进制:1.010 × 2³
, 不加偏置,指数部分为3
,二进制表示为11
,需要 2 位二进制来表示指数部分 - 将
10000
转换为二进制:1.001110010000 × 2¹³
,不加偏置,指数部分为13
,二进制表示为1101
,需要4位二进制来表示指数部分
结论:指数部分位数越多,可以表示的 数据范围越大
2、尾数部分(Mantissa)控制精度
举例 1:(1.55
的有效位为 3位,1.5
的有效位为 2位)
1.55
的二进制表示近似为:1.1000110011
(循环小数), 尾数部分为1000110011
,需要 10 位二进制来表示尾数部分- 如果减小
1.55
精度为1.5
,1.5
的二进制表示是:1.1
, 尾数部分为1
, 需要 1位二进制来表示尾数部分
举例 2:(10
的有效位为 2位,10000
的有效位为 5位)
- 将
10
转换为二进制:1.010 × 2³
, 尾数部分为010
,需要3位二进制来表示尾数部分 - 将
10000
转换为二进制 :1.001110010000 × 2¹³
,尾数部分为001110010000
, 需要12位二进制来表示尾数部分
结论:尾数部分位数越多,可以表示的 精度越高
数据类型对比表
数据类型 | 数值范围 | 精度 |
---|---|---|
Float32 | −3.4028235×1038∼1.1754944×10−38-3.4028235 \times 10^{38} \sim 1.1754944 \times 10^{-38}−3.4028235×1038∼1.1754944×10−38 1.1754944×10−38∼3.4028235×10381.1754944 \times 10^{-38} \sim 3.4028235 \times 10^{38}1.1754944×10−38∼3.4028235×1038 | 6~9 位有效数字 |
Float16 | −65504∼−6.1035156×10−5-65504 \sim -6.1035156 \times 10^{-5}−65504∼−6.1035156×10−5 6.1035156×10−5∼655046.1035156 \times 10^{-5} \sim 655046.1035156×10−5∼65504 | 3~4 位有效数字 |
BFloat16 | −3.4028235×1038∼1.1754944×10−38-3.4028235 \times 10^{38} \sim 1.1754944 \times 10^{-38}−3.4028235×1038∼1.1754944×10−38 1.1754944×10−38∼3.4028235×10381.1754944 \times 10^{-38} \sim 3.4028235 \times 10^{38}1.1754944×10−38∼3.4028235×1038 | 2~3 位有效数字 |
有效数字:指的是在浮点数表示中,能够精确表达的数字。即,从最左边非零数字到最后一位数字的所有数字
123.45
有 5 个有效数字:1, 2, 3, 4, 5
0.00456
有 3 个有效数字:4, 5, 6
400
有 3 个有效数字:4, 0, 0
浮点数类型应用场景与硬件支持
应用场景
- Float32:通用型,适用于机器学习训练+推理,精度要求高的任务(如科学计算)。
- Float16:主打推理阶段,精度低但可借硬件加速,提升计算效率。
- BFloat16:聚焦深度学习训练,平衡数值范围(类似 Float32 范围)与内存,减少占用、加速计算。
硬件支持
- Float32:大多数现代 CPU、GPU 普遍良好支持。
- Float16:NVIDIA Volta 及后续架构(如 V100、A100、T4 等 GPU)提供硬件加速。
- BFloat16:某些硬件(特别是Google 的TPU、Intel 部分处理器)有硬件支持,加速深度学习训练推理。
核心是不同浮点数类型在 AI 任务(训练/推理)里的分工,以及对应硬件适配情况 。
混合精度训练 torch.cuda.autocast()
torch.cuda.amp.autocast()
是 PyTorch 中实现混合精度训练的技术,核心作用是 在保持数值精度的同时,提升训练速度、减少显存占用。
混合精度原理
通常DL使用的精度为32位单精度浮点数。混合精度指结合不同精度的数值计算加速训练(关键计算用 32 位避免精度问题,其他用 16 位加速)
- 16 位浮点优势:内存占用减半、计算速度更快;
- 16 位浮点风险:精度低,易出现数值下溢/溢出,影响训练结果;
torch.cuda.amp.autocast()
能够自动将16位浮点数转换为32位浮点数进行数值计算,并在必要时将结果转换回16位浮点数。
这种自动转换可以帮助避免数值下溢或溢出的问题,并在保持数值精度的同时提高计算速度和减少显存占用。
autocast 工作流程
使用 torch.cuda.autocast()
需配合以下步骤:
- 模型与数据上 GPU:确保计算在 GPU 环境执行(CPU 环境无效果);
- 包装前向传播:用
torch.cuda.autocast()
上下文管理器,包裹模型前向传递和损失计算,自动处理精度转换; - 梯度缩放:通过
torch.cuda.GradScaler
对象,将反向传播的梯度缩放回 16 位,避免梯度下溢; - 更新参数:执行梯度更新,完成训练迭代。
代码示例
# 混合精度训练,CPU环境下无作用
with amp.autocast(enabled=scaler is not None):pred = model(imgs)
混合精度训练通过 torch.cuda.amp.autocast()
实现低精度加速 + 高精度保障,在深度学习训练中,可有效提升效率、降低显存压力,适配 GPU 硬件加速场景。
数值下溢 / 溢出相关
问:什么叫数值下溢 / 溢出
答:数值下溢和数值溢出是指在数值计算中,某个数值的绝对值太小或太大而无法被表示的情况。
数值下溢是指计算结果的绝对值小于机器能够表示的最小数,通常是由于计算结果太接近于零而导致的。在计算机中,数值通常以有限的二进制位表示,如果数值太小,那么在表示过程中会丢失精度,最终结果可能与真实结果相差很大,这种现象称为数值下溢。
数值溢出则是相反的情况,它是指计算结果的绝对值超过了机器能够表示的最大数,通常是由于计算结果太大而导致的。当计算机尝试用有限的二进制位表示超出其表示范围的数值时,会发生数值溢出,这会导致计算结果不准确或者无法表示。
在数值计算中,数值下溢和数值溢出都可能导致计算结果的不准确性,从而影响模型的训练效果。因此,处理这些数值问题是数值计算中的一个重要问题。常用的解决方法包括使用高精度算法、使用浮点数的科学计数法表示、使用数值截断等方法。
torch.cuda.amp.autocast() 精度相关
问:使用 torch.cuda.amp.autocast() 将数据 从32位(单精度) 转换为 16位(半精度),会导致精度丢失嘛?
答:使用 torch.cuda.amp.autocast() 将数据从32位(单精度)转换为16位(半精度)会导致精度损失。由于16位浮点数只能表示更少的有效位数,因此它们的精度不如32位浮点数。在混合精度训练中,为了平衡精度和性能,通常会将网络的前向传播和反向传播过程中的参数和梯度计算使用半精度浮点数来加速计算。这种方法可以在一定程度上降低计算精度要求,但会带来一定的精度损失。
尽管存在精度损失,使用半精度浮点数的优点在于它们可以显著降低计算时间和显存消耗,从而使模型可以在更大的批量下进行训练,提高训练效率。此外,在实际应用中,对于某些任务,半精度精度的计算误差对于结果的影响可能不是很大,因此,半精度计算可以在保证结果准确性的前提下,大幅度提高模型的训练速度和效率。