当前位置: 首页 > news >正文

2. 手写数字预测 gui版

2. 手写数字预测 gui版

  • 背景
  • 1.界面绘制
  • 2.处理图片
  • 3. 加载模型
  • 4. 预测
  • 5.结果
  • 6.一点小问题

在这里插入图片描述

背景

做了手写数字预测的模型,但是老是跑模型太无聊了,就配合pyqt做了一个可视化界面出来玩一下

源代码可以去这里https://github.com/Leezed525/pytorch_toy拿

1.界面绘制

在这里插入图片描述

整个页面布局逻辑很简单,搭建一下就好了

class MainWindow(QMainWindow):def __init__(self):super().__init__()self.net = self.get_net()  # 获取数字预测模型self.setWindowTitle("PyQt 数字预测")self.setGeometry(100, 100, 500, 550)  # 设置主窗口的初始位置和大小,留出空间给按钮self.setFixedSize(500, 550)self.setWindowFlags(self.windowFlags() & ~Qt.WindowType.WindowMaximizeButtonHint)central_widget = QWidget()  # 创建一个中央 QWidgetself.setCentralWidget(central_widget)  # 设置中央 QWidget 为主窗口的中心部件layout = QVBoxLayout(central_widget)  # 为中央 QWidget 创建一个垂直布局# 创建一个水平布局operation_layer = QHBoxLayout()  # 创建一个水平布局用于放置操作区域left_operation_layer = QVBoxLayout()right_operation_layer = QVBoxLayout()self.canvas = DrawingCanvas(self)  # 创建 DrawingCanvas 实例canvas_label = QLabel("请在此处绘制数字")  # 创建一个标签,提示用户在画布上绘制数字canvas_label.setAlignment(Qt.AlignmentFlag.AlignCenter)canvas_label.setStyleSheet("font-size: 20px;")  # 设置标签的样式left_operation_layer.addWidget(canvas_label)  # 将标签添加到左侧操作区域布局中left_operation_layer.addWidget(self.canvas)left_operation_layer.setStretch(0, 1)left_operation_layer.setStretch(1, 10)  # 设置画布的伸缩比例,使其占据更多空间operation_layer.addLayout(left_operation_layer)  # 将左侧操作区域布局添加到操作层布局中# 右侧操作区域self.predict_label = QLabel("预测结果: ")  # 创建一个标签,显示预测结果right_operation_layer.addWidget(self.predict_label)self.predict_digit_labels = []for i in range(10):predict_digit_label = QLabel(f"数字 {i}: 0.00%")  # 创建标签显示每个数字的预测概率self.predict_digit_labels.append(predict_digit_label)  # 将标签添加到列表中for label in self.predict_digit_labels:right_operation_layer.addWidget(label)operation_layer.addLayout(right_operation_layer)  # 将右侧操作区域布局添加到操作层布局中operation_layer.setStretch(0, 10)operation_layer.setStretch(1, 1)layout.addLayout(operation_layer)  # 将操作层布局添加到主布局中# 按钮区布局button_layout = QHBoxLayout()  # 创建一个垂直布局用于放置按钮clear_button = QPushButton("清空画布")  # 清空画布按钮clear_button.clicked.connect(self.canvas.clear_canvas)  # 连接按钮的点击信号到清空画布方法predict_button = QPushButton("预测")  # 清空画布按钮predict_button.clicked.connect(self.predict)  # 连接按钮的点击信号到预测方法button_layout.addStretch(6)button_layout.addWidget(clear_button)button_layout.addWidget(predict_button)layout.addLayout(button_layout)  # 将按钮布局添加到主布局中

其中稍微有点心智压力的区域就是画图区域,这里配合ai然后再自行修改一下就好了,逻辑就是鼠标按住然后绘制,松开后停止绘制。

canvas代码

class DrawingCanvas(QWidget):"""一个自定义的 QWidget 类,用作绘图画布。用户可以在此画布上用鼠标点击并拖动来绘制线条。"""def __init__(self, parent=None):super().__init__(parent)  # 调用父类 QWidget 的构造函数self.setWindowTitle("绘图画布")  # 设置窗口标题self.setGeometry(100, 100, 280, 280)  # 设置窗口的初始位置和大小 (x, y, width, height)self.setMinimumSize(280, 280)# 创建一个 QImage 对象作为绘图缓冲区# 所有的绘图操作都在这个 QImage 上进行,然后整体绘制到屏幕,可以避免闪烁。# QImage.Format.Format_RGB32 是 PyQt6 中推荐的 RGBA 格式,支持透明度。self.image = QImage(self.size(), QImage.Format.Format_RGB32)# 将 QImage 填充为白色。self.image.fill(Qt.GlobalColor.white)self.drawing = False  # 一个布尔标志,指示当前是否正在进行鼠标拖拽绘图self.last_point = QPoint()  # 存储鼠标上次的位置,用于绘制连续的线条# 同样,颜色常量需要通过 Qt.GlobalColor 访问。self.pen_color = Qt.GlobalColor.blackself.pen_size = 20def paintEvent(self, event):"""绘制事件处理函数。每当窗口需要被重新绘制时(例如,首次显示、窗口大小改变、调用 update() 时),Qt 就会自动调用这个方法。"""painter = QPainter(self)  # 创建一个 QPainter 对象,指定在当前 QWidget (self) 上进行绘制# 将 self.image (绘图缓冲区) 的内容绘制到当前 QWidget 的整个矩形区域内。painter.drawImage(self.rect(), self.image, self.image.rect())def mousePressEvent(self, event):# 检查是否是鼠标左键被按下。if event.button() == Qt.MouseButton.LeftButton:self.drawing = True  # 设置绘图标志为 Trueself.last_point = event.pos()  # 记录当前鼠标位置作为线条的起始点def mouseMoveEvent(self, event):"""鼠标移动事件处理函数。当鼠标在窗口内移动时触发。"""# 只有当正在绘图 (self.drawing 为 True) 并且鼠标左键被按住时才执行绘图操作。# event.buttons() 返回当前按下的所有鼠标按钮的位掩码,Qt.MouseButton.LeftButton 用于检查左键是否按下。if self.drawing and event.buttons() & Qt.MouseButton.LeftButton:painter = QPainter(self.image)  # 在 QImage (绘图缓冲区) 上创建 QPainter 进行绘制# 设置画笔的颜色、粗细和样式。painter.setPen(QPen(QColor(self.pen_color), self.pen_size,Qt.PenStyle.SolidLine, Qt.PenCapStyle.RoundCap, Qt.PenJoinStyle.RoundJoin))# 绘制从上次记录的点到当前鼠标位置的直线painter.drawLine(self.last_point, event.pos())self.last_point = event.pos()  # 更新 last_point 为当前鼠标位置,为下一次绘制做准备self.update()  # 请求窗口重绘。这会间接调用 paintEvent,将 QImage 的最新内容显示到屏幕上。def mouseReleaseEvent(self, event):"""鼠标释放事件处理函数。当用户释放鼠标按钮时触发。"""# 检查是否是鼠标左键被释放。if event.button() == Qt.MouseButton.LeftButton:self.drawing = False  # 停止绘图def resizeEvent(self, event):"""窗口大小改变事件处理函数。当窗口大小改变时触发。"""# 如果新窗口的宽度或高度大于当前 QImage 的尺寸,则需要创建一个新的 QImage。if self.width() > self.image.width() or self.height() > self.image.height():new_image = QImage(self.size(), QImage.Format.Format_RGB32)# 填充新图像为白色new_image.fill(Qt.GlobalColor.white)painter = QPainter(new_image)# 将旧图像的内容绘制到新图像上,以保留已有的绘图。painter.drawImage(QPoint(0, 0), self.image)self.image = new_image  # 更新 self.image 为新的 QImageself.update()  # 请求重绘窗口def clear_canvas(self):"""清空画布内容,将整个 QImage 重新填充为白色。"""self.image.fill(Qt.GlobalColor.white)self.update()  # 请求重绘以显示空白画布def set_pen_size(self, size):"""设置画笔粗细。"""self.pen_size = size

2.处理图片

当布局完成后就只需要处理将图片变成输入的过程就好了,先给代码,在讲解

    def get_image(self):"""获取当前画布上的图像数据。返回一个 QImage 对象,包含当前画布的绘图内容。"""image = self.canvas.image# 将图像缩放到 28x28 像素并转换为灰度图scaled_image = image.scaled(28, 28,Qt.AspectRatioMode.IgnoreAspectRatio,  # 不保持宽高比Qt.TransformationMode.SmoothTransformation  # 平滑缩放)# 转换为 8 位灰度图grayscale_image = scaled_image.convertToFormat(QImage.Format.Format_Grayscale8)# 使用 qimage2ndarray.byte_view() 获取 NumPy 数组arr_3d = qimage2ndarray.byte_view(grayscale_image)arr = arr_3d.squeeze()# 将 NumPy 数组转换为 PyTorch 张量tensor_image = torch.from_numpy(arr).float()# --- 关键修正:添加颜色反转和标准化 ---# 1. 将像素值从 [0, 255] 归一化到 [0.0, 1.0]tensor_image = tensor_image / 255.0# 2. 颜色反转:如果你的模型是基于白色数字黑色背景训练的 而画布是黑色数字白色背景,则需要反转颜色tensor_image = 1.0 - tensor_image# 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std# 添加批次维度和通道维度,使形状变为 (1, 1, 28, 28)tensor_image = tensor_image.unsqueeze(0).unsqueeze(0).cuda()# --- 可视化 PyTorch 张量 ---# 为了可视化,我们先将其恢复到 [0,1] 范围,否则标准化后的值可能很难看# 逆标准化 (用于可视化,不影响模型输入)# visual_tensor = tensor_image * std + mean# # 确保在 [0,1] 范围内# visual_tensor = torch.clamp(visual_tensor, 0.0, 1.0)# plt.figure(figsize=(2, 2))# plt.imshow(visual_tensor.cpu().squeeze().numpy(), cmap='gray')# plt.title("input")# plt.axis('off')# plt.show()return tensor_image

其中有几个注意点
1.
目前的画布是白色的,画笔是黑色,但是mnist数据集的底是黑色的,画笔是白色的,因此需要使用

tensor_image = 1.0 - tensor_image

来将颜色取反,不然跟训练数据不一样模型无法良好运行。
2.
QT中的image是Qimage,转换成numpy代码有点麻烦,我这里图省事直接用了qimage2ndarray库,因此只需一行代码

arr_3d = qimage2ndarray.byte_view(grayscale_image)

就完成了这个操作。
3.
在输入到模型之前,要进行数据预处理,如上面的代码中

        # 3. 标准化:应用训练时使用的均值和标准差# MNIST 均值和标准差mean = 0.1307std = 0.3081tensor_image = (tensor_image - mean) / std

来优化模型效果。

3. 加载模型

这里的预训练权重就直接用了上一篇文章中训练出来的权重,还给她放到cuda上了,不过这么小的模型其实放不放其实都无所谓,没有太大的影响。

    def get_net(self):"""获取数字预测模型。返回一个 DigitCNN 模型实例。"""# 创建并返回一个 DigitCNN 模型实例net = DigitCNN()net.eval()net.cuda()net.load_state_dict(torch.load('./digit_CNN.pth'))return net

4. 预测

这里就没什么好说的了,就是简单地预测然后将结果同步到gui上了。

    def predict(self):"""预测当前画布上绘制的数字。这里可以调用模型进行预测,并更新预测结果标签。"""input = self.get_image()  # 获取当前画布上的图像数据# 使用模型进行预测with torch.no_grad():output = self.net(input)# 获取预测结果self.update_predict_result(output)def update_predict_result(self, output):_, predict = output.max(1)  # 获取预测的数字类别predict = predict.cpu().numpy()[0]# 更新预测结果标签self.predict_label.setText(f"预测结果: {predict}")# 更新每个数字的预测概率probabilities = torch.softmax(output, dim=1).cpu().numpy()[0]for i, label in enumerate(self.predict_digit_labels):label.setText(f"数字 {i}: {probabilities[i] * 100:.2f}%")

5.结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

6.一点小问题

现在模型是可以用了,但是因为Mnist数据集本身的局限性,已经网络也比较小,泛化性能比较差(但是没差到不能用的地步),所以预测结果又是后会比较奇怪,例如:

.在这里插入图片描述
这是mnist数据集中的数据,可以看出这里的0大部分都是上面闭合,导致模型预测奇怪位置的闭合的0会失准。

还有其中的4大部分都是开口的,并没有闭合4上面的开口,导致写一个很标准的4反倒有时候会预测出错,还有其他的一些问题我就不赘述了。

总之如果想要模型想要获得更好的表现,一是可以增强一下模型的能力,第二个我觉得更重的是把数据好好清洗一下,有些数据真的太差了

相关文章:

  • 声纹技术体系:从理论基础到工程实践的完整技术架构
  • VAE在扩散模型中的技术实现与应用
  • 算法训练第三天
  • 跑步前热身动作
  • Python应用for循环遍历寻b
  • RAGFlow从理论到实战的检索增强生成指南
  • 在win10/11下Node.js安装配置教程
  • Java 认识异常
  • 桥 接 模 式
  • 介绍一种LDPC码译码器
  • uv:现代化的 Python 包和项目管理工具
  • 解常微分方程组
  • GoogLeNet网络模型
  • 西瓜书第五章——感知机
  • 《江西棒球资讯》棒球运动发展·棒球1号位
  • 信息安全之为什么引入公钥密码
  • 5.31 专业课复习笔记 12
  • day42 简单CNN
  • 计算机组织原理第三章
  • C 语言栈实现详解:从原理到动态扩容与工程化应用(含顺序/链式对比、函数调用栈、表达式求值等)
  • wordpress网站空间/合肥关键词排名
  • 用ae做模板下载网站/灰色词排名推广
  • 网站前置审批怎么做/站长工具seo综合查询下载
  • 公众号视频网站怎么做/站长工具的使用seo综合查询排名
  • 华北建设招标网官方网站/百度百家号怎么赚钱
  • 设计类的网站/友情链接又称