pytourch训练识别单个数字的图片
我这里准备了imgs文件夹,里面有0名字为0-9的9个目录,每个目录内的所有图片的数字和目录名相同, 比方说5目录中图片内容都是5.
mod.py是模型内容 xl.py调用模型进行训练 a.py调用模型进行测试
mod.py
import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
# 定义卷积神经网络模型
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10) # 10类数字
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, 2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = x.view(-1, 64 * 7 * 7)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 保存模型
def save_model(model, filepath):
torch.save(model.state_dict(), filepath)
print(f"Model saved to {filepath}")
# 加载模型 这里是从选定文件夹加载 图片路径和对应的 内容数字(label)
#这里目录 是有0-9 9个目录 比如说2文件夹里的图片内容都是2
def load_model_from_file(filepath):
model = SimpleCNN()
if os.path.exists(filepath):
model.load_state_dict(torch.load(filepath))
model.eval()
print(f"Model loaded from {filepath}")
return model
else:
print(f"Model file {filepath} does not exist!")
return None
# 定义数据集类
class DigitDataset(Dataset):
def __init__(self, img_folder, transform=None):
self.img_folder = img_folder
self.transform = transform
self.img_paths = []
self.labels = []
for d in os.listdir(img_folder):
label = int(d)
L=os.listdir(f'./{img_folder}/{d}')
for i in L:
p=f'./{img_folder}/{d}/{i}'
self.img_paths.append(p)
self.labels.append(label)
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path = self.img_paths[idx]
label = self.labels[idx]
img = Image.open(img_path).convert('L')
if self.transform:
img = self.transform(img)
return img, label
# 数据加载的转换
def get_transform():
return transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)), # 灰度图归一化
])
xl.py
import os
import torch
from torch.utils.data import DataLoader
from mod import SimpleCNN, save_model, DigitDataset, get_transform, load_model_from_file
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
def train_model(img_folder, model_path, epochs=5, batch_size=64):
# 准备数据
transform = get_transform()
dataset = DigitDataset(img_folder, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 创建模型
model = SimpleCNN()
# 如果提供了模型路径,并且该文件存在,则加载模型
if model_path and os.path.exists(model_path):
model = load_model_from_file(model_path) # 加载已有的模型
else:
print("No existing model found, training from scratch.")
optimizer = Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()
# 开始训练
for epoch in range(epochs):
model.train()
running_loss = 0.0
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(dataloader)}")
# 最终保存模型到一个文件
save_model(model, model_path)
if __name__ == "__main__":
img_folder = 'imgs' # 图像文件夹路径
#save_dir = 'saved_models' # 保存模型的文件夹
model_path = 'res' # 预先保存的模型路径
# 如果存在模型路径,传入模型路径进行加载
train_model(img_folder, model_path)
a.py
import torch
from mod import load_model_from_file
from torchvision import transforms
from PIL import Image,ImageOps
def predict_digit(model, img_path):
"""加载并预测单个数字图像"""
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
img = Image.open(img_path).convert('L')
#黑白反转 因为我训练的图都是黑底白字 测试时用的相反
img=ImageOps.invert(img)
img = transform(img).unsqueeze(0) # 增加批量维度
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs, 1)
digit = predicted.item()
return digit
def test_model(model_path, img_path):
"""加载模型并对图像进行预测"""
model = load_model_from_file(model_path)
if model is not None:
digit = predict_digit(model, img_path)
print(f"Predicted digit: {digit}")
else:
print("Failed to load model!")
if __name__ == "__main__":
model_path = 'res' # 最终保存的模型路径
img_path = '7.bmp' # 测试图片路径
test_model(model_path, img_path)