特征工程四-2:使用GridSearchCV 进行超参数网格搜索(Hyperparameter Tuning)的用途
1. GridSearchCV
的作用
GridSearchCV
(网格搜索交叉验证)用于:
- 自动搜索 给定参数范围内的最佳超参数组合。
- 交叉验证评估 每个参数组合的性能,避免过拟合。
- 返回最佳模型,可直接用于预测或分析。
2. 代码逐行解析
(1) 创建 GridSearchCV
对象
grid = GridSearchCV(model, # 要优化的模型(如 RandomForest、SVM 等)params, # 待搜索的参数网格(字典或列表格式)error_score=0. # 如果某组参数拟合报错,将该组合得分设为 0
)
-
model
:已经定义的模型实例(如model = RandomForestClassifier()
)。 -
params
:参数网格,格式示例:params = {'n_estimators': [50, 100], # 决策树数量'max_depth': [5, 10] # 树的最大深度 }
-
error_score=0.
:
当某组参数导致模型拟合失败(如不兼容参数)时,将该参数组合的验证得分设为0
,避免程序中断。
(2) 执行网格搜索
grid.fit(X, y) # 用数据 X 和标签 y 拟合模型
- 对
params
中的所有参数组合进行尝试,并通过交叉验证(默认 5 折)评估性能。 - 最终确定 最佳参数组合,并重新训练模型(用最佳参数在整个数据集上训练)。
3. 关键输出
完成 fit
后,可通过以下属性获取结果:
-
最佳参数:
print(grid.best_params_) # 输出示例:{'max_depth': 10, 'n_estimators': 100}
-
最佳模型的交叉验证得分:
print(grid.best_score_)
-
最佳模型实例(可直接用于预测):
best_model = grid.best_estimator_ best_model.predict(X_test)
4. 注意事项
-
参数网格设计:
-
范围过大可能导致计算耗时,建议先用粗网格筛选,再细化。
-
示例:
params = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']} # SVM 参数
-
-
交叉验证控制:
- 可通过
cv
参数调整折数(如cv=10
)。 - 使用
scoring
指定评估指标(如scoring='accuracy'
)。
- 可通过
-
替代方案:
- 如果参数空间较大,可用
RandomizedSearchCV
(随机搜索,更快)。
- 如果参数空间较大,可用
5. 完整示例
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV# 定义模型和参数网格# 实例化机器学习模型
rf = RandomForestClassifier()
lr = LogisticRegression()
knn = KNeighborsClassifier()
dt = DecisionTreeClassifier()# 逻辑回归
lr_params/ = {'C': [1e-1, 1e0, 1e1, 1e2], 'penalty': ['l1', 'l2']}
# KNN
knn_params = {'n_neighbors': [1, 3, 5, 7]}
# 决策树
dt_params = {'max_depth': [None, 1, 3, 5, 7]}
# 随机森林
rf_params = {'n_estimators': [10, 50, 100], 'max_depth': [None, 1, 3, 5, 7]}
# model=rf/lr/knn/dt,params=lr_params/knn_params/dt_params/rf_params# 网格搜索
grid = GridSearchCV(model, params, error_score=0.)
grid.fit(X_train, y_train)# 输出最佳参数
print("Best parameters:", grid.best_params_)
总结
- 用途:自动化超参数优化,提升模型性能。
- 核心参数:
model
(模型)、params
(参数网格)、error_score
(容错处理)。 - 输出:通过
best_params_
、best_score_
等获取最佳结果。
适用于任何 Scikit-learn 兼容模型(分类、回归等)。