当前位置: 首页 > news >正文

机器学习knnlearn5

import numpy as np
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as kNN

# 此函数用于将一个32x32的文本文件转换为一个1x1024的一维向量
def img2vector(filename):
    """
    将32x32的文本文件转换为1x1024的向量
    :param filename: 要转换的文本文件的文件名
    :return: 转换后的1x1024向量,如果出现错误则返回None
    """
    try:
        # 初始化一个1x1024的零向量,用于存储转换后的数据
        returnVect = np.zeros((1, 1024))
        # 以只读模式打开指定的文件
        with open(filename) as fr:
            # 遍历文件的前32行,因为图像是32x32的
            for i in range(32):
                # 读取当前行的内容
                lineStr = fr.readline()
                # 遍历当前行的前32个字符
                for j in range(32):
                    # 将当前字符转换为整数,并存储到向量的相应位置
                    returnVect[0, 32 * i + j] = int(lineStr[j])
        # 返回转换后的向量
        return returnVect
    except FileNotFoundError:
        # 若文件未找到,打印错误信息
        print(f"错误:文件 {filename} 未找到。")
        return None
    except Exception as e:
        # 若发生其他未知错误,打印错误信息
        print(f"错误:处理文件 {filename} 时发生未知错误:{e}")
        return None

# 此函数用于加载训练数据,返回训练数据矩阵和对应的标签列表
def load_training_data():
    """
    加载训练数据
    :return: 训练数据矩阵和对应的标签列表,如果出现错误则返回None, None
    """
    # 用于存储训练数据的标签
    hwLabels = []
    try:
        # 获取训练数据文件夹下的所有文件名
        trainingFileList = listdir('trainingDigits')
        # 计算训练数据的数量
        m = len(trainingFileList)
        # 初始化一个m行1024列的零矩阵,用于存储训练数据
        trainingMat = np.zeros((m, 1024))
        # 遍历训练数据文件夹下的所有文件
        for i in range(m):
            # 获取当前文件名
            fileNameStr = trainingFileList[i]
            # 从文件名中提取出对应的数字标签
            classNumber = int(fileNameStr.split('_')[0])
            # 将标签添加到标签列表中
            hwLabels.append(classNumber)
            # 调用img2vector函数将当前文件转换为向量
            vector = img2vector(f'trainingDigits/{fileNameStr}')
            if vector is not None:
                # 将转换后的向量存储到训练数据矩阵的相应行
                trainingMat[i, :] = vector
        # 返回训练数据矩阵和标签列表
        return trainingMat, hwLabels
    except FileNotFoundError:
        # 若训练数据文件夹未找到,打印错误信息
        print("错误:训练数据文件夹未找到。")
        return None, None
    except Exception as e:
        # 若发生其他未知错误,打印错误信息
        print(f"错误:加载训练数据时发生未知错误:{e}")
        return None, None

# 此函数用于加载测试数据,返回测试数据矩阵和对应的标签列表
def load_test_data():
    """
    加载测试数据
    :return: 测试数据矩阵和对应的标签列表,如果出现错误则返回None, None
    """
    try:
        # 获取测试数据文件夹下的所有文件名
        testFileList = listdir('testDigits')
        # 计算测试数据的数量
        mTest = len(testFileList)
        # 初始化一个mTest行1024列的零矩阵,用于存储测试数据
        testMat = np.zeros((mTest, 1024))
        # 用于存储测试数据的标签
        testLabels = []
        # 遍历测试数据文件夹下的所有文件
        for i in range(mTest):
            # 获取当前文件名
            fileNameStr = testFileList[i]
            # 从文件名中提取出对应的数字标签
            classNumber = int(fileNameStr.split('_')[0])
            # 将标签添加到标签列表中
            testLabels.append(classNumber)
            # 调用img2vector函数将当前文件转换为向量
            vector = img2vector(f'testDigits/{fileNameStr}')
            if vector is not None:
                # 将转换后的向量存储到测试数据矩阵的相应行
                testMat[i, :] = vector
        # 返回测试数据矩阵和标签列表
        return testMat, testLabels
    except FileNotFoundError:
        # 若测试数据文件夹未找到,打印错误信息
        print("错误:测试数据文件夹未找到。")
        return None, None
    except Exception as e:
        # 若发生其他未知错误,打印错误信息
        print(f"错误:加载测试数据时发生未知错误:{e}")
        return None, None

# 此函数用于进行手写数字识别测试,打印分类结果和错误率
def handwritingClassTest():
    """
    手写数字识别测试
    """
    # 调用load_training_data函数加载训练数据
    trainingMat, hwLabels = load_training_data()
    if trainingMat is None or hwLabels is None:
        # 若加载训练数据失败,直接返回
        return
    # 创建一个K近邻分类器对象,设置邻居数量为3,算法为自动选择
    neigh = kNN(n_neighbors=3, algorithm='auto')
    # 使用训练数据和标签对分类器进行训练
    neigh.fit(trainingMat, hwLabels)
    # 调用load_test_data函数加载测试数据
    testMat, testLabels = load_test_data()
    if testMat is None or testLabels is None:
        # 若加载测试数据失败,直接返回
        return
    # 初始化错误计数为0
    errorCount = 0.0
    # 计算测试数据的数量
    mTest = len(testLabels)
    # 遍历测试数据
    for i in range(mTest):
        # 使用训练好的分类器对当前测试数据进行预测
        classifierResult = neigh.predict(testMat[i].reshape(1, -1))
        # 打印分类结果和真实标签
        print(f"分类返回结果为 {classifierResult[0]}\t真实结果为 {testLabels[i]}")
        if classifierResult[0] != testLabels[i]:
            # 若分类结果与真实标签不一致,错误计数加1
            errorCount += 1.0
    # 打印错误的数量和错误率
    print(f"总共错了 {int(errorCount)} 个数据\n错误率为 {errorCount / mTest * 100:.2f}%")

# 程序入口,当脚本作为主程序运行时,调用handwritingClassTest函数进行测试
if __name__ == '__main__':
    handwritingClassTest()
# 首先导入鸢尾花数据载入工具
from sklearn.datasets import load_iris
#导入KNN分类模型
from sklearn.neighbors import KNeighborsClassifier
#为了方便可视化,我们再导入matplotlib和seaborn
import matplotlib.pyplot as plt
import seaborn as sns
#加载鸢尾花数据集,赋值给iris变量
iris = load_iris()
#查看数据集的键名
iris.keys()
#查看数据集的特征名称
iris.feature_names
# 查看数据集中的样本分类
iris.target 
#将样本的特征和标签分别赋值给X和y
x, y = iris.data, iris.target 
#查看是否成功
x.shape
#导入数据集拆分工具
from sklearn.model_selection import train_test_split
#将X和y拆分为训练集和验证集
x_train, x_test, y_train, y_test = train_test_split(x,y)
#查看拆分情况
x_train.shape
#创建KNN分类器,参数保持默认设置
knn_clf = KNeighborsClassifier(n_neighbors=6)
#使用训练集拟合模型
knn_clf.fit(x_train,y_train)
#查看模型在训练集和验证集中的准确率
                            
print('训练集准确率:%.2f'%knn_clf.score(x_train, y_train))
print('验证集准确率:%.2f'%knn_clf.score(x_test, y_test))

# 导入网格搜索
from sklearn.model_selection import GridSearchCV
# 定义一个从1到10的n_neighbors
n_neighbors = tuple(range(1,11,1))
# 创建网格搜索示例,estimator 用knn分类器
# 把刚刚定义的n_neighbors 传入param_grid参数
# cv参数指交叉验证次数为5
cv = GridSearchCV(estimator=KNeighborsClassifier(),
                  param_grid = {'n_neighbors':n_neighbors},
                  cv = 5)
# 使用网络搜索你和数据集
cv.fit(x,y)
# 查看最优参数
cv.best_params_






训练集准确率:0.96
验证集准确率:0.95





{'n_neighbors': 6}

相关文章:

  • 硬件面试问题
  • centos7 linux VMware虚拟机新添加的网卡,能看到网卡名称,但是看不到网卡的配置文件
  • UE4学习笔记 FPS游戏制作30 显示击杀信息 水平框 UI模板(预制体)
  • 纯css实现环形进度条+动画加载效果
  • QScreen 捕获屏幕(截图)
  • 智能舵机:AI融合下的自动化新纪元
  • Postman 如何模拟 Request Payload 发送请求?
  • MySQL 性能优化:索引优化与查询优化
  • Scikit-learn全攻略:从入门到工业级应用
  • MQ的数据一致性,如何保证?
  • 网络基础:五层模型
  • 深入理解Spring Data JPA:简化Java持久层开发
  • 探索 curl ipinfo.io:从命令行获取你的网络身份卡!!!
  • 在Git仓库的Readme上增加目录页
  • 【LLM】Llama Factory:Windows部署全流程
  • linux如何查看系统版本
  • WinDbg. From A to Z! 笔记(一)
  • 项目代码第8讲【数据库基础知识】:SQL(DDL、DML、DQL、DCL);函数(聚合、字符串、数值、日期、流程);约束;多表查询;事务
  • 西域平台商品详情接口设计与实现‌
  • 电容式电压互感器在线监测系统
  • 乌美矿产协议文本公布,明确乌收益及协议优先级
  • 近七成科创板公司2024年营收增长,285家营收创历史新高
  • 周劼已任中国航天科技集团有限公司董事、总经理、党组副书记
  • 中国海警位中国黄岩岛领海及周边区域执法巡查
  • 国泰海通合并后首份业绩报告出炉:一季度净利润增逾391%
  • 量子传感新技术“攻克”退相干难题