tensorflow image_dataset_from_directory 训练数据集构建
以数据集 https://www.kaggle.com/datasets/vipoooool/new-plant-diseases-dataset 为例
目录结构
训练图像数据集要求:
- 主目录下包含多个子目录,每个子目录代表一个类别。
- 每个子目录中存储属于该类别的图像文件。
例如
main_directory/
...cat/
......cat_image_1.jpg
......cat_image_2.jpg
...dog/
......dog_image_1.jpg
......dog_image_2.jpg
main_directory
是主目录。cat
和dog
是两个类别对应的子目录。- 子目录中的文件是属于该类别的图像文件。
在 TensorFlow 和 Keras 中,image_dataset_from_directory
是一个用于从文件系统中加载图像数据的便捷函数。它可以从目录结构中自动推断标签,并生成一个 tf.data.Dataset
对象,便于模型训练和评估。
download_path = kagglehub.dataset_download("vipoooool/new-plant-diseases-dataset")
print("Path to dataset files:", download_path)# 定义数据集路径
dataset_path = f"{download_path}/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)"
# 定义训练集目录
trainDir = os.path.join(dataset_path, "train")# 开启 TensorFlow 设备放置日志,方便调试时查看运算在哪个设备上执行
# tf.debugging.set_log_device_placement(True)print("trainDir:", trainDir)
# 从训练集目录加载图像数据集
training_set = keras_utils.image_dataset_from_directory(trainDir,labels="inferred", # 从目录结构推断图像标签label_mode="categorical", # 使用独热编码的标签class_names=None, # 自动推断类别名称color_mode="rgb", # 处理 RGB 图像batch_size=32, # 每个批次包含 32 张图像image_size=(128, 128), # 将图像大小调整为 128x128shuffle=True, # 打乱数据集seed=None, # 不设置随机种子validation_split=None, # 不进行数据集划分subset=None, # 不指定子集interpolation="bilinear", # 使用双线性插值调整图像大小follow_links=False, # 不跟随符号链接crop_to_aspect_ratio=False, # 不按纵横比裁剪图像
)