当前位置: 首页 > news >正文

AutoKeras 处理图像回归预测

AutoKeras 是一个自动机器学习库,在处理图像回归预测问题时,它可以自动选择最佳的模型架构和超参数,从而简化深度学习模型的构建过程。 AutoKeras 主要用于分类和回归任务,它同样可以进行图像数据的回归预测。

步骤 1: 安装 AutoKeras

首先,确保你已经安装了 autokeras。如果没有安装,可以通过 pip 安装:

pip install autokeras

步骤 2: 准备数据

图像数据通常需要被预处理为适合模型输入的格式。这包括将图像调整为统一的大小、归一化等。

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
 
# 假设你有图像路径列表和相应的标签(回归值)
image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', ...]
labels = np.array([label1, label2, ...])  # 回归标签
 
# 加载和预处理图像
def load_and_preprocess_image(path, target_size=(128, 128)):
    from tensorflow.keras.preprocessing.image import load_img, img_to_array
    img = load_img(path, target_size=target_size)
    img = img_to_array(img) / 255.0  # 归一化到[0, 1]
    return img
 
images = np.array([load_and_preprocess_image(path) for path in image_paths])

步骤 3: 使用 AutoKeras 进行模型搜索

接下来,使用 AutoKeras 的 ImageRegressor 来自动寻找最佳的模型架构。

import autokeras as ak
 
# 初始化图像回归模型搜索器
regressor = ak.ImageRegressor(overwrite=True, max_trials=10)  # max_trials 控制搜索的模型数量
 
# 训练模型
regressor.fit(images, labels, epochs=10)  # 这里epochs的数量可以根据需要调整

步骤 4: 评估和预测

训练完成后,你可以评估模型的性能,并进行预测。

# 评估模型
loss, accuracy = regressor.evaluate(images, labels)
print(f"Loss: {loss}, Accuracy: {accuracy}")  # 注意:回归通常不使用准确率,此处仅为示例,实际应为回归指标如MSE或MAE
 
# 进行预测
predictions = regressor.predict(images)
print(predictions)

注意:

  • 评价指标:对于回归任务,常用的评价指标是均方误差(MSE)或平均绝对误差(MAE)。在 AutoKeras 中,你可以通过自定义回调函数来指定这些指标。例如:

  • 性能调优AutoKeras 的 max_trials 参数控制了搜索空间的大小。增加这个值可以提供更好的模型性能,但同时也会增加训练时间。根据具体问题的复杂性和计算资源,适当调整这个值。

  • 数据划分:在实际应用中,通常会将数据划分为训练集、验证集和测试集,以确保模型的泛化能力。你可以使用 sklearn.model_selection.train_test_split 来划分数据。

通过上述步骤,你可以使用 AutoKeras 来处理图像回归预测问题。

--------------------------------------------------------------------------------------------------------------------------------

完整示例代码

注:以下代码通过测试运行,运行环境Python 3.8、AutoKeras 1.1.0、TensorFlow  2.13.1

import os
import autokeras as ak
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from sklearn.model_selection import train_test_split

# 压缩图像数据和归一化
# 注意:load_img以来PIL库,通过pip install pillow安装
def load_and_preprocess_img(img_path, target_size=(200, 200)):
    img = load_img(img_path, target_size=target_size)
    img = img_to_array(img) / 255.0
    return img


def make_input_data():
    images_path = '/root/images'
    files = os.listdir(images_path)

    png_data_list = []
    sate_count_list = []
    total_count = len(files)
    index = 0
    for file in files:
        index += 1
        if not file.endswith('.png'):
            continue

        png_path = os.path.join(images_path, file)
        index_s = file.find('S')
        index_dot = file.find('.')
        sate_count = float(file[index_s + 1:index_dot])

        img_array = load_and_preprocess_img(png_path)
        png_data_list.append(img_array)

        sate_count_list.append(sate_count)
        print('Progress: {}/{} {} {}'.format(index, total_count, file, sate_count))

    return np.array(png_data_list), np.array(sate_count_list)



def do_work():
    X, Y = make_input_data()
    x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)
    x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2)
	del X,Y

    # 定义AutoKeras图像回归模型
    regressor = ak.ImageRegressor(
        overwrite=True,  # 覆盖之前的训练结果
        max_trials=3,  # 最大超参数优化次数
        metrics=["mae"]  # 回归任务常用指标:平均绝对误差、均方误差
    )
    regressor.fit(x_train, y_train, epochs=5, validation_data=(x_val, y_val))
    print('Finish fiting')

    y_pred = regressor.predict(x_test)
    print('Finish predicting')

    MAE = mean_absolute_error(y_test, y_pred)
    print(f"测试集MAE: {MAE:.4f}")

if __name__ == '__main__':
    print('OK')
    do_work()
    print('END')

http://www.dtcms.com/a/123345.html

相关文章:

  • spark-core学习内容总结
  • 【完美解决】VSCode连接HPC节点,已配置密钥却还是提示需要输入密码
  • 京华幻梦:科技自然共生诗篇
  • 【蓝桥杯】二分查找
  • springcloud进阶
  • SkyWalking + ELK 全链路监控系统整合指南
  • FPGA_DDR(二)
  • Go语言编写一个进销存Web软件的demo
  • python基础语法1:输入输出
  • Java 基础 - 反射(1)
  • Java学习——day26(线程同步与共享资源保护)
  • FastAPI用户认证系统开发指南:从零构建安全API
  • Cloudflare 缓存工作原理
  • ComfyUI_Echomimic部署问题集合
  • 企业信息化-系统架构师(九十八)
  • 玩转Docker | 使用Docker搭建pinry图片展示系统
  • swagger + Document
  • 修改 docker 工作目录
  • MySQL的索引下推是什么
  • opengrok使用指南
  • 了解 DeFi:去中心化金融的入门指南与未来展望
  • JS—防抖和节流:1分钟掌握防抖和节流
  • 【ctfplus】python靶场记录-任意文件读取+tornado模板注入+yaml反序列化(新手向)
  • 良渚实验室郭国骥/夏宏光团队合作开发单细胞水平筛选抗肿瘤药物的深度学习框架——“神农”
  • 蓝桥杯C++组算法知识点整理 · 考前突击(上)【小白适用】
  • Java 面试总结
  • 数据结构 | 证明链表环结构是否存在
  • ubuntu设备磁盘空间不足 处理办法
  • WinForm真入门(12)——RadioButton控件详解
  • C++中static与private继承关系解析