回归任务损失函数对比曲线
回归任务损失函数曲线可视化对比
本节将可视化对比均方误差(MSE)、平均绝对误差(MAE)、Huber损失函数三种常见回归任务损失函数的曲线,帮助理解它们在不同误差区间的表现差异。
1. 导入所需库
我们需要用到 numpy
进行数值计算,matplotlib
进行绘图。
import numpy as np
import matplotlib.pyplot as plt
2. 定义损失函数(MSE、MAE、Huber)
分别实现均方误差(MSE)、平均绝对误差(MAE)和Huber损失的Python函数。
def mse_loss(x):"""均方误差"""return x ** 2def mae_loss(x):"""平均绝对误差"""return np.abs(x)def huber_loss(x, delta=1.0): #delta阈值,控制损失函数从二次到线性切换的位置,常用1.0"""Huber损失"""return np.where(np.abs(x) <= delta,0.5 * x ** 2,delta * (np.abs(x) - 0.5 * delta))
3. 生成误差数据
生成一组对称分布的误差(如-5到5),用于损失函数的输入。
# 生成误差区间
errors = np.linspace(-5, 5, 200)
4. 计算各损失函数的取值
对每个误差值,分别计算MSE、MAE和Huber损失的结果。
mse_values = mse_loss(errors)
mae_values = mae_loss(errors)
huber_values = huber_loss(errors, delta=1.0) #delta设置为1.0是为了与其他损失函数对比
5. 绘制损失函数对比曲线
使用matplotlib将三种损失函数的曲线绘制在同一张图上,便于直观对比。
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'Microsoft YaHei', 'PingFang SC'] # 支持中文
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
plt.figure(figsize=(8, 5))
plt.plot(errors, mse_values, label='MSE (L2)', color='blue')
plt.plot(errors, mae_values, label='MAE (L1)', color='green')
plt.plot(errors, huber_values, label='Huber (δ=1.0)', color='red')
plt.xlabel('误差 (error)')
plt.ylabel('损失值 (loss)')
plt.title('MSE、MAE、Huber损失函数曲线对比')
plt.legend()
plt.grid(True)
plt.show()
总结
- MSE 对大误差更敏感,曲线在远离0时增长更快。
- MAE 对所有误差线性增长,对异常值不敏感,但在0点不可导。
- Huber损失 在误差较小时与MSE一致,误差较大时与MAE一致,兼具二者优点,常用于鲁棒回归任务。