tensorflow加载和预处理数据
总的来说,TensorFlow 的数据加载与预处理可以分为两大类:
对于中小型数据集:可以直接在内存中加载并使用 TensorFlow 内置操作进行预处理。
对于大型数据集:使用
tf.data.Dataset
API 从磁盘上动态加载和预处理,构建高效的数据管道。
一、数据预处理
1、tf中的数据预处理——tf.data管道的内容
tf.data主要是将数据转化成一个高效的、可被模型消费的数据流。在深度学习项目中,数据预处理和加载常常是性能瓶颈之一。特别是当模型训练得越来越快(例如使用 GPU/TPU)时,CPU 可能无法及时准备好下一批数据,导致昂贵的加速器空闲等待。tf.data 就是为了解决这个问题而设计的,它允许你构建灵活、高效的并行数据预处理流程。
该API围绕两个核心概念构建:
- Dataset:一个包含一系列元素的集合,其中每个元素可以是一个或多个张量。dataset是数据管道基础,可以通过多种方式创建它
- Transformation:对Dataset对象进行操作,生成一个新的Dataset。map逐元素转换、batch多元素转换、shuffle全局转换
(1)创建dataset数据集的方式
除了以下几种方式还有一种方式是从TFRecord文件中创建(第二部分介绍),这种方法是最高效的格式,对于大规模数据能够更好的进行序列化、存储和高效读取。
a.从NumPy 数组创建
import tensorflow as tf
import numpy as np# 假设你有一些 NumPy 数据
data_np = np.array([1, 2, 3, 4, 5], dtype=np.float32)
labels_np = np.array([0, 1, 0, 1, 0], dtype=np.int32)# 使用 from_tensor_slices 创建 Dataset
dataset = tf.data.Dataset.from_tensor_slices((data_np, labels_np))# 迭代查看(在 Eager Execution 模式下可直接迭代)
for data, label in dataset.take(3): # 取前3个元素print(f"Data: {data.numpy()}, Label: {label.numpy()}")
# 输出: Data: 1.0, Label: 0
# Data: 2.0, Label: 1
# Data: 3.0, Label: 0
b. 从 Python 列表创建
data_list = [10.0, 20.0, 30.0]
labels_list = [True, False, True]dataset = tf.data.Dataset.from_tensor_slices((data_list, labels_list))
c.从字典创建(用于结构化数据)
features_dict = {'feature_a': [1, 2, 3],'feature_b': [0.5, 0.6, 0.7],'label': ['cat', 'dog', 'cat']
}dataset = tf.data.Dataset.from_tensor_slices(features_dict)
from_tensor_slices()函数创建的 tf.data.Dataset中的元素是沿第一个维度的所有切片(要求输入张量的第一个维度的大小必须相同)。
from_tensor():函数将其中的内容视为一个单一的元素放入数据集中。
d.从文本文件中创建(TextLineDataset)
如果你的数据是每行一个样本的文本文件(如 CSV、日志文件、诗歌等),这是理想的选择。
# 创建一个示例文本文件
with open('sample_text.txt', 'w') as f:f.write("Hello, world!\n")f.write("How are you?\n")f.write("TensorFlow is great!\n")# 从文本文件创建 Dataset
text_dataset = tf.data.TextLineDataset('sample_text.txt')# 处理数据
def preprocess_line(line):line = tf.strings.strip(line) # 去除首尾空格和换行符line = tf.strings.lower(line) # 转换为小写return lineprocessed_text_dataset = text_dataset.map(preprocess_line)for line in processed_text_dataset:print(line.numpy().decode('utf-8'))
# 输出: hello, world!
# how are you?
# tensorflow is great!
e. 从CSV文件中创建
# 创建一个示例 CSV 文件
with open('sample_data.csv', 'w') as f:f.write("sepal_length,sepal_width,petal_length,petal_width,species\n")f.write("5.1,3.5,1.4,0.2,setosa\n")f.write("7.0,3.2,4.7,1.4,versicolor\n")# 方法一:使用 tf.data.experimental.make_csv_dataset (推荐,功能强大)
dataset = tf.data.experimental.make_csv_dataset(file_pattern='sample_data.csv', # 文件路径batch_size=2, # 查看批次数label_name='species', # 指定哪一列是标签num_epochs=1,ignore_errors=True)
# 查看一个批次
for batch, labels in dataset.take(1):print("Batch features:", {key: value.numpy() for key, value in batch.items()})print("Batch labels:", labels.numpy())# 方法二:使用 tf.data.TextLineDataset 并手动解析(更灵活)
def parse_csv_line(line, label_name='species'):# 定义列默认值column_defaults = [tf.float32, tf.float32, tf.float32, tf.float32, tf.string]# 解析一行columns = tf.io.decode_csv(line, record_defaults=column_defaults)# 将解析出的列组装成特征字典features = dict(zip(['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'], columns))# 从特征字典中弹出标签,剩下的就是特征label = features.pop(label_name)return features, labelcsv_dataset = tf.data.TextLineDataset('sample_data.csv').skip(1) # 跳过标题行
parsed_csv_dataset = csv_dataset.map(parse_csv_line)
f. 从文件路径创建图像数据集
这一是一种组合模式,通常先获取所有图片的路径和标签,之后从路径中加载图片。
# 假设你的目录结构如下:
# train/
# cat/
# cat1.jpg
# cat2.jpg
# dog/
# dog1.jpg
# dog2.jpgimport pathlibdata_dir = pathlib.Path('train/')# 1. 获取所有图片路径
all_image_paths = list(data_dir.glob('*/*.jpg'))
all_image_paths = [str(path) for path in all_image_paths]
# 打乱顺序
tf.random.set_seed(42)
random.shuffle(all_image_paths)# 2. 提取标签(从路径中获取文件夹名)
label_names = sorted(item.name for item in data_dir.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index, name in enumerate(label_names))
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]# 3. 创建路径和标签的 Dataset
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
label_ds = tf.data.Dataset.from_tensor_slices(all_image_labels)# 4. 定义加载和预处理单张图片的函数
def load_and_preprocess_image(path):image = tf.io.read_file(path)image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize(image, [192, 192]) # 调整大小image = image / 255.0 # 归一化到 [0,1]return image# 5. 使用 map 将路径 Dataset 转换为图像 Dataset
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)# 6. 将图像和标签压缩在一起,形成最终的 (image, label) 数据集
dataset = tf.data.Dataset.zip((image_ds, label_ds))# 7. 配置数据集以获得最佳性能
batch_size = 32
dataset = dataset.shuffle(buffer_size=1000) # 打乱顺序
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE) # 预取,重叠数据预处理和模型执行
数据来源 | 推荐方法 | 适用场景 |
---|---|---|
小规模内存数据 | from_tensor_slices() | 实验、小批量数据、简单示例 |
文本行数据 | TextLineDataset() | 日志、诗歌、自定义格式文本 |
CSV 表格数据 | experimental.make_csv_dataset() | 结构化数据、表格数据 |
图像文件 | from_tensor_slices(paths).map(load_fn) | 计算机视觉任务 |
大规模数据集 | TFRecordDataset() | 生产环境、追求最高性能 |
(2)链式转换
有了数据集后可以通过调用转换方法对其进行各种转换。每个方法都会返回一个新的数据集,因此可以进行链式转换(对dataset对象连续调用多个方法)。数据集方法不会修改数据集,而是创建新的数据集,因此需要保留对新数据集的引用。
通用性能优化技巧:
.map(map_func, num_parallel_calls=tf.data.AUTOTUNE)
: 并行化数据转换。.cache()
: 将数据集缓存到内存或磁盘,避免每个 epoch 重复计算。.shuffle(buffer_size)
: 打乱数据顺序,buffer_size
应尽可能大。.batch(batch_size)
: 将样本组合成批次。.prefetch(buffer_size=tf.data.AUTOTUNE)
: 最重要的一条,让数据预处理和模型训练过程重叠,几乎总是应该使用。
典型的转换链顺序为:shuffle → map → batch → repeat → prefetch
转换链需要根据具体需求进行调整,例如:
- 缓存优化:如果预处理很耗时且数据集能够放入到内存中,可以在map方法后添加cache()方法。shuffle -> map -> cache() -> batch -> repeat -> prefetch
- 过滤数据:如果需要过滤某些样本,可以再map前后使用filter方法。shuffle -> filter -> map -> batch -> repeat -> prefetch
- 批处理优先:某些特定的数据增强需要在批次级别进行,这时需要shuffle -> map -> batch -> map(批处理增强) -> repeat -> prefetch
总之,这个流程有着从轻量级操作到重量级操作,在最合适的层级进行并行处理,保证随机性,优化性能等优点。
a. shuffle乱序数据
乱序数据,最先shuffle轻量级元素成本远低于shuffle已经开始加载和预测里后的重量级数据
shuffle()方法打乱数据的原始顺序,确保模型在每个epoch中看到的数据都是随机分布,对训练的效果至关重要。
suffle()方法会创建一个新的数据集,该数据集首先将源数据集的第一项元素填充到缓冲区中。然后,无论何时要求提供一个元素,它都会从缓冲区中随机取出一个元素,并用源数据集中的新元素替换它,直到完全遍历完源数据集为止。它将继续从缓冲区中随机抽取元素直到其为空。必须指定缓冲区buffer_size的大小,重要的是要使其足够大,否则乱序处理不会非常有效。
不要超出所拥有的RAM的数量(因为缓冲区是保存在内存中的,过大会导致内存不足甚至程序崩溃),即使有足够的RAM,也不要超出数据集的大小(因为超出样本集的缓冲区会永远用不上导致空间浪费)。如果每次运行程序都需要相同的随机顺序,那么可以提供随机种子。
# 乱序数据
dataset = tf.data.Dataset.range(10).repeat(2)
dataset = dataset.shuffle(buffer_size=4, seed=42).batch(7)
for item in dataset:print(item)
如果在经过乱序处