当前位置: 首页 > news >正文

tensorflow加载和预处理数据

总的来说,TensorFlow 的数据加载与预处理可以分为两大类:

  1. 对于中小型数据集:可以直接在内存中加载并使用 TensorFlow 内置操作进行预处理。

  2. 对于大型数据集:使用 tf.data.Dataset API 从磁盘上动态加载和预处理,构建高效的数据管道。

一、数据预处理

1、tf中的数据预处理——tf.data管道的内容

tf.data主要是将数据转化成一个高效的、可被模型消费的数据流。在深度学习项目中,数据预处理和加载常常是性能瓶颈之一。特别是当模型训练得越来越快(例如使用 GPU/TPU)时,CPU 可能无法及时准备好下一批数据,导致昂贵的加速器空闲等待。tf.data 就是为了解决这个问题而设计的,它允许你构建灵活、高效的并行数据预处理流程。

该API围绕两个核心概念构建:

  • Dataset:一个包含一系列元素的集合,其中每个元素可以是一个或多个张量。dataset是数据管道基础,可以通过多种方式创建它
  • Transformation:对Dataset对象进行操作,生成一个新的Dataset。map逐元素转换、batch多元素转换、shuffle全局转换
(1)创建dataset数据集的方式

除了以下几种方式还有一种方式是从TFRecord文件中创建(第二部分介绍),这种方法是最高效的格式,对于大规模数据能够更好的进行序列化、存储和高效读取。

a.从NumPy 数组创建
import tensorflow as tf
import numpy as np# 假设你有一些 NumPy 数据
data_np = np.array([1, 2, 3, 4, 5], dtype=np.float32)
labels_np = np.array([0, 1, 0, 1, 0], dtype=np.int32)# 使用 from_tensor_slices 创建 Dataset
dataset = tf.data.Dataset.from_tensor_slices((data_np, labels_np))# 迭代查看(在 Eager Execution 模式下可直接迭代)
for data, label in dataset.take(3): # 取前3个元素print(f"Data: {data.numpy()}, Label: {label.numpy()}")
# 输出: Data: 1.0, Label: 0
#       Data: 2.0, Label: 1
#       Data: 3.0, Label: 0
b. 从 Python 列表创建
data_list = [10.0, 20.0, 30.0]
labels_list = [True, False, True]dataset = tf.data.Dataset.from_tensor_slices((data_list, labels_list))
c.从字典创建(用于结构化数据)
features_dict = {'feature_a': [1, 2, 3],'feature_b': [0.5, 0.6, 0.7],'label': ['cat', 'dog', 'cat']
}dataset = tf.data.Dataset.from_tensor_slices(features_dict)

from_tensor_slices()函数创建的 tf.data.Dataset中的元素是沿第一个维度的所有切片(要求输入张量的第一个维度的大小必须相同)。

from_tensor():函数将其中的内容视为一个单一的元素放入数据集中。

d.从文本文件中创建(TextLineDataset)

如果你的数据是每行一个样本的文本文件(如 CSV、日志文件、诗歌等),这是理想的选择。

# 创建一个示例文本文件
with open('sample_text.txt', 'w') as f:f.write("Hello, world!\n")f.write("How are you?\n")f.write("TensorFlow is great!\n")# 从文本文件创建 Dataset
text_dataset = tf.data.TextLineDataset('sample_text.txt')# 处理数据
def preprocess_line(line):line = tf.strings.strip(line) # 去除首尾空格和换行符line = tf.strings.lower(line) # 转换为小写return lineprocessed_text_dataset = text_dataset.map(preprocess_line)for line in processed_text_dataset:print(line.numpy().decode('utf-8'))
# 输出: hello, world!
#       how are you?
#       tensorflow is great!
e. 从CSV文件中创建
# 创建一个示例 CSV 文件
with open('sample_data.csv', 'w') as f:f.write("sepal_length,sepal_width,petal_length,petal_width,species\n")f.write("5.1,3.5,1.4,0.2,setosa\n")f.write("7.0,3.2,4.7,1.4,versicolor\n")# 方法一:使用 tf.data.experimental.make_csv_dataset (推荐,功能强大)
dataset = tf.data.experimental.make_csv_dataset(file_pattern='sample_data.csv',    # 文件路径batch_size=2,                      # 查看批次数label_name='species',              # 指定哪一列是标签num_epochs=1,ignore_errors=True)
# 查看一个批次
for batch, labels in dataset.take(1):print("Batch features:", {key: value.numpy() for key, value in batch.items()})print("Batch labels:", labels.numpy())# 方法二:使用 tf.data.TextLineDataset 并手动解析(更灵活)
def parse_csv_line(line, label_name='species'):# 定义列默认值column_defaults = [tf.float32, tf.float32, tf.float32, tf.float32, tf.string]# 解析一行columns = tf.io.decode_csv(line, record_defaults=column_defaults)# 将解析出的列组装成特征字典features = dict(zip(['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'], columns))# 从特征字典中弹出标签,剩下的就是特征label = features.pop(label_name)return features, labelcsv_dataset = tf.data.TextLineDataset('sample_data.csv').skip(1) # 跳过标题行
parsed_csv_dataset = csv_dataset.map(parse_csv_line)
f. 从文件路径创建图像数据集

这一是一种组合模式,通常先获取所有图片的路径和标签,之后从路径中加载图片。

# 假设你的目录结构如下:
# train/
#   cat/
#     cat1.jpg
#     cat2.jpg
#   dog/
#     dog1.jpg
#     dog2.jpgimport pathlibdata_dir = pathlib.Path('train/')# 1. 获取所有图片路径
all_image_paths = list(data_dir.glob('*/*.jpg'))
all_image_paths = [str(path) for path in all_image_paths]
# 打乱顺序
tf.random.set_seed(42)
random.shuffle(all_image_paths)# 2. 提取标签(从路径中获取文件夹名)
label_names = sorted(item.name for item in data_dir.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index, name in enumerate(label_names))
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]# 3. 创建路径和标签的 Dataset
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
label_ds = tf.data.Dataset.from_tensor_slices(all_image_labels)# 4. 定义加载和预处理单张图片的函数
def load_and_preprocess_image(path):image = tf.io.read_file(path)image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize(image, [192, 192]) # 调整大小image = image / 255.0 # 归一化到 [0,1]return image# 5. 使用 map 将路径 Dataset 转换为图像 Dataset
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)# 6. 将图像和标签压缩在一起,形成最终的 (image, label) 数据集
dataset = tf.data.Dataset.zip((image_ds, label_ds))# 7. 配置数据集以获得最佳性能
batch_size = 32
dataset = dataset.shuffle(buffer_size=1000) # 打乱顺序
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE) # 预取,重叠数据预处理和模型执行
数据来源 推荐方法 适用场景
小规模内存数据 from_tensor_slices() 实验、小批量数据、简单示例
文本行数据 TextLineDataset() 日志、诗歌、自定义格式文本
CSV 表格数据 experimental.make_csv_dataset() 结构化数据、表格数据
图像文件 from_tensor_slices(paths).map(load_fn) 计算机视觉任务
大规模数据集 TFRecordDataset() 生产环境、追求最高性能
(2)链式转换

有了数据集后可以通过调用转换方法对其进行各种转换。每个方法都会返回一个新的数据集,因此可以进行链式转换(对dataset对象连续调用多个方法)。数据集方法不会修改数据集,而是创建新的数据集,因此需要保留对新数据集的引用。

通用性能优化技巧:

  • .map(map_func, num_parallel_calls=tf.data.AUTOTUNE): 并行化数据转换。
  • .cache(): 将数据集缓存到内存或磁盘,避免每个 epoch 重复计算。
  • .shuffle(buffer_size): 打乱数据顺序,buffer_size 应尽可能大。
  • .batch(batch_size): 将样本组合成批次。
  • .prefetch(buffer_size=tf.data.AUTOTUNE): 最重要的一条,让数据预处理和模型训练过程重叠,几乎总是应该使用。

典型的转换链顺序为:shuffle → map → batch → repeat → prefetch

转换链需要根据具体需求进行调整,例如:

  • 缓存优化:如果预处理很耗时且数据集能够放入到内存中,可以在map方法后添加cache()方法。shuffle -> map -> cache() -> batch -> repeat -> prefetch
  • 过滤数据:如果需要过滤某些样本,可以再map前后使用filter方法。shuffle -> filter -> map -> batch -> repeat -> prefetch
  • 批处理优先:某些特定的数据增强需要在批次级别进行,这时需要shuffle -> map -> batch -> map(批处理增强) -> repeat -> prefetch

总之,这个流程有着从轻量级操作到重量级操作,在最合适的层级进行并行处理,保证随机性,优化性能等优点。

a. shuffle乱序数据

乱序数据,最先shuffle轻量级元素成本远低于shuffle已经开始加载和预测里后的重量级数据

shuffle()方法打乱数据的原始顺序,确保模型在每个epoch中看到的数据都是随机分布,对训练的效果至关重要。

suffle()方法会创建一个新的数据集,该数据集首先将源数据集的第一项元素填充到缓冲区中。然后,无论何时要求提供一个元素,它都会从缓冲区中随机取出一个元素,并用源数据集中的新元素替换它,直到完全遍历完源数据集为止。它将继续从缓冲区中随机抽取元素直到其为空。必须指定缓冲区buffer_size的大小,重要的是要使其足够大,否则乱序处理不会非常有效。

不要超出所拥有的RAM的数量(因为缓冲区是保存在内存中的,过大会导致内存不足甚至程序崩溃),即使有足够的RAM,也不要超出数据集的大小(因为超出样本集的缓冲区会永远用不上导致空间浪费)。如果每次运行程序都需要相同的随机顺序,那么可以提供随机种子。

# 乱序数据
dataset = tf.data.Dataset.range(10).repeat(2)
dataset = dataset.shuffle(buffer_size=4, seed=42).batch(7)
for item in dataset:print(item)

如果在经过乱序处

http://www.dtcms.com/a/403576.html

相关文章:

  • DAY 03 CSS的认识
  • 黑群晖做php网站pc网站手机网站
  • Jakarta EE 实验 — Web 聊天室(过滤器、监听器版)
  • 做js题目的网站知乎抖音代运营公司合法吗
  • MyBatis的最佳搭档(MyBatis-Plus)
  • 无用知识研究:和普通函数不同,返回类型也参与了模板函数的signature
  • 简单小结类与对象
  • Java 大视界 -- Java 大数据机器学习模型在金融风险传染路径分析与防控策略制定中的应用
  • 【C++】Template:深入理解特化与分离编译,破解编译难题
  • 【把15v方波转为±7.5v的方波】2022-12-21
  • 自己可以做一个网站吗自己怎么做直播网站吗
  • 嵌入式开发常见问题解决:Keil头文件路径与MCUXpresso外设配置错误
  • 从Android到iOS:启动监控实现的跨平台技术对比
  • 数据开放网站建设内容大连可以做网站的公司
  • lesson67:JavaScript事件绑定全解析:从基础到高级实践
  • 软件开发还是网站开发好惠州seo招聘
  • ARM芯片架构之CoreSight系统架构规范
  • 品牌网站建设黑白I狼J足球比赛直播网
  • 支持向量机深度解析:从数学原理到工程实践的完整指南——核技巧与凸优化视角下的模式识别革命
  • FPGA有什么作用和功能,主副关系是什么,跟通道有什么关系
  • 怎么做整蛊网站dw自己做的网站手机进不去
  • Udp 和 Tcp socket的一般编程套路(笔记)
  • C++_STL和数据结构《3》_仿函数作为STL中算法参数的用法、匿名函数、序列容器使用、关联容器使用、无关联容器使用、容器适配器使用
  • php基础-流程控制(第12天)
  • 怎样建设尧都水果网站网页游戏网站556pk游戏福利平台
  • logo做ppt模板下载网站简历制作官网
  • LeetCode:51.岛屿数量
  • English Around the House and Farm
  • 目标速度估计中MLE和CRLB运用(二)
  • 沈阳网站建设找思路做区位分析的地图网站