【机器学习笔记Ⅰ】11 多项式回归
多项式回归(Polynomial Regression)详解
多项式回归是线性回归的扩展,通过引入特征的幂次项(如 (x^2, x^3))来拟合非线性关系。它保留了线性回归的简洁性,同时能捕捉更复杂的数据模式。
1. 核心思想
- 问题场景:当自变量 (x) 和因变量 (y) 之间存在非线性关系(如抛物线、周期性变化)时,简单线性回归((y = w_1 x + b))无法拟合。
- 解决方案:将特征升维,构造多项式特征,再用线性模型拟合。
例如:
[
y = w_1 x + w_2 x^2 + w_3 x^3 + b
]
虽然对 (x) 是非线性的,但对参数 (w) 仍是线性的,仍可用线性回归方法求解。
2. 数学模型
(1) 多项式方程
对于单特征 (x),(d) 次多项式回归方程:
[
y = w_0 + w_1 x + w_2 x^2 + \dots + w_d x^d
]
- (d):多项式阶数(需谨慎选择,过高会导致过拟合)。
(2) 多特征情况
若原始特征为 (x_1, x_2),二次多项式可扩展为:
[
y = w_0 + w_1 x_1 + w_2 x_2 + w_3 x_1^2 + w_4 x_2^2 + w_5 x_1 x_2
]
- 引入了交互项(如 (x_1 x_2))和平方项。
3. 实现步骤
(1) 特征变换
将原始特征 (x) 转换为多项式特征矩阵:
[
\text{若 } x = \begin{bmatrix} x_1 \ x_2 \end{bmatrix}, \text{二次多项式特征为 } \begin{bmatrix} 1, x_1, x_2, x_1^2, x_2^2, x_1 x_2 \end{bmatrix}
]
(2) 代码实现(Python)
import numpy as np
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression# 示例数据
X = np.array([[1], [2], [3], [4]]) # 单特征
y = np.array([2, 4, 9, 16]) # y ≈ x^2# 构造多项式特征(2阶)
poly = PolynomialFeatures(degree=2, include_bias=False)
X_poly = poly.fit_transform(X) # 转换为 [x, x^2]# 用线性回归拟合
model = LinearRegression()
model.fit(X_poly, y)# 预测
x_test = np.array([[5]])
x_test_poly = poly.transform(x_test)
print(model.predict(x_test_poly)) # 输出 ≈ 25 (5^2)
4. 关键问题
(1) 如何选择多项式阶数?
- 欠拟合(阶数太低):无法捕捉数据非线性(如用直线拟合抛物线)。
- 过拟合(阶数太高):模型过于复杂,拟合噪声(如下图)。
- 建议:
- 通过交叉验证选择最佳阶数。
- 观察训练集和验证集的误差曲线。
(2) 是否需要特征缩放?
- 需要!多项式特征的量纲差异极大(如 (x) 范围是 [0,1],则 (x^5) 范围是 [0,1e-5]),务必使用
StandardScaler
或MinMaxScaler
。
(3) 与非线性回归的区别
- 多项式回归:对特征非线性,对参数线性(仍用最小二乘法求解)。
- 非线性回归:参数也是非线性的(如 (y = e^{w x})),需数值优化(如梯度下降)。
5. 优缺点
优点 | 缺点 |
---|---|
简单高效,保留线性回归的计算优势。 | 高阶易过拟合,需正则化(如Lasso)。 |
可解释性强(系数反映特征重要性)。 | 特征维度爆炸(阶数高时)。 |
适合低维非线性数据。 | 对非多项式模式(如周期性)拟合差。 |
6. 应用场景
- 物理学:拟合物体运动轨迹(抛物线)。
- 经济学:描述增长趋势(如GDP的指数增长可用高阶多项式逼近)。
- 工业控制:传感器数据的非线性校准。
7. 进阶技巧
(1) 正则化(防止过拟合)
- 岭回归(Ridge):对系数 (w) 的L2惩罚。
- Lasso回归:对系数 (w) 的L1惩罚(可稀疏化特征)。
from sklearn.linear_model import Ridge
model = Ridge(alpha=0.1).fit(X_poly, y) # alpha是正则化强度
(2) 多项式核SVM
- 用核函数隐式计算高维特征,避免显式构造:
from sklearn.svm import SVR
model = SVR(kernel='poly', degree=3).fit(X, y)
8. 总结
- 多项式回归 = 特征升维 + 线性回归。
- 核心参数:阶数
degree
,需平衡拟合与泛化。 - 必做步骤:特征缩放、交叉验证、正则化(高阶时)。
通过合理使用多项式回归,可用线性方法解决复杂的非线性问题!