当前位置: 首页 > news >正文

Kaggle-Digit Recognizer-(多分类+卷积神经网络CNN)

Digit Recognizer

题意:

给你每个图片的dataframe类型的数据,让你预测出每个图片可能是多少。

思考:

数据处理

1.首先把数据从dadaframe转换成numpy,数据类型改为float32,并且并且展开为1维的28×28×1的形状,也就是28宽28高灰色通道。并且都要/255,因为灰度值是0-255,把灰度值压缩成0-1。
2.把train的image和label分开,然后split成训练集合和验证集合。

建立模型:

Sequential(),构建顺序模型,选择按层顺序堆叠。
添加卷积层,32个卷积核,内核为3×3,线性处理。
添加池化层,池化窗口2×2,步长为2。
添加卷积层,32个卷积核,内核为3×3,线性处理。
添加池化层,池化窗口2×2,步长为2。

Flatten(),将二维特征图展平为一维向量。
128个全神经元,relu激活函数。
输出层10个神经元对应10个答案,softmax激活函数进行分类。

代码:
import sys
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense


if __name__ == "__main__":
    data_train = pd.read_csv("train.csv")
    data_test = pd.read_csv("test.csv")
    X = data_train.drop(['label'],axis=1).values.reshape(-1, 28, 28, 1).astype('float32')
    Y = data_train['label'].values
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=2)
    X_train /= 255.0
    X_test /= 255.0
    data_test = data_test.values.reshape(-1, 28, 28, 1).astype('float32')/255.0

    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), padding='valid', activation='relu', input_shape=(28, 28, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=2, padding='valid'))
    model.add(Conv2D(32, kernel_size=(3, 3), padding='valid', activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=2, padding='valid'))

    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(10, activation='softmax'))

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer='adam',
        metrics=['accuracy']
    )
    history = model.fit(X_train, Y_train, epochs=10, validation_split=0.2)
    y_pred =model.predict(data_test).argmax(axis=1)
    anw = pd.DataFrame({"ImageId": range(1,len(y_pred)+1), "Label": y_pred})
    anw.to_csv("Submission.csv", index=False)

相关文章:

  • 集成学习+泰坦尼克号案例+红酒品质预测
  • pipe匿名管道实操(Linux)
  • SpringBoot集成Ollama本地模型
  • AllData数据中台升级发布 | 支持K8S数据平台2.0版本
  • 系统变量和用户变量的区别是什么
  • Android WiFi获取动态IP地址
  • python函数的定义与使用
  • Docker Harbor
  • 连表查询的时候,子查询的条件应该写到子查询里面,不能放到外面
  • 大模型在网络安全领域的七大应用
  • qml之锚点Anchors
  • Google Cloud Next‘25大会 Gemini 支持 Anthropic MCP 协议及推出 A2A 协议剑指医疗AI情况分析
  • QBitmap、QPixmap、QImage 和 QPicture 使用方法和特点以及转换
  • Windows10 ssh无输出 sshd服务启动失败 1067报错 公钥无法认证链接 解决办法
  • Android 中绕过hwbinder 实现跨模块对audio 的HAL调用
  • Java面试黄金宝典45
  • POSIX线程(pthread)库:线程的终止与管理
  • C#异步方法返回Task<T>的同步调用
  • LLM相关代码笔记
  • 【Docker基础】容器技术详解:生命周期、命令与实战案例