当前位置: 首页 > news >正文

机器学习之数字识别

这是一个基于MNIST数据集的手写数字识别程序演示。程序使用Keras构建了一个简单的CNN模型,包含两个卷积层和池化层,可自动训练或加载已有模型。通过摄像头实时捕捉画面,在画面中央200×200区域识别数字,并显示识别结果及置信度。程序实现了图像预处理(灰度化、缩放、二值化)和实时预测功能,按q键可退出。该演示展示了从模型训练到实际应用的完整流程,适用于课堂演示数字识别的基本原理。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ----------------------------------------------------------------------------
# @File     : demo3
# @Time     : 2025/10/20 17:22
# @Author   : CWB
# @Desc     : 
# ----------------------------------------------------------------------------
"""
这里写文件描述...
"""
import cv2
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical# 1. 加载或训练模型(这里直接用 MNIST 训练一个简单 CNN)
def load_or_train_model():try:model = load_model("digit_model.h5")print("✅ 加载已有模型")except:print("🔄 训练新模型...")(x_train, y_train), (x_test, y_test) = mnist.load_data()x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0y_train = to_categorical(y_train, 10)y_test = to_categorical(y_test, 10)from tensorflow.keras.models import Sequentialfrom tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Densemodel = Sequential([Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),MaxPooling2D((2,2)),Conv2D(64, (3,3), activation='relu'),MaxPooling2D((2,2)),Flatten(),Dense(64, activation='relu'),Dense(10, activation='softmax')])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])model.fit(x_train, y_train, epochs=3, validation_data=(x_test, y_test))model.save("digit_model.h5")return model# 2. 预处理图像:裁剪、缩放、二值化
def preprocess_roi(roi):roi_gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)roi_resized = cv2.resize(roi_gray, (28, 28), interpolation=cv2.INTER_AREA)roi_blur = cv2.GaussianBlur(roi_resized, (5, 5), 0)_, roi_thresh = cv2.threshold(roi_blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)roi_normalized = roi_thresh / 255.0return roi_normalized.reshape(1, 28, 28, 1)# 3. 主函数:打开摄像头,识别数字
def main():model = load_or_train_model()cap = cv2.VideoCapture(0)print("📷 摄像头已打开,按 'q' 退出")while True:ret, frame = cap.read()if not ret:break# 定义 ROI 区域(中央 200x200)x, y, w, h = 100, 100, 200, 200cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)roi = frame[y:y+h, x:x+w]# 预处理processed = preprocess_roi(roi)prediction = model.predict(processed)digit = np.argmax(prediction)confidence = np.max(prediction)# 显示结果cv2.putText(frame, f"Digit: {digit} ({confidence:.2f})", (x, y-10),cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)cv2.imshow("Digit Recognition", frame)if cv2.waitKey(1) & 0xFF == ord('q'):breakcap.release()cv2.destroyAllWindows()if __name__ == "__main__":main()

应老师要求,demo一个数字识别的程序,用的是公共数据集,通过摄像头获取并打印数字

http://www.dtcms.com/a/520149.html

相关文章:

  • 网站开发群安阳网站设计多少钱
  • 7. Prometheus告警配置-alertmanger
  • 自动签到之实现掘金模拟签到
  • 【探寻C++之旅】C++11 深度解析:重塑现代 C++ 的关键特性
  • 【unity】运行时加载并修改ScriptableObject类型资源对象的值会怎样
  • Spring Boot 实现 DOCX 转 PDF(基于 docx4j 的轻量级开源方案)
  • 服装企业官方网站建设网站运营收入
  • Spring Boot Actuator深度解析与实战
  • 如何做 行业社交类网站网站 建设在作用
  • 线程3 JavaEE(阻塞队列,线程池)
  • K8s中,deployment 是如何从 yaml 文件最终部署成功 pod 的
  • RK3588 使用 FFmpeg 硬件解码输出到 DRM Prime (DMA Buf) 加速数据传输
  • 基于蚁群算法的PID参数整定方法及MATLAB实现
  • 排序算法大全——插入排序
  • 手搓一个CUDA JIT编译器
  • 网站引导页模板互联网公司排名全球
  • JDK 9 List.of(...)
  • 做一个vue3 v-model 双向绑定的弹窗
  • 为超过10亿条记录的订单表新增字段
  • 哪里做网站最便宜WordPress功能模块排版
  • 每日算法刷题Day78:10.23:leetcode 一般树7道题,用时1h30min
  • 薄膜测厚选CWL法还是触针法?针对不同厚度与材质的台阶仪技术选型指南
  • WPF-MVVM的简单入门(第一个MVVM程序)
  • blender拓扑建模教程
  • asp.net手机网站开发教程翻译网站建设方案
  • 佛山建设网站公司哪家好特斯拉ceo进厂拧螺丝
  • 如何做新网站保留域名wordpress基础
  • C# 实现 Modbus TCP 通信
  • 《Git:从入门到精通(七)——Git分支管理与协作开发实战》
  • 超越传统工具:利用Reddit发现关键词的独特视角与前沿方法