当前位置: 首页 > news >正文

机器学习——KNN数据集划分

一、主要函数

sklearn.datasets.my_train_test_split()

该函数为Scikit-learn 中用于将数据集划分为训练集和测试集的函数,适用于机器学习模型的训练和验证。以下是详细解释:


​1、函数签名

train_test_split(
    *arrays,                  # 输入的数据集(可以是多个数组,如 X, y)
    test_size=None,           # 测试集的比例或样本数
    train_size=None,          # 训练集的比例或样本数
    random_state=None,        # 随机种子(确保结果可复现)
    shuffle=True,             # 是否打乱数据顺序
    stratify=None             # 是否分层抽样(保持类别比例)
)
返回值
  • 分割后的数据集,例如 X_train, X_test, y_train, y_test(顺序与输入一致)。

​2、参数详解

1. ​***arrays**​(必填)
  • 输入的数据集,可以是多个数组(如特征矩阵 X 和标签 y),支持同时拆分多个数据集。
  • 例如:X_train, X_test, y_train, y_test = train_test_split(X, y)
2. ​**test_size**​(默认 None
  • 浮点数:表示测试集占总数据的比例(如 test_size=0.2 表示 20% 作为测试集)。
  • 整数:表示测试集的绝对样本数(如 test_size=100)。
  • 如果 test_size 和 train_size 均为 None,默认 test_size=0.25
3. ​**train_size**​(默认 None
  • 类似 test_size,但指定训练集的比例或样本数。
  • 通常只需指定 test_size 或 train_size 中的一个。
4. ​**random_state**​(默认 None
  • 随机种子,保证每次分割结果一致。
  • 例如:random_state=42 使结果可复现。
5. ​**shuffle**​(默认 True
  • 是否在分割前打乱数据顺序。
  • 时间序列数据 需设置为 shuffle=False,避免破坏时间依赖性。
6. ​**stratify**​(默认 None
  • 指定分层抽样的参考标签,保持训练集和测试集的类别分布与原始数据一致。
  • 适用于分类任务中类别不平衡的数据。
  • 例如:stratify=y 会根据 y 的类别比例分割数据。

二、手动实现数据集划分

1.按比例计算

import numpy as np
import matplotlib.pyplot as plt #绘图模块
from sklearn.datasets import make_blobs  #聚类划分模块

#300样本,2个标签,3个聚类
x, y = make_blobs(
    n_samples = 300,
    n_features = 2,
    centers = 3,
    cluster_std = 1,
    center_box = (-10, 10),
    random_state = 233,
    return_centers = False
)

#打印显示
plt.scatter(x[:,0], x[:,1], c = y,s = 15)
plt.show()

#设置随机种子,将x打乱,并返回索引值
np.random.seed(233)
shuffle = np.random.permutation(len(x))

#设置划分比例
train_size=0.7  
test_size =0.3

#得到对应比例的数据索引
train_index = shuffle[:int(len(x) * train_size)]
test_index = shuffle[int(len(x) * train_size):]

#得到数据集
x[train_index].shape, y[train_index].shape #结果:((210, 2), (210,))
x[test_index].shape, y[test_index].shape  #结果:((90, 2), (90,))

2、上述过程封装

import numpy as np
from matplotlib import pyplot as plt

#函数封装
def my_train_test_split(x, y, train_size = 0.7, random_state = None):
    if random_state:
        np.random.seed(random_state)
    shuffle = np.random.permutation(len(x))
    train_index = shuffle[:int(len(x) * train_size)]
    test_index = shuffle[int(len(x) * train_size):]
    return x[train_index], x[test_index], y[train_index], y[test_index]

#调用
x_train, x_test, y_train, y_test = my_train_test_split(x, y, train_size = 0.7, random_state = 233)
x_train.shape, x_test.shape, y_train.shape, y_test.shape #结果:((210, 2), (90, 2), (210,), (90,))

#显示结果(训练集)
plt.scatter(x_train[:, 0], x_train[:, 1], c = y_train, s = 15)
plt.show()
#显示结果(测试集)
plt.scatter(x_test[:, 0], x_test[:, 1], c = y_test, s = 15)
plt.show()

三、KNN方法实现数据集快速划分

from sklearn.model_selection import train_test_split
from sklearn.datasets import make_blobs

#制作数据集
x, y = make_blobs(
    n_samples = 300,
    n_features = 2,
    centers = 3,
    cluster_std = 1,
    center_box = (-10, 10),
    random_state = 233,
    return_centers = False
)


#调用
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size = 0.7, random_state = 233)

#显示形状
x_train.shape, x_test.shape, y_train.shape, y_test.shape #结果:((210, 2), (90, 2), (210,), (90,))

#统计y_test标签
from collections import Counter
Counter(y_test) #结果:Counter({2: 34, 0: 25, 1: 31}),发现标签并不均匀

#加stratify = y限制,使其划分和标签类型一样
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size = 0.7, random_state = 233, stratify = y) 
print(Counter(y_test))  # 结果:Counter({2: 30, 0: 30, 1: 30})
print(Counter(y_train)) # 结果:Counter({0: 70, 2: 70, 1: 70})


相关文章:

  • 深度学习1—Python基础
  • 「一起学后端」Nest.js + MySQL 查询方法教学文档
  • Docker Compose 常用命令详解
  • Cursor平替免费软件开发工具使用感受和推荐
  • vim的一般操作(分屏操作) 和 Makefile 和 gdb
  • 从零到一开发一款 DeepSeek 聊天机器人
  • 【支持二次开发】基于YOLO系列的车辆行人检测 | 含完整源码、数据集、环境配置和训练教程
  • 程序算法基础
  • 思源配置阿里云 OSS 踩坑记
  • 寻找左边第一个更小值
  • RAG(Retrieval-Augmented Generation)基建之PDF解析的“魔法”与“陷阱”
  • 感知识别算法Jetson环境部署测试记录
  • 【AVRCP】深度剖析 AVRCP 中 Generic Access Profile 的要求与应用
  • RHCE 使用nginx搭建网站
  • Linux进程信号(下:补充)
  • 分布式任务调度框架XXl-job
  • 蓝桥杯备考:二分答案之路标设置
  • 大模型-提示词工程与架构
  • RK3588开发笔记-RTL8852wifi6模块驱动编译报错解决
  • Linux操作系统7- 线程同步与互斥4(基于POSIX条件变量的生产者消费者模型)
  • 中行一季度净赚超543亿降2.9%,利息净收入降逾4%
  • 我国将开展市场准入壁垒清理整治行动
  • 众信旅游:去年盈利1.06亿元,同比增长228.18%
  • 中共中央、国务院关于表彰全国劳动模范和先进工作者的决定
  • 庆祝中华全国总工会成立100周年暨全国劳动模范和先进工作者表彰大会隆重举行,习近平发表重要讲话
  • 人社部:就业政策储备充足,将会根据形势变化及时推出