当前位置: 首页 > news >正文

Scikit-learn 简单介绍入门和常用API汇总

Scikit-learn 简单介绍和入门示例

1. 概述

Scikit-learn(简称 sklearn)是 Python 生态中最流行的 机器学习库,主要用于传统 ML 任务。它基于 NumPy、SciPy 和 Matplotlib,提供了统一的 API,涵盖 数据预处理、特征工程、模型训练、评估与调优

定位:

  • 适合:中小规模数据,原型验证,教学与科研
  • 不适合:海量数据分布式计算(推荐 Spark ML、Dask ML)

2. 设计思想

Scikit-learn 遵循 模块化、统一接口、组合化 的原则。
主要接口规范:

  • fit(X, y=None):训练模型 / 学习参数
  • predict(X):预测
  • transform(X):数据变换(特征工程、降维)
  • fit_transform(X):训练并变换(常用于预处理)
  • score(X, y):评估模型
  • get_params() / set_params():超参数管理

统一接口的好处:无论是 SVM、决策树还是 PCA,调用方式都基本相同。


3. 模块全景

3.1 数据预处理

  • 标准化与归一化StandardScaler, MinMaxScaler
  • 特征选择SelectKBest, RFE
  • 降维PCA, TruncatedSVD, TSNE

3.2 监督学习

  • 分类

    • 线性:LogisticRegression, SGDClassifier
    • 树模型:DecisionTreeClassifier, RandomForestClassifier, GradientBoostingClassifier
    • SVM:SVC
    • 朴素贝叶斯:GaussianNB
  • 回归

    • 线性:LinearRegression, Ridge, Lasso
    • 树模型:DecisionTreeRegressor, RandomForestRegressor
    • 支持向量回归:SVR

3.3 非监督学习

  • 聚类KMeans, DBSCAN, AgglomerativeClustering, GaussianMixture
  • 降维PCA, NMF

3.4 模型选择与评估

  • 交叉验证cross_val_score, KFold

  • 调参GridSearchCV, RandomizedSearchCV

  • 指标

    • 分类:准确率、精确率、召回率、F1、ROC-AUC
    • 回归:均方误差 (MSE)、R²

3.5 工程工具

  • Pipeline:将预处理和模型打包成流水线
  • Joblib:模型保存与加载
  • 并行计算:内置多核并行支持

4. 工作流程

一个典型的 Scikit-learn 项目通常包括:

  1. 数据准备:加载、清洗、划分(train_test_split
  2. 预处理:标准化、特征选择、降维
  3. 建模:选择分类/回归/聚类算法
  4. 评估:使用交叉验证与指标函数
  5. 调参GridSearchCVRandomizedSearchCV
  6. 部署:模型持久化 (joblib.dump/load)

5. 简单示例

鸢尾花分类 为例:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score# 数据加载
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 构建流水线:标准化 + SVM 分类器
pipeline = Pipeline([("scaler", StandardScaler()),("clf", SVC(kernel="linear"))
])# 训练与预测
pipeline.fit(X_train, y_train)
y_pred = pipeline.predict(X_test)# 评估
print("Accuracy:", accuracy_score(y_test, y_pred))

6. 优势与不足

优势

  • 统一 API,学习成本低
  • 覆盖大多数传统 ML 算法
  • 性能不错,底层部分用 Cython/C++ 优化
  • 文档和社区生态完善

不足

  • 不支持大规模分布式数据
  • 不包含深度学习模型
  • 在线学习能力有限(部分算法支持 partial_fit

7. 适用场景

  • 原型验证:快速测试不同算法
  • 教学与科研:直观 API 适合学习
  • 工程应用:中小规模数据的分类/回归/聚类
  • 特征工程 + 调参:配合 Pandas/Numpy 使用

8.入门示例yanz演示

阶段一:入门(快速上手)

目标:掌握 Scikit-learn 的基本 API,能完成简单的分类/回归任务。

学习要点
  • 熟悉 fit / predict / transform 接口
  • 使用 train_test_split 划分数据
  • 调用常见模型:LinearRegression, LogisticRegression, SVC, KNeighborsClassifier
  • 使用 accuracy_scoremean_squared_error 等指标
综合示例:鸢尾花分类
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score# 数据集
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 模型
clf = LogisticRegression(max_iter=200)
clf.fit(X_train, y_train)# 预测与评估
y_pred = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

阶段二:提升(实战技巧)

目标:掌握 Pipeline、特征工程、模型调参,能在真实数据集上完成较复杂的任务。

学习要点
  • Pipeline:整合预处理与建模
  • 特征工程StandardScaler, OneHotEncoder, PCA
  • 交叉验证cross_val_score
  • 超参数调优GridSearchCV, RandomizedSearchCV
  • 模型选择:比较不同模型性能
综合示例:房价预测(加州房价数据集)
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
from sklearn.metrics import mean_squared_error# 数据集
X, y = fetch_california_housing(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 流水线:标准化 + 随机森林
pipe = Pipeline([("scaler", StandardScaler()),("rf", RandomForestRegressor(random_state=42))
])# 超参数搜索
param_grid = {"rf__n_estimators": [50, 100],"rf__max_depth": [10, 20, None]
}
grid = GridSearchCV(pipe, param_grid, cv=3, scoring="neg_mean_squared_error")
grid.fit(X_train, y_train)# 预测与评估
y_pred = grid.predict(X_test)
print("Best params:", grid.best_params_)
print("MSE:", mean_squared_error(y_test, y_pred))

阶段三:高级(综合应用)

目标:能 系统性构建机器学习项目,包括数据预处理、特征选择、模型集成、结果可视化与解释。

学习要点
  • 特征选择SelectKBest, RFE
  • 集成学习VotingClassifier, StackingClassifier
  • 概率估计与模型校准CalibratedClassifierCV
  • 模型解释permutation_importance, PartialDependenceDisplay
  • 保存与部署joblib.dump, joblib.load
综合示例:信用卡客户违约预测(分类任务)
import joblib
import pandas as pd
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report# 假设已加载信用卡客户数据 (X: 特征, y: 是否违约)
data = pd.read_csv("credit_card.csv")
X = data.drop("default", axis=1)
y = data["default"]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)# 流水线:预处理 + 特征选择 + 集成学习
pipe = Pipeline([("scaler", StandardScaler()),("select", SelectKBest(score_func=f_classif, k=10)),("clf", VotingClassifier(estimators=[("rf", RandomForestClassifier(random_state=42)),("gb", GradientBoostingClassifier(random_state=42))], voting="soft"))
])# 随机搜索调参
param_dist = {"clf__rf__n_estimators": [100, 200],"clf__gb__learning_rate": [0.05, 0.1]
}
search = RandomizedSearchCV(pipe, param_dist, cv=3, scoring="f1", n_iter=4, random_state=42)
search.fit(X_train, y_train)# 预测与评估
y_pred = search.predict(X_test)
print("Best params:", search.best_params_)
print(classification_report(y_test, y_pred))# 模型保存
joblib.dump(search.best_estimator_, "credit_model.pkl")

总结

  • 入门:掌握 API + 简单模型(Logistic/SVC/LinearRegression)
  • 提升:学会 Pipeline、特征工程、调参(GridSearchCV/RandomizedSearchCV)
  • 高级:综合应用,能做完整 ML 项目(特征选择 + 集成学习 + 模型解释 + 部署)

Scikit-learn 常用 API 汇总

1️、数据集工具

from sklearn import datasets
from sklearn.model_selection import train_test_split
  • datasets.load_iris():鸢尾花分类
  • datasets.load_digits():手写数字识别
  • datasets.fetch_california_housing():加州房价数据
  • datasets.make_classification():生成分类数据
  • datasets.make_regression():生成回归数据
  • train_test_split(X, y, test_size=0.2, random_state=42):划分训练/测试集

2️、数据预处理

from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
  • 标准化StandardScaler().fit_transform(X)
  • 归一化MinMaxScaler().fit_transform(X)
  • 正则化Normalizer().fit_transform(X)
  • 独热编码OneHotEncoder().fit_transform(X)
  • 标签编码LabelEncoder().fit_transform(y)

3️、特征工程

from sklearn.feature_selection import SelectKBest, f_classif, RFE
from sklearn.decomposition import PCA
  • 特征选择SelectKBest(score_func=f_classif, k=10).fit_transform(X, y)
  • 递归特征消除RFE(estimator, n_features_to_select=5).fit_transform(X, y)
  • 主成分分析PCA(n_components=2).fit_transform(X)

4️、常用模型

分类

from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
  • 逻辑回归:LogisticRegression()
  • 支持向量机:SVC(kernel="linear")
  • 决策树:DecisionTreeClassifier()
  • 随机森林:RandomForestClassifier(n_estimators=100)
  • 梯度提升:GradientBoostingClassifier()
  • 朴素贝叶斯:GaussianNB()
  • 最近邻:KNeighborsClassifier(n_neighbors=5)

回归

from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
  • 线性回归:LinearRegression()
  • 岭回归:Ridge(alpha=1.0)
  • Lasso 回归:Lasso(alpha=0.1)
  • 支持向量回归:SVR(kernel="rbf")
  • 随机森林回归:RandomForestRegressor(n_estimators=100)

聚类 / 非监督学习

from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.mixture import GaussianMixture
  • KMeans 聚类:KMeans(n_clusters=3)
  • DBSCAN:DBSCAN(eps=0.5, min_samples=5)
  • 层次聚类:AgglomerativeClustering(n_clusters=3)
  • 高斯混合模型:GaussianMixture(n_components=3)

5️、模型评估

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import mean_squared_error, r2_score
  • 分类指标

    • accuracy_score(y_true, y_pred)
    • precision_score(y_true, y_pred)
    • recall_score(y_true, y_pred)
    • f1_score(y_true, y_pred)
    • roc_auc_score(y_true, y_prob)
    • classification_report(y_true, y_pred)
  • 回归指标

    • mean_squared_error(y_true, y_pred)
    • r2_score(y_true, y_pred)
  • 混淆矩阵confusion_matrix(y_true, y_pred)


6️、模型选择与调参

from sklearn.model_selection import cross_val_score, KFold, StratifiedKFold
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
  • 交叉验证:cross_val_score(model, X, y, cv=5)

  • K 折:KFold(n_splits=5)

  • 分层 K 折:StratifiedKFold(n_splits=5)

  • 网格搜索:

    GridSearchCV(estimator, param_grid, cv=3, scoring="accuracy")
    
  • 随机搜索:

    RandomizedSearchCV(estimator, param_distributions, cv=3, n_iter=10)
    

7️、工程工具

from sklearn.pipeline import Pipeline
import joblib
  • 流水线

    pipe = Pipeline([("scaler", StandardScaler()),("clf", SVC())
    ])
    
  • 模型保存joblib.dump(model, "model.pkl")

  • 模型加载model = joblib.load("model.pkl")



文章转载自:

http://JAaCzzWT.nnqrb.cn
http://1MgDFZJ0.nnqrb.cn
http://COYF8aMC.nnqrb.cn
http://X2oE3NLu.nnqrb.cn
http://WRAMN6na.nnqrb.cn
http://loG94emy.nnqrb.cn
http://SCDO4uCa.nnqrb.cn
http://ONpQwcd2.nnqrb.cn
http://VgB7pHB1.nnqrb.cn
http://3HKCQ2h0.nnqrb.cn
http://CAF10xcH.nnqrb.cn
http://u8pjDIxc.nnqrb.cn
http://7uoYHjUF.nnqrb.cn
http://tH8lxI64.nnqrb.cn
http://gfJ7IDUL.nnqrb.cn
http://CntZUpUy.nnqrb.cn
http://d75sSLmp.nnqrb.cn
http://XWmvOKi1.nnqrb.cn
http://2NfTZgO0.nnqrb.cn
http://HSHTC5bT.nnqrb.cn
http://S8UAaTd0.nnqrb.cn
http://tSRksK67.nnqrb.cn
http://RqPtvjax.nnqrb.cn
http://4cOt5MNH.nnqrb.cn
http://6920vaSJ.nnqrb.cn
http://ZeahBG0F.nnqrb.cn
http://ZTTKZ6oe.nnqrb.cn
http://jvz4rUz0.nnqrb.cn
http://FDDpVmEc.nnqrb.cn
http://9TzjVd9F.nnqrb.cn
http://www.dtcms.com/a/384405.html

相关文章:

  • [Dify] 用多个工具节点构建多轮 API 调用任务流:链式任务设计实战指南
  • Java实战:从零开发图书管理系统
  • 认知语义学中的隐喻对人工智能自然语言处理的深层语义分析的启示与影响研究报告
  • Mysql数据库事务全解析:概念、操作与隔离级别
  • Halcon 常用算子
  • 基于Spring Boot与Micrometer的系统参数监控指南
  • 【高并发内存池——项目】定长内存池——开胃小菜
  • 作为注册中心zk和nacos如何选型
  • 前置配置3:nacos 配置中心
  • Linux —— 进程的程序替换[进程控制]
  • [Linux] 从YT8531SH出发看Linux网络PHY驱动
  • ArcGIS定向影像(2)——非传统影像轻量级解决方案
  • 分享机械键盘MCU解决方案
  • Unity 性能优化 之 编辑器创建资源优化(UGUI | 物理 | 动画)
  • PostgreSQL——分区表
  • Elastic APM 高级特性:分布式追踪与机器学习优化
  • Ubuntu 服务器配置转发网络访问
  • Redis 数据结构源码剖析(SDS、Dict、Skiplist、Quicklist、Ziplist)
  • C#通讯之网络通讯 TCP UDP
  • 响应时间从5ms到0.8ms:威迈斯AI+DSP协同架构的突破与工程实践
  • 《WINDOWS 环境下32位汇编语言程序设计》第16章 WinSock接口和网络编程(2)
  • 算法--插入排序
  • 领码方案|权限即数据:企业系统中的字段级访问控制架构实战(Ver=1.0)
  • 【面试场景题】支付金融系统与普通业务系统的一些技术和架构上的区别
  • 数证杯顺心借JAVA网站重构详细版(服务器取证基础考点+检材+题目+重构视频)
  • 【Unity】【Photon】Fusion2中的玩家输入系统 学习笔记
  • Vue3 + Three.js 实战:自定义 3D 模型加载与交互全流程
  • 【Leetcode hot 100】102.二叉树的层序遍历
  • [Windows] 微软 .Net 运行库离线安装包 | Microsoft .Net Packages AIO_v09.09.25
  • java通过RESTful API实现两个项目之间相互传输数据