机器学习09-正规方程
机器学习笔记:正规方程(Normal Equation)
概述
正规方程是线性回归中求解参数的一种解析方法。它基于最小化损失函数(如最小二乘法)来直接计算出参数的最优值。在机器学习中,这种方法尤其适用于特征数量不多且数据集不是非常大的情况。
预备知识
- 线性回归:模型形式通常表示为 (y = Xw + b),其中 (y) 是输出,(X) 是输入特征矩阵,(w) 是权重向量,(b) 是偏置项。
- 损失函数(最小二乘法):(J(w) = \frac{1}{2m} \sum_{i=1}^{m} (h(x^{(i)}) - y{(i)})2),其中 (m) 是样本数量,(h(x^{(i)})) 是模型预测值。
正规方程推导
目标是找到使损失函数 (J(w)) 最小化的参数 (w)。通过设置 (J(w)) 对 (w) 的导数为0,可以得到正规方程:
[ w = (XTX){-1}X^Ty ]
这里,(X) 是设计矩阵(包含特征和偏置项),(y) 是目标值向量。
步骤
- 构建设计矩阵 (X):如果特征数为 (n),样本数为 (m),则 (X) 的维度为 (m \times (n+1))(加一是因为偏置项)。
- 计算 (X^TX):这是一个 ((n+1) \times (n+1)) 的矩阵。
- 计算 (X^Ty):这是一个 ((n+1) \times 1) 的向量。
- 求解 (w):通过计算 ((XTX){-1}X^Ty) 来得到权重向量 (w)。
特点
- 优点:
- 直接计算,不需要迭代。
- 当特征数量较少时,计算量较小。
- 缺点:
- 当特征数量很多时,计算 (X^TX) 的逆矩阵非常耗时。
- 对于大规模数据集,内存消耗大。
- 对数据的规模和特征数量敏感。
应用场景
正规方程适用于特征数量不多、数据集规模适中的情况。在特征数量较多或者数据集较大时,通常推荐使用梯度下降等迭代方法。
实现示例(Python)
import numpy as np# 假设 X 是特征矩阵,y 是目标值向量
X = np.array([[1, 2], [1, 3], [1, 4], [1, 5]])
y = np.array([2, 3, 5, 7])# 添加偏置项
X_b = np.c_[np.ones((X.shape[0], 1)), X]# 计算参数 w
w_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)print("参数 w:", w_best)
总结
正规方程是解决线性回归问题的一种有效方法,尤其适用于特征数量较少的情况。然而,在处理大规模数据集或特征数量众多的情况时,其计算和存储的开销可能变得不可接受,此时应考虑使用更高效的优化算法,如梯度下降。