当前位置: 首页 > wzjs >正文

便宜手机网站建设国内做会展比较好的公司

便宜手机网站建设,国内做会展比较好的公司,uc极速版福利一天能赚多少钱,电脑如何做穿透外网网站我这里准备了imgs文件夹,里面有0名字为0-9的9个目录,每个目录内的所有图片的数字和目录名相同, 比方说5目录中图片内容都是5. mod.py是模型内容 xl.py调用模型进行训练 a.py调用模型进行测试 mod.py import torch import torch.nn…

我这里准备了imgs文件夹,里面有0名字为0-9的9个目录,每个目录内的所有图片的数字和目录名相同,  比方说5目录中图片内容都是5.

mod.py是模型内容      xl.py调用模型进行训练    a.py调用模型进行测试

mod.py

import torch
import torch.nn as nn
import torch.optim as optim
import os
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np# 定义卷积神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)  # 10类数字def forward(self, x):x = torch.relu(self.conv1(x))x = torch.max_pool2d(x, 2)x = torch.relu(self.conv2(x))x = torch.max_pool2d(x, 2)x = x.view(-1, 64 * 7 * 7)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 保存模型
def save_model(model, filepath):torch.save(model.state_dict(), filepath)print(f"Model saved to {filepath}")# 加载模型  这里是从选定文件夹加载  图片路径和对应的 内容数字(label)
#这里目录  是有0-9  9个目录 比如说2文件夹里的图片内容都是2
def load_model_from_file(filepath):model = SimpleCNN()if os.path.exists(filepath):model.load_state_dict(torch.load(filepath))model.eval()print(f"Model loaded from {filepath}")return modelelse:print(f"Model file {filepath} does not exist!")return None# 定义数据集类
class DigitDataset(Dataset):def __init__(self, img_folder, transform=None):self.img_folder = img_folderself.transform = transformself.img_paths = []self.labels = []for d in os.listdir(img_folder):label = int(d)L=os.listdir(f'./{img_folder}/{d}')for i in L:p=f'./{img_folder}/{d}/{i}'self.img_paths.append(p)self.labels.append(label)def __len__(self):return len(self.img_paths)def __getitem__(self, idx):img_path = self.img_paths[idx]label = self.labels[idx]img = Image.open(img_path).convert('L')if self.transform:img = self.transform(img)return img, label# 数据加载的转换
def get_transform():return transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),  # 灰度图归一化])

xl.py

import os
import torch
from torch.utils.data import DataLoader
from mod import SimpleCNN, save_model, DigitDataset, get_transform, load_model_from_file
from torch.optim import Adam
from torch.nn import CrossEntropyLossdef train_model(img_folder, model_path, epochs=5, batch_size=64):# 准备数据transform = get_transform()dataset = DigitDataset(img_folder, transform)dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 创建模型model = SimpleCNN()# 如果提供了模型路径,并且该文件存在,则加载模型if model_path and os.path.exists(model_path):model = load_model_from_file(model_path)  # 加载已有的模型else:print("No existing model found, training from scratch.")optimizer = Adam(model.parameters(), lr=0.001)criterion = CrossEntropyLoss()# 开始训练for epoch in range(epochs):model.train()running_loss = 0.0for inputs, labels in dataloader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(dataloader)}")# 最终保存模型到一个文件save_model(model, model_path)if __name__ == "__main__":img_folder = 'imgs'  # 图像文件夹路径#save_dir = 'saved_models'  # 保存模型的文件夹model_path = 'res'  # 预先保存的模型路径# 如果存在模型路径,传入模型路径进行加载train_model(img_folder, model_path)

a.py

import torch
from mod import load_model_from_file
from torchvision import transforms
from PIL import Image,ImageOpsdef predict_digit(model, img_path):"""加载并预测单个数字图像"""transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])img = Image.open(img_path).convert('L')#黑白反转  因为我训练的图都是黑底白字    测试时用的相反img=ImageOps.invert(img)img = transform(img).unsqueeze(0)  # 增加批量维度with torch.no_grad():outputs = model(img)_, predicted = torch.max(outputs, 1)digit = predicted.item()return digitdef test_model(model_path, img_path):"""加载模型并对图像进行预测"""model = load_model_from_file(model_path)if model is not None:digit = predict_digit(model, img_path)print(f"Predicted digit: {digit}")else:print("Failed to load model!")if __name__ == "__main__":model_path = 'res'  # 最终保存的模型路径img_path = '7.bmp'  # 测试图片路径test_model(model_path, img_path)

http://www.dtcms.com/wzjs/536379.html

相关文章:

  • 资金盘网站开发费用京东官网
  • 建设网站你认为需要注意有什么做家纺的网站
  • 廊坊建设质量监督局网站深圳设计网站培训班
  • jsp租房网站开发郑州营销策划公司排行榜
  • 你建立的网站使用了那些营销方法英文网站建设服务合同模板下载
  • 怎么搜索整个网站内容营销网站定制的优势
  • 网站搜索功能怎么做如何做电影网站才不侵权
  • 网站用户维度鄂州网络推广
  • 建设工程安全管理网站做课内教学网站
  • 自定义表单网站在网站上签失业保险怎样做
  • 淮安网站建设推广网站后台发布新闻
  • 网站备案多少岁可以做网站的查询系统怎么做
  • 自己怎么做网站游戏宜兴建设局的网站
  • 文化企业网站模板网站开发图片压缩
  • 什么是网站栏目标题腾讯官方网站
  • 娄底网站建设网站菠菜网站模板
  • 龙岩网站设计大概价格代理财务记账公司
  • 大学高校网站建设栏目织梦cms网站模板修改
  • 青海西宁做网站多少钱工行gcms系统
  • 余姚市城乡建设局网站做网站必须购买空间吗
  • 网站引导动画怎么做贵金属交易app下载
  • 网站内页检测装修公司联系方式汇总
  • asp网站 证书哪个网站可以查公司注册信息
  • 仿唧唧帝笑话门户网站源码带多条采集规则 织梦搞笑图片视频模板磁力吧
  • 湖北长欣建设有限公司网站怎么知道网站是哪个公司做的
  • 怎么优化网站关键词排名范县网站建设
  • 阿里云网站搭建服务器如何配置php网站
  • 模板网站建设清单网站应该怎么建设
  • 泰安网站建设dxkjw网络搭建安全分析
  • 网站的死链微信小程序商城怎样做