当前位置: 首页 > wzjs >正文

网站建设 服饰鞋帽惠州seo计费

网站建设 服饰鞋帽,惠州seo计费,微商城网站建设怎么样,免费商城版网站我这里准备了imgs文件夹,里面有0名字为0-9的9个目录,每个目录内的所有图片的数字和目录名相同, 比方说5目录中图片内容都是5. mod.py是模型内容 xl.py调用模型进行训练 a.py调用模型进行测试 mod.py import torch import torch.nn…

我这里准备了imgs文件夹,里面有0名字为0-9的9个目录,每个目录内的所有图片的数字和目录名相同,  比方说5目录中图片内容都是5.

mod.py是模型内容      xl.py调用模型进行训练    a.py调用模型进行测试

mod.py

import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np# 定义卷积神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)  # 10类数字def forward(self, x):x = torch.relu(self.conv1(x))x = torch.max_pool2d(x, 2)x = torch.relu(self.conv2(x))x = torch.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 保存模型
def save_model(model, filepath):torch.save(model.state_dict(), filepath)print(f"Model saved to {filepath}")# 加载模型  这里是从选定文件夹加载  图片路径和对应的 内容数字(label)
#这里目录  是有0-9  9个目录 比如说2文件夹里的图片内容都是2
def load_model_from_file(filepath):model = SimpleCNN()if os.path.exists(filepath):model.load_state_dict(torch.load(filepath))model.eval()print(f"Model loaded from {filepath}")return modelelse:print(f"Model file {filepath} does not exist!")return None# 定义数据集类
class DigitDataset(Dataset):def __init__(self, img_folder, transform=None):self.img_folder = img_folderself.transform = transformself.img_paths = []self.labels = []for d in os.listdir(img_folder):label = int(d)L=os.listdir(f'./{img_folder}/{d}')for i in L:p=f'./{img_folder}/{d}/{i}'self.img_paths.append(p)self.labels.append(label)def __len__(self):return len(self.img_paths)def __getitem__(self, idx):img_path = self.img_paths[idx]label = self.labels[idx]img = Image.open(img_path).convert('L')if self.transform:img = self.transform(img)return img, label# 数据加载的转换
def get_transform():return transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),  # 灰度图归一化])

xl.py

import os
import torch
from torch.utils.data import DataLoader
from mod import SimpleCNN, save_model, DigitDataset, get_transform, load_model_from_file
from torch.optim import Adam
from torch.nn import CrossEntropyLossdef train_model(img_folder, model_path, epochs=5, batch_size=64):# 准备数据transform = get_transform()dataset = DigitDataset(img_folder, transform)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建模型model = SimpleCNN()# 如果提供了模型路径,并且该文件存在,则加载模型if model_path and os.path.exists(model_path):model = load_model_from_file(model_path)  # 加载已有的模型else:print("No existing model found, training from scratch.")optimizer = Adam(model.parameters(), lr=0.001)criterion = CrossEntropyLoss()# 开始训练for epoch in range(epochs):model.train()running_loss = 0.0for inputs, labels in dataloader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(dataloader)}")# 最终保存模型到一个文件save_model(model, model_path)if __name__ == "__main__":img_folder = 'imgs'  # 图像文件夹路径#save_dir = 'saved_models'  # 保存模型的文件夹model_path = 'res'  # 预先保存的模型路径# 如果存在模型路径,传入模型路径进行加载train_model(img_folder, model_path)

a.py

import torch
from mod import load_model_from_file
from torchvision import transforms
from PIL import Image,ImageOpsdef predict_digit(model, img_path):"""加载并预测单个数字图像"""transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])img = Image.open(img_path).convert('L')#黑白反转  因为我训练的图都是黑底白字    测试时用的相反img=ImageOps.invert(img)img = transform(img).unsqueeze(0)  # 增加批量维度with torch.no_grad():outputs = model(img)_, predicted = torch.max(outputs, 1)digit = predicted.item()return digitdef test_model(model_path, img_path):"""加载模型并对图像进行预测"""model = load_model_from_file(model_path)if model is not None:digit = predict_digit(model, img_path)print(f"Predicted digit: {digit}")else:print("Failed to load model!")if __name__ == "__main__":model_path = 'res'  # 最终保存的模型路径img_path = '7.bmp'  # 测试图片路径test_model(model_path, img_path)

http://www.dtcms.com/wzjs/67965.html

相关文章:

  • 怎样学好网站开发今日热点新闻事件及评论
  • wamp做网站seo视频教程百度云
  • 哪些网站开业做简单海报品牌全网推广
  • 一级a做爰片免费网站今日最新足球推荐
  • 网站策划报告书怎么做百度引流怎么推广
  • 在网上做网站百度客服在线咨询人工服务
  • wordpress调用图标搜索引擎seo如何赚钱
  • 惠州个人做网站联系人哈尔滨网站优化流程
  • 广州网站建设优化电商网站开发平台有哪些
  • 建设网站的目标客户群武汉新闻最新消息
  • 坊网站建设宁波seo免费优化软件
  • 嘉兴外贸网站建设环球贸易网
  • 做网站还 淘宝网络运营seo是什么
  • 设计网站behance市场调研方案范文
  • wordpress专业站内优化主要从哪些方面进行
  • 什么叫网站定位天津网站排名提升多少钱
  • 动态网站开发课程设计怎样在百度做广告宣传
  • b2c网站技术架构怎么开展网络营销推广
  • 兰州做网站公司有哪些网络营销比较成功的企业
  • 建设彩票网站需要多少投资营销策划方案公司
  • 网站的百度地图怎么做的百度手机助手app安卓版官方下载
  • wordpress删除重装谷歌网站优化推广
  • 做网站建设的网站郴州网站seo
  • 郑州哪有做网站的seo网页优化培训
  • 韩国购物网站有哪些百度推广合作
  • 免费ppt模板下载有哪些天津百度快速优化排名
  • 工控主机做网站服务器网络营销的基本流程
  • 电影网站建设java企业网络推广技巧
  • 张家口网站建设价格搜索排名影响因素
  • 中国高定十大品牌成都seo优化