当前位置: 首页 > news >正文

pytourch训练识别单个数字的图片

我这里准备了imgs文件夹,里面有0名字为0-9的9个目录,每个目录内的所有图片的数字和目录名相同,  比方说5目录中图片内容都是5.

mod.py是模型内容      xl.py调用模型进行训练    a.py调用模型进行测试

mod.py

import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np

# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # 10类数字

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 保存模型
def save_model(model, filepath):
    torch.save(model.state_dict(), filepath)
    print(f"Model saved to {filepath}")

# 加载模型  这里是从选定文件夹加载  图片路径和对应的 内容数字(label)
#这里目录  是有0-9  9个目录 比如说2文件夹里的图片内容都是2
def load_model_from_file(filepath):
    model = SimpleCNN()
    if os.path.exists(filepath):
        model.load_state_dict(torch.load(filepath))
        model.eval()
        print(f"Model loaded from {filepath}")
        return model
    else:
        print(f"Model file {filepath} does not exist!")
        return None

# 定义数据集类
class DigitDataset(Dataset):
    def __init__(self, img_folder, transform=None):
        self.img_folder = img_folder
        self.transform = transform
        self.img_paths = []
        self.labels = []
        for d in os.listdir(img_folder):
            label = int(d)
            L=os.listdir(f'./{img_folder}/{d}')
            for i in L:
                p=f'./{img_folder}/{d}/{i}'
                self.img_paths.append(p)
                self.labels.append(label)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        label = self.labels[idx]
        img = Image.open(img_path).convert('L')
        if self.transform:
            img = self.transform(img)
        return img, label

# 数据加载的转换
def get_transform():
    return transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # 灰度图归一化
    ])

xl.py

import os
import torch
from torch.utils.data import DataLoader
from mod import SimpleCNN, save_model, DigitDataset, get_transform, load_model_from_file
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

def train_model(img_folder, model_path, epochs=5, batch_size=64):
    # 准备数据
    transform = get_transform()
    dataset = DigitDataset(img_folder, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # 创建模型
    model = SimpleCNN()
    
    # 如果提供了模型路径,并且该文件存在,则加载模型
    if model_path and os.path.exists(model_path):
        model = load_model_from_file(model_path)  # 加载已有的模型
    else:
        print("No existing model found, training from scratch.")

    optimizer = Adam(model.parameters(), lr=0.001)
    criterion = CrossEntropyLoss()



    # 开始训练
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(dataloader)}")

    # 最终保存模型到一个文件
    save_model(model, model_path)

if __name__ == "__main__":
    img_folder = 'imgs'  # 图像文件夹路径
    #save_dir = 'saved_models'  # 保存模型的文件夹
    model_path = 'res'  # 预先保存的模型路径
    
    # 如果存在模型路径,传入模型路径进行加载
    train_model(img_folder, model_path)

a.py

import torch
from mod import load_model_from_file
from torchvision import transforms
from PIL import Image,ImageOps

def predict_digit(model, img_path):
    """加载并预测单个数字图像"""
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
    
    img = Image.open(img_path).convert('L')
    #黑白反转  因为我训练的图都是黑底白字    测试时用的相反
    img=ImageOps.invert(img)
    img = transform(img).unsqueeze(0)  # 增加批量维度
    with torch.no_grad():
        outputs = model(img)
        _, predicted = torch.max(outputs, 1)
        digit = predicted.item()
    
    return digit

def test_model(model_path, img_path):
    """加载模型并对图像进行预测"""
    model = load_model_from_file(model_path)
    if model is not None:
        digit = predict_digit(model, img_path)
        print(f"Predicted digit: {digit}")
    else:
        print("Failed to load model!")

if __name__ == "__main__":
    model_path = 'res'  # 最终保存的模型路径
    img_path = '7.bmp'  # 测试图片路径
    
    test_model(model_path, img_path)

相关文章:

  • 【STM32】DRV8833驱动电机
  • APlayer - APlayer 初识(APlayer 初识案例、APlayer 常用事件)
  • C++中常用的十大排序方法之4——希尔排序
  • 代码随想录算法训练营第三十九天| 动态规划03
  • 19.Python实战:实现对博客文章的点赞系统
  • 微信小程序中缓存数据全方位解惑
  • Unity 编辑器热更C# FastScriptReload
  • 安卓基础(Adapter)
  • JVM 底层探秘:对象创建的详细流程、内存分配机制解析以及线程安全保障策略
  • React生产环境下使用mock.js
  • VueRouter 实例
  • 单、双 链 表
  • MIMO信号检测ZF算法和MMSE算法
  • 深度求索—DeepSeek API的简单调用(Java)
  • 简单的异步图片上传
  • 游戏引擎学习第104天
  • ABB能源自动化选用宏集Cogent DataHub避免DCOM问题,实现高效、安全的数据传输
  • cuML机器学习GPU库
  • vue3的响应式的理解,与普通对象的区别
  • ROS基本功能
  • 华泰柏瑞基金总经理韩勇因工作调整卸任,董事长贾波代为履职
  • 2025年度上海市住房城乡建设管理委工程系列中级职称评审工作启动
  • 青年与人工智能共未来,上海创新创业青年50人论坛徐汇分论坛举办
  • 眉山“笑气”迷局:草莓熊瓶背后的隐秘与危机
  • 洛杉矶奥组委确认2028年奥运会和残奥会开闭幕式场地
  • 见微知沪|优化营商环境,上海为何要当“细节控”自我加压?