数据集划分示例代码(图片、txt标注文档)
目录
- 数据集说明
- 代码示例
数据集说明
本次我划分的数据集是包含两个文件夹,一个文件夹是包含图片的,一个文件夹是包含txt格式文档的。
下面代码适用于划分包含图片和txt标注格式的数据集!
代码示例
import os
import random
import shutil
from pathlib import Pathdef split_image_dataset():"""在PyCharm中运行的完整数据集划分脚本"""# 设置随机种子保证可重复性random.seed(42)# === 配置部分:请根据你的实际情况修改这些路径 ===# 原始数据集路径(包含images和labels文件夹)dataset_path = r"D:\BaiduNetdiskDownload\train" # Windows路径示例# 输出路径output_dir = r"D:\DATAROOT"# 图片文件扩展名(根据你的实际情况修改)image_extensions = ['.png']# 划分比例train_ratio = 0.8val_ratio = 0.1test_ratio = 0.1# === 配置结束 ===# 检查路径是否存在images_dir = os.path.join(dataset_path, "images")labels_dir = os.path.join(dataset_path, "annfiles")if not os.path.exists(images_dir):print(f"错误: 图片目录不存在: {images_dir}")returnif not os.path.exists(labels_dir):print(f"错误: 标签目录不存在: {labels_dir}")returnprint("=== PyCharm数据集划分工具 ===")print(f"原始数据路径: {dataset_path}")print(f"输出路径: {output_dir}")# 创建输出目录for split in ['train', 'val', 'test']:os.makedirs(os.path.join(output_dir, split, 'images'), exist_ok=True)os.makedirs(os.path.join(output_dir, split, 'labels'), exist_ok=True)# 获取所有基础文件名(不含扩展名)all_files = []for file in os.listdir(images_dir):name, ext = os.path.splitext(file)if ext.lower() in image_extensions:all_files.append(name)total_count = len(all_files)print(f"\n找到 {total_count} 个有效样本")if total_count == 0:print("错误: 未找到任何图片文件!")print(f"支持的格式: {image_extensions}")print(f"请检查路径: {images_dir}")return# 随机打乱random.shuffle(all_files)# 计算划分点train_count = int(total_count * train_ratio)val_count = int(total_count * val_ratio)test_count = total_count - train_count - val_count# 划分数据集train_files = all_files[:train_count]val_files = all_files[train_count:train_count + val_count]test_files = all_files[train_count + val_count:]print(f"\n=== 划分结果 ===")print(f"训练集: {len(train_files)} 样本 ({len(train_files) / total_count:.1%})")print(f"验证集: {len(val_files)} 样本 ({len(val_files) / total_count:.1%})")print(f"测试集: {len(test_files)} 样本 ({len(test_files) / total_count:.1%})")def copy_files(file_list, split_name):"""复制文件到指定分割目录"""print(f"\n正在复制 {split_name} 集文件...")split_images_dir = os.path.join(output_dir, split_name, 'images')split_labels_dir = os.path.join(output_dir, split_name, 'labels')copied_count = 0missing_files = []for base_name in file_list:# 查找并复制图片文件image_found = Falsefor ext in image_extensions:image_src = os.path.join(images_dir, base_name + ext)if os.path.exists(image_src):image_dst = os.path.join(split_images_dir, base_name + ext)shutil.copy2(image_src, image_dst)image_found = Truebreak# 复制标签文件label_src = os.path.join(labels_dir, base_name + '.txt')if os.path.exists(label_src):label_dst = os.path.join(split_labels_dir, base_name + '.txt')shutil.copy2(label_src, label_dst)else:missing_files.append(base_name + '.txt')continueif image_found:copied_count += 1else:missing_files.append(base_name + ' (image)')print(f"✓ {split_name}集: 成功复制 {copied_count} 个样本")if missing_files:print(f"警告: 缺少 {len(missing_files)} 个文件")return copied_count# 复制所有划分的文件train_count = copy_files(train_files, 'train')val_count = copy_files(val_files, 'val')test_count = copy_files(test_files, 'test')# 验证划分结果print("\n=== 验证划分结果 ===")for split in ['train', 'val', 'test']:split_images_dir = os.path.join(output_dir, split, 'images')split_labels_dir = os.path.join(output_dir, split, 'labels')images_count = len([f for f in os.listdir(split_images_dir) if any(f.endswith(ext) for ext in image_extensions)])labels_count = len([f for f in os.listdir(split_labels_dir) if f.endswith('.txt')])print(f"{split}集: 图片={images_count}, 标签={labels_count}")# 保存划分信息print("\n=== 保存划分信息 ===")with open(os.path.join(output_dir, 'train.txt'), 'w') as f:for file in train_files:f.write(file + '\n')with open(os.path.join(output_dir, 'val.txt'), 'w') as f:for file in val_files:f.write(file + '\n')with open(os.path.join(output_dir, 'test.txt'), 'w') as f:for file in test_files:f.write(file + '\n')# 保存统计信息stats_content = f"""数据集划分统计信息
==================总样本数: {total_count}
训练集: {len(train_files)} ({len(train_files) / total_count:.1%})
验证集: {len(val_files)} ({len(val_files) / total_count:.1%})
测试集: {len(test_files)} ({len(test_files) / total_count:.1%})划分比例: {train_ratio:.0%}/{val_ratio:.0%}/{test_ratio:.0%}
随机种子: 42
"""with open(os.path.join(output_dir, 'split_statistics.txt'), 'w') as f:f.write(stats_content)print(f"\n🎉 数据集划分完成!")print(f"📍 输出目录: {output_dir}")print(f"📊 划分统计: {len(train_files)}/{len(val_files)}/{len(test_files)} (训练/验证/测试)")# 运行主函数
if __name__ == "__main__":split_image_dataset()
