使用 np.zeros_like(label) 保存预测概率时发现数据类型不匹配导致的隐式类型转换
🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/
下面这段代码中 predictions
数组在赋值后全为 0 的根本原因是数据类型不匹配导致的隐式类型转换,保存概率值时会被截断为 0 或 1。具体分析如下:
import numpy as npsamples = 100
y = np.random.randint(0, 2, size=samples)
y_pred = np.random.randint(0, 2, size=samples)
y_pred_proba = np.random.rand(samples).astype(np.float32)
print(y, len(y))
print(y_pred, len(y_pred))
print(y_pred_proba, len(y_pred_proba))print("=" * 100)
predictions = np.zeros_like(y_pred)
print(predictions, len(predictions))
predictions[:] = y_pred_proba
print(predictions, len(predictions))
1. 错误原因分析
-
初始化数据类型错误:
-
y_pred
是np.random.randint(0, 2)
生成的,默认数据类型为int
。 -
np.zeros_like(y_pred)
会继承y_pred
的int
类型,因此predictions
是整型数组。
-
-
浮点数到整型的强制转换:
-
y_pred_proba
是np.random.rand
生成的浮点数组(float32
)。 -
当执行
predictions[:] = y_pred_proba
时,右侧的浮点数会被强制转换为左侧的整型,导致小数部分被截断。例如,0.95 -> 0
,0.19 -> 0
。 -
最终,所有
y_pred_proba
中的浮点值都会变成 0,导致predictions
全为 0。
-
2. 修正方法
核心思路:确保 predictions
的数据类型与 y_pred_proba
兼容(即浮点类型)。
2.1 高效且精确的修正代码
# 修正后的关键行:显式指定浮点类型
predictions = np.zeros_like(y_pred, dtype=np.float32) # 强制为浮点类型
predictions[:] = y_pred_proba
2.2 修正原理
-
显式指定数据类型:
-
np.zeros_like(y_pred, dtype=np.float32)
会创建一个与y_pred
形状相同但数据类型为float32
的数组。 -
此时
predictions
可以正确存储浮点数值,避免类型转换。
-
-
赋值操作保留精度:
- 右侧的
y_pred_proba
(float32
)可以直接赋值给左侧的浮点数组,无精度损失。
- 右侧的
3. 完整修正代码
import numpy as npsamples = 100
y = np.random.randint(0, 2, size=samples)
y_pred = np.random.randint(0, 2, size=samples)
y_pred_proba = np.random.rand(samples).astype(np.float32)
print(y, len(y))
print(y_pred, len(y_pred))
print(y_pred_proba, len(y_pred_proba))print("=" * 100)
predictions = np.zeros_like(y_pred)
print(predictions, len(predictions))
predictions[:] = y_pred_proba
print(predictions, len(predictions))
predictions = np.zeros_like(y_pred, dtype=np.float32)
predictions[:] = y_pred_proba
print(predictions, len(predictions))
4. 其他可行方案
- 直接使用浮点初始化:
predictions = np.zeros(samples, dtype=np.float32)
- 复用
y_pred_proba
的数据类型:
predictions = np.zeros_like(y_pred_proba)
5. 总结
-
根本原因:整型数组无法存储浮点数值,这会引起隐式类型转换。
-
修正关键:确保目标数组的数据类型与源数据相匹配(浮点数类型:np.float32、np.float64)。