当前位置: 首页 > news >正文

tensorflow图像分类预测

tensorflow图像分类预测

CPU版本和GPU版本二选一

CPU版本

pip -m install --upgrade pip
pip install matplotlib pillow scikit-learn
pip install tensorflow-intel==2.18.0

GPU版本

工具 miniconda

  1. 升级依赖库

    conda update --all
    
  2. 创建目录

    mkdir gpu-tf
    
  3. 进入目录

    cd gpu-tf
    
  4. 创建虚拟环境

    conda create -p tf210-310 python==3.10.16
    
  5. 激活虚拟环境

    conda activate D:\gpu-tf\tf210-310
    
  6. 重新安装pip

    python -m pip uninstall pip
    python -m ensurepip --upgrade
    
  7. 升级 setuptools wheel

    python -m pip install --upgrade pip setuptools wheel
    

安装cudacudnn

conda install cudatoolkit==11.3.1 cudnn==8.2.1

安装 numpy

解决版本兼容

pip install numpy==1.26.4

安装 tensorflow-gpu

pip install tensorflow==2.10.1

只安装GPU版本执行以下命令:

pip install tensorflow-gpu==2.10.1

安装依赖

pip install matplotlib pillow scikit-learn

代码

import tensorflow as tf
from tensorflow.keras import layers, models
import os
import numpy as np
from sklearn.model_selection import train_test_split
from PIL import Image
import matplotlib.pyplot as pltplt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False# 步骤 1: 加载和预处理数据
def load_images_and_labels(data_folder, label_mapping):images = []labels = []invalid_images = []  # 用于存储无效图像的路径# 遍历数据文件夹中的每个子文件夹for root, dirs, files in os.walk(data_folder):for file in files:if file.endswith('.jpg') or file.endswith('.png'):file_path = os.path.join(root, file)try:# 打开图像文件image = Image.open(file_path)# 调整图像大小为统一尺寸image = image.resize((224, 224))# 确保所有图像都是RGB模式if image.mode != 'RGB':image = image.convert('RGB')# 将图像转换为 numpy 数组image_array = np.array(image)# 检查图像数组的形状是否符合预期if image_array.shape != (224, 224, 3):print(f"警告: {file_path} 的形状不符合预期: {image_array.shape}")continue# 归一化图像数据image_array = image_array / 255.0images.append(image_array)# 获取标签名称label_name = os.path.basename(root)# 根据标签映射获取标签索引label = label_mapping[label_name]labels.append(label)except Exception as e:print(f"无法加载图像 {file_path}: {str(e)}")invalid_images.append(file_path)continueprint(f"成功加载 {len(images)} 张图像")print(f"无效图像数量: {len(invalid_images)}")return np.array(images), np.array(labels)# 定义标签映射
label_mapping = {"bus": 0,"taxi": 1,"truck": 2,"family sedan": 3,"minibus": 4,"jeep": 5,"SUV": 6,"heavy truck": 7,"racing car": 8,"fire engine": 9
}# 反向映射,用于将预测结果的索引转换为类别名称
reverse_label_mapping = {v: k for k, v in label_mapping.items()}# 数据文件夹路径
data_folder = 'data/car/train'# 检查是否需要训练模型
train_model = True  # 设置为 False 表示使用已保存的模型,设置为 True 表示重新训练模型if train_model:# 加载图像和标签print("开始加载图像数据...")images, labels = load_images_and_labels(data_folder, label_mapping)if len(images) == 0:print("没有找到有效的图像数据,请检查数据路径和格式!")else:# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=42)print(f"训练集大小: {len(X_train)}, 测试集大小: {len(X_test)}")print(f"图像形状: {X_train[0].shape}")# 步骤 2: 构建 CNN 模型model = models.Sequential()# 第一个卷积层,32 个滤波器,卷积核大小为 3x3,激活函数为 ReLUmodel.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))# 最大池化层,池化窗口大小为 2x2model.add(layers.MaxPooling2D((2, 2)))# 第二个卷积层,64 个滤波器,卷积核大小为 3x3,激活函数为 ReLUmodel.add(layers.Conv2D(64, (3, 3), activation='relu'))# 最大池化层,池化窗口大小为 2x2model.add(layers.MaxPooling2D((2, 2)))# 第三个卷积层,64 个滤波器,卷积核大小为 3x3,激活函数为 ReLUmodel.add(layers.Conv2D(64, (3, 3), activation='relu'))# 将多维数据展平为一维向量model.add(layers.Flatten())# 全连接层,64 个神经元,激活函数为 ReLUmodel.add(layers.Dense(64, activation='relu'))# 输出层,神经元数量等于类别数,激活函数为 Softmaxmodel.add(layers.Dense(len(label_mapping), activation='softmax'))# 显示模型结构model.summary()# 步骤 3: 编译模型model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 步骤 4: 训练模型print("开始训练模型...")history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))# 步骤 5: 评估模型test_loss, test_acc = model.evaluate(X_test, y_test)print(f"Test accuracy: {test_acc}")# 步骤 6: 保存模型model.save('model/vehicle_classification_model.h5')print("模型已保存为 vehicle_classification_model.h5")# 绘制训练历史plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(history.history['accuracy'])plt.plot(history.history['val_accuracy'])plt.title('Model accuracy')plt.ylabel('Accuracy')plt.xlabel('Epoch')plt.legend(['Train', 'Test'], loc='upper left')plt.subplot(1, 2, 2)plt.plot(history.history['loss'])plt.plot(history.history['val_loss'])plt.title('Model loss')plt.ylabel('Loss')plt.xlabel('Epoch')plt.legend(['Train', 'Test'], loc='upper left')plt.tight_layout()plt.show()
else:# 步骤 7: 加载已保存的模型try:model = tf.keras.models.load_model('model/vehicle_classification_model.h5')print("已加载保存的模型")except Exception as e:print(f"加载模型失败: {str(e)}")print("请确保模型文件存在并且格式正确,或者设置 train_model=True 重新训练模型")# 步骤 8: 预测新图片分类
def predict_image(model, image_path, reverse_label_mapping):try:# 打开图像文件image = Image.open(image_path)# 调整图像大小为模型输入所需的尺寸image = image.resize((224, 224))# 确保图像是RGB模式if image.mode != 'RGB':image = image.convert('RGB')# 显示图像plt.imshow(image)plt.axis('off')plt.show()# 将图像转换为 numpy 数组image = np.array(image)# 归一化图像数据image = image / 255.0# 增加一个维度,因为模型期望的输入形状是 (batch_size, height, width, channels)image = np.expand_dims(image, axis=0)# 使用模型进行预测predictions = model.predict(image)# 获取预测结果的索引(即预测的类别)predicted_class_index = np.argmax(predictions[0])# 获取预测的类别名称predicted_class_name = reverse_label_mapping[predicted_class_index]# 获取预测的置信度confidence = np.max(predictions[0]) * 100print(f"预测结果: {predicted_class_name},置信度: {confidence:.2f}%")# 显示预测概率分布plt.figure(figsize=(10, 4))plt.bar(reverse_label_mapping.values(), predictions[0])plt.xticks(rotation=45, ha='right')plt.title('预测概率分布')plt.ylabel('概率')plt.tight_layout()plt.show()return predicted_class_name, confidenceexcept Exception as e:print(f"预测失败: {str(e)}")return None, None# 步骤 9: 使用示例
# 替换为你自己的测试图片路径
test_image_path = 'data/car/val/SUV/a3c4f639c87e59383cfec1062b0ebd1b.jpg'
if 'model' in locals() and os.path.exists(test_image_path):predict_image(model, test_image_path, reverse_label_mapping)print('success')
else:print("请提供有效的测试图片路径并确保模型已成功加载")

运行结果

tensorflow
tensorflow
tensorflow

相关文章:

  • IDEA - Windows IDEA 代码块展开与折叠(基础折叠操作、高级折叠操作)
  • 渗透测试流程-中篇
  • 5、事务和limit补充
  • Linux的内存泄漏问题及排查方法
  • 【通用智能体】Playwright:跨浏览器自动化工具
  • C++学习:六个月从基础到就业——C++20:协程(Coroutines)
  • 【Linux】ELF与动静态库的“暗黑兵法”:程序是如何跑起来的?
  • IDE/IoT/搭建物联网(LiteOS)集成开发环境,基于 LiteOS Studio + GCC + JLink
  • Ansible模块——文件内容修改
  • 【Linux】简易版Shell实现(附源码)
  • Day29 类的装饰器
  • PopSQL:一个支持团队协作的SQL开发工具
  • 机器学习(12)——LGBM(1)
  • 软件架构之--论微服务的开发方法1
  • 一种开源的高斯泼溅实现库——gsplat: An Open-Source Library for Gaussian Splatting
  • Leetcode 3553. Minimum Weighted Subgraph With the Required Paths II
  • EMQX开源版安装指南:Linux/Windows全攻略
  • 初学c语言15(字符和字符串函数)
  • 【图像生成大模型】Wan2.1:下一代开源大规模视频生成模型
  • windows笔记本连接RKNN3588网络配置解析
  • 原核试验基地司令员范如玉逝世,从事核试验研究超40年
  • “80后”萍乡市安源区区长邱伟,拟任县(区)委书记
  • 英德宣布开发射程超2000公里导弹,以防务合作加强安全、促进经济
  • 体坛联播|巴萨提前2轮西甲夺冠,郑钦文不敌高芙止步4强
  • 上海黄浦江挡潮闸工程建设指挥部成立,组成人员名单公布
  • 奥古斯都时代的历史学家李维