TensorFlow数据流水线(tf.data)

深度学习专题 · 高效数据加载与预处理

专题:深度学习系统学习

关键词:深度学习, tf.data, Dataset, TFRecord, prefetch, cache, 高性能, 数据流水线, FeatureColumn

一、tf.data 概述

tf.data 是 TensorFlow 2.x 中官方推荐的数据输入流水线 API,它提供了一套高性能、可扩展的数据处理工具,能够将原始数据(文件、张量、生成器等)构建为高效的输入流水线。与传统使用 feed_dict 或手动编写数据加载循环的方式相比,tf.data 能更好地利用硬件资源(CPU、GPU、TPU),实现数据加载与模型训练的并行化。

tf.data 的核心概念是 tf.data.Dataset,它表示一个元素的序列,每个元素包含一个或多个 Tensor 对象。Dataset 可以像 Python 迭代器一样使用,但背后实现了大量性能优化:流水线重叠、并行化数据变换、自动调优参数等。

核心理念:tf.data 将数据加载与模型训练解耦,通过声明式的 API 链式调用构建数据流水线,让框架自动处理并发、预取和缓存等复杂任务,开发者只需关注数据变换逻辑。

从更高的视角看,tf.data 不仅仅是一个数据加载工具,它提供了完整的 ETL(Extract, Transform, Load)流程:从各种数据源提取原始数据,执行清洗和增强变换,最后以最优方式加载到训练循环中。数据流水线的性能直接影响 GPU 利用率和整体训练速度,因此掌握 tf.data 是深度学习工程实践中的关键技能。

二、Dataset 的创建方式

tf.data.Dataset 提供了多种静态和动态创建方法,覆盖了绝大部分实际场景。根据数据源的不同,选用的创建方式也各不相同。

2.1 from_tensor_slices

最基础的创建方式,将 NumPy 数组或 TensorFlow 张量按第一个维度切分成多个样本。常用于小数据集或内存数据。

import tensorflow as tf import numpy as np # 从 NumPy 数组创建 features = np.random.randn(100, 28, 28, 3).astype(np.float32) labels = np.random.randint(0, 10, size=(100,)) dataset = tf.data.Dataset.from_tensor_slices((features, labels)) # 从字典创建(结构化数据) data_dict = { "image": features, "label": labels, "weight": np.ones((100,), dtype=np.float32) } dict_dataset = tf.data.Dataset.from_tensor_slices(data_dict) for batch in dict_dataset.batch(4).take(1): print(batch["image"].shape) # (4, 28, 28, 3)

2.2 from_generator

当数据无法全部加载到内存中,或者数据来自外部流时,可以通过 Python 生成器动态产生数据。生成器是一个 Python 函数,每次 yield 一个样本或一批样本。

def data_generator(): for i in range(1000): yield np.random.randn(28, 28, 3).astype(np.float32), i dataset = tf.data.Dataset.from_generator( data_generator, output_signature=( tf.TensorSpec(shape=(28, 28, 3), dtype=tf.float32), tf.TensorSpec(shape=(), dtype=tf.int32) ) ) # 重要:from_generator 的序列化限制 # 生成器函数不能依赖外部 Python 状态,否则无法序列化(如用于 tf.function)

2.3 list_files

当数据以文件形式存储在磁盘上时,使用 list_files 获取文件路径列表,并结合 map 函数读取。支持通配符模式匹配,可以按比例分片。

# 获取所有图像文件路径 file_pattern = "/data/images/train/*.jpg" file_dataset = tf.data.Dataset.list_files( file_pattern, shuffle=True, # 是否打乱文件顺序 seed=42 ) # 定义文件读取函数 def parse_image(file_path): image = tf.io.read_file(file_path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, [224, 224]) image = tf.cast(image, tf.float32) / 255.0 return image image_dataset = file_dataset.map(parse_image) # 按文件数量分配到各 worker # num_parallel_reads 控制读取并行度 files = tf.data.Dataset.list_files("/data/*.tfrecord") dataset = files.interleave( lambda f: tf.data.TFRecordDataset(f), cycle_length=4, num_parallel_calls=tf.data.AUTOTUNE )

2.4 TextLineDataset

专门用于读取文本文件,每行作为一个元素。常用于 CSV 数据、日志文件等纯文本格式。

# 读取文本文件,每行一个元素 text_dataset = tf.data.TextLineDataset( ["file1.csv", "file2.csv"], compression_type="GZIP" # 支持压缩文件 ) # 跳过 CSV 表头 text_dataset = text_dataset.skip(1) # 解析 CSV 行 def parse_csv_line(line): fields = tf.io.decode_csv( line, record_defaults=[tf.constant("", dtype=tf.string)] * 5 ) return fields parsed_dataset = text_dataset.map(parse_csv_line)

2.5 TFRecordDataset

专为 TFRecord 格式设计的读取器,是 TensorFlow 生态中最推荐的大规模数据存储格式。TFRecord 将数据序列化为二进制记录,读取效率远高于文本文件。

# 读取 TFRecord 文件 filenames = ["/data/train-00000-of-00100", "/data/train-00001-of-00100"] raw_dataset = tf.data.TFRecordDataset( filenames, compression_type="GZIP", # 压缩类型 buffer_size=1024 * 1024 * 16 # 16MB 读取缓冲区 ) # 需要配合解析函数使用(详见第五节 TFRecord) dataset = raw_dataset.map(parse_tfrecord_fn)
创建方法 适用场景 数据量 读取速度
from_tensor_slices 小数据内存加载、快速原型 <10GB ★★★
from_generator 动态生成、流式数据、外部数据源 无上限 ★★
list_files + map 图像文件、自定义二进制文件 中等 ★★★
TextLineDataset CSV、日志等纯文本 中等 ★★★
TFRecordDataset 大规模训练数据(推荐) TB级 ★★★★★

三、数据变换操作

Dataset 创建完成后,可以链式调用多种变换方法对其进行处理。这些变换是"惰性"的,只在迭代时才实际执行,从而允许 TensorFlow 对流水线进行整体优化。

3.1 map

对数据集的每个元素应用一个函数。这是最常用的变换,用于数据预处理、解析、增强等。map 支持并行化执行,是性能优化的关键点。

# 图像增强示例 def augment(image, label): image = tf.image.random_flip_left_right(image) image = tf.image.random_brightness(image, max_delta=0.1) image = tf.image.random_contrast(image, lower=0.8, upper=1.2) return image, label # 并行 map(关键性能优化) dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)

3.2 batch

将连续元素组合成批次。最后一个批次可能小于 batch_size,可通过 drop_remainder 控制。

dataset = dataset.batch( batch_size=32, drop_remainder=True # 丢弃最后一个不完整批次 )

3.3 shuffle

以大小为 buffer_size 的缓冲区实现数据打乱。buffer_size 越大,随机性越好,但内存消耗也越大。经验法则是 buffer_size 至少为样本总数的 10%。

# shuffle 是"采样式"打乱,不是全局重排 dataset = dataset.shuffle( buffer_size=10000, # 缓冲区大小(越大越随机) seed=42, reshuffle_each_iteration=True # 每轮重新打乱 )

3.4 repeat

重复数据集多次,用于多轮训练。不传参数时无限重复。

# epoch 由 repeat 控制 dataset = dataset.repeat(count=10) # 重复10个epoch # 或与 batch 联合:无限重复,由训练步数控制 dataset = dataset.repeat().batch(32)

3.5 prefetch

预取数据到内存中,让数据预处理和模型训练重叠。这通常是提升 GPU 利用率最有效的单一操作。

# 预取1个批次 dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE) # 实现原理: # CPU 在 GPU 处理当前批次时,已经在准备下一个批次 # 从而消除"GPU 等 CPU"的空闲时间

3.6 cache

将数据集缓存到内存或磁盘,避免重复计算。在第一个 epoch 中执行所有变换并缓存结果,后续 epoch 直接读取缓存。

# 缓存到内存(适合小数据集) dataset = dataset.cache() # 缓存到磁盘(适合大数据集,注意磁盘空间) dataset = dataset.cache(filename="/tmp/cache_dir") # 典型用法:map/cache/shuffle/batch/prefetch dataset = dataset.map(preprocess, num_parallel_calls=AUTOTUNE) dataset = dataset.cache() # 第一次 epoch 后缓存预处理结果 dataset = dataset.shuffle(1024) dataset = dataset.batch(64) dataset = dataset.prefetch(AUTOTUNE)

3.7 filter

根据条件函数过滤元素。常用于移除无效样本或筛选特定类别的数据。

# 过滤掉标签为 -1 的无效样本 dataset = dataset.filter(lambda x, y: y >= 0) # 筛选出特定类别的数据 dataset = dataset.filter(lambda x, y: tf.reduce_any(tf.equal(y, [0, 1, 2])))

3.8 unbatch

将批次数据重新展开为单个元素,是 batch 的逆操作。常用于变长序列数据的处理。

# unbatch 示例 batched = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).batch(2) unbatched = batched.unbatch() for item in unbatched: print(item.numpy()) # 1, 2, 3, 4

3.9 window

将连续元素组合为"窗口",每个窗口本身是一个子数据集。常用于序列模型的时间窗口特征构造。

# 时间序列窗口 time_series = tf.data.Dataset.range(100) windows = time_series.window(size=5, shift=1, drop_remainder=True) # 将窗口展平为特征-标签对 def window_to_dataset(window): window = window.batch(5, drop_remainder=True) return window.map(lambda w: (w[:-1], w[-1:])) seq_dataset = windows.flat_map(window_to_dataset)

四、性能优化技巧

tf.data 提供了多种性能优化机制,合理配置可以使数据加载速度提升数倍至数十倍。以下是经过实践验证的核心优化策略。

4.1 使用 prefetch 实现流水线重叠

prefetch 是 tf.data 中最简单但最有效的优化手段。它将数据准备与模型训练在时间上重叠,让 CPU 在 GPU 执行前向/反向传播的同时准备下一批数据。使用 tf.data.AUTOTUNE 让框架自动选择最佳的预取缓冲区大小。

# 始终在流水线末尾添加 prefetch dataset = dataset.batch(32) dataset = dataset.prefetch(tf.data.AUTOTUNE) # 可视化理解: # 无 prefetch: [train][load][train][load][train][load] # 有 prefetch: [train]──[train]──[train]── # [load]──[load]──[load]── (并行执行)

4.2 并行 map(num_parallel_calls)

map 操作默认是串行执行的,通过设置 num_parallel_calls 可以在多个 CPU 线程上并行处理元素。对于计算密集型预处理(如图像解码、数据增强),并行度可以显著降低延迟。

# 串行 map(慢) dataset = dataset.map(preprocess_fn) # 并行 map(快) dataset = dataset.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE) # 并行 map 的注意事项: # 1. preprocess_fn 必须是纯函数(无副作用) # 2. 不是 num_parallel_calls 越大越好,需考虑 CPU 核心数 # 3. AUTOTUNE 让 tf.data 自动选择最优并行度

4.3 使用 AUTOTUNE 自动调优

tf.data.AUTOTUNE 是一个特殊常量,告知 tf.data 运行时根据系统负载动态调整并行度、预取大小和缓冲区大小等参数。这是从手动调优迈向自动调优的关键一步。

# 完全自动调优的流水线 dataset = ( tf.data.Dataset.list_files(pattern) .shuffle(buffer_size=10000) .interleave( lambda f: tf.data.TFRecordDataset(f), num_parallel_calls=tf.data.AUTOTUNE ) .map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE) .map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE) .batch(64, drop_remainder=True) .prefetch(tf.data.AUTOTUNE) )

4.4 缓存到内存

对于预处理计算密集型但数据集较小(能放入内存)的场景,在 map 之后添加 cache() 可以极大加速后续 epoch。第一个 epoch 执行所有 map 变换并缓存结果,后续 epoch 直接复用。

# 策略1: 预处理后立即缓存(适合 < 10GB) dataset = dataset.map(heavy_preprocessing, num_parallel_calls=AUTOTUNE) dataset = dataset.cache() dataset = dataset.shuffle(1024).batch(64).prefetch(AUTOTUNE) # 策略2: 分阶段缓存 dataset = dataset.map(light_preprocess).cache() dataset = dataset.map(random_augment, num_parallel_calls=AUTOTUNE) dataset = dataset.batch(64).prefetch(AUTOTUNE) # 确定性预处理被缓存,随机增强每次重新计算

4.5 并行化文件读取(interleave)

interleave 允许同时从多个文件中读取数据,避免因单个大文件的 I/O 等待而阻塞流水线。cycle_length 控制并行读取的文件数,block_length 控制每个文件连续读取的记录数。

# interleave 与 flat_map 对比 files = tf.data.Dataset.list_files("/data/*.tfrecord") # flat_map:串行读取(慢) dataset = files.flat_map(lambda f: tf.data.TFRecordDataset(f)) # interleave:并行读取(快) dataset = files.interleave( lambda f: tf.data.TFRecordDataset(f).map(parse_fn), cycle_length=4, # 同时读4个文件 block_length=16, # 每个文件连续取16条 num_parallel_calls=tf.data.AUTOTUNE ) # 注意:当文件数量远大于 cycle_length 时, # interleave 会动态切换文件,保持并行度

4.6 文件分片(shard)

在分布式训练中,每个 worker 只处理总文件的一个子集,避免数据重复。shard 根据 worker 数量和索引对数据集进行均匀划分。

# 分布式训练中的文件分片 num_workers = 8 worker_index = 0 # 当前 worker 编号 dataset = tf.data.Dataset.list_files("/data/*.tfrecord") dataset = dataset.shard(num_workers, worker_index) # 分片后每个 worker 只处理 1/8 的文件 # 配合 interleave 进一步提升读取效率 # 最佳实践:先分片,后 shuffle dataset = dataset.shard(num_workers, worker_index) dataset = dataset.shuffle(buffer_size=10000)

4.7 性能优化最佳实践总结

推荐流水线结构(按顺序):

1. list_files(获取文件列表)→ 2. shard(分布式分片)→ 3. shuffle(文件级打乱)→ 4. interleave(并行读取)→ 5. map(预处理,并行)→ 6. cache(缓存预处理结果)→ 7. map(随机增强,并行)→ 8. shuffle(样本级打乱)→ 9. batch(组批)→ 10. prefetch(预取)

注意:cache 之前的 shuffle/random 操作会被固定,因此随机增强应在 cache 之后进行。

五、TFRecord 格式详解

TFRecord 是 TensorFlow 原生推荐的二进制数据存储格式。它将数据序列化为一条条记录(Record),每条记录由字节串构成,读取时无需解析文件格式,I/O 效率极高。在大规模分布式训练场景中,TFRecord 几乎是标准配置。

5.1 Example 与 Feature 协议

TFRecord 的核心数据结构是 tf.train.Example,它本质上是一个字符串到 tf.train.Feature 的映射。Feature 支持三种类型:BytesList(字符串/字节)、FloatList(浮点数)、Int64List(整数)。

# 构建一个 Example 消息 def create_example(image_bytes, label, height, width): feature = { "image/encoded": tf.train.Feature( bytes_list=tf.train.BytesList(value=[image_bytes])), "image/height": tf.train.Feature( int64_list=tf.train.Int64List(value=[height])), "image/width": tf.train.Feature( int64_list=tf.train.Int64List(value=[width])), "image/label": tf.train.Feature( int64_list=tf.train.Int64List(value=[label])), "image/format": tf.train.Feature( bytes_list=tf.train.BytesList(value=[b"jpeg"])), } return tf.train.Example(features=tf.train.Features(feature=feature))

5.2 序列化与写入

创建 Example 后,需要通过 SerializeToString 序列化为字节串,再写入 TFRecord 文件。写入时可以使用 tf.io.TFRecordWriter。

# 写入 TFRecord 文件 output_file = "/data/train.tfrecord" with tf.io.TFRecordWriter(output_file) as writer: for i in range(1000): image_bytes = read_image_as_bytes(i) # 假设的函数 example = create_example( image_bytes, label=i % 10, height=224, width=224 ) writer.write(example.SerializeToString()) # 写入多个文件(推荐用于大规模数据) num_shards = 100 for shard_id in range(num_shards): filename = f"/data/train-{shard_id:05d}-of-{num_shards:05d}" with tf.io.TFRecordWriter(filename) as writer: for i in range(1000): writer.write(create_example_for_index(i).SerializeToString())

5.3 解析 TFRecord

读取时需要编写解析函数,使用 tf.io.parse_single_example 将序列化的字节串还原为特征张量。解析函数的特征描述必须与写入时完全一致。

# 定义特征描述 feature_description = { "image/encoded": tf.io.FixedLenFeature([], tf.string), "image/height": tf.io.FixedLenFeature([], tf.int64), "image/width": tf.io.FixedLenFeature([], tf.int64), "image/label": tf.io.FixedLenFeature([], tf.int64), "image/format": tf.io.FixedLenFeature([], tf.string), } def parse_tfrecord_fn(example_proto): # 解析 Example parsed = tf.io.parse_single_example(example_proto, feature_description) # 解码图像 image = tf.image.decode_jpeg(parsed["image/encoded"], channels=3) image = tf.image.resize(image, [224, 224]) image = tf.cast(image, tf.float32) / 255.0 # 返回 (特征, 标签) 元组 return image, parsed["image/label"] # 组装流水线 dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) dataset = dataset.batch(64).prefetch(AUTOTUNE)

5.4 VarLenFeature 与序列特征

当特征的长度不固定时(如变长序列、多值标签),使用 VarLenFeature 代替 FixedLenFeature,它会返回 SparseTensor。

# 示例:多标签分类 feature_desc = { "image/encoded": tf.io.FixedLenFeature([], tf.string), "image/labels": tf.io.VarLenFeature(tf.int64), # 变长标签列表 "text/tokens": tf.io.VarLenFeature(tf.string), # 变长文本 token } def parse_fn(proto): parsed = tf.io.parse_single_example(proto, feature_desc) image = tf.image.decode_jpeg(parsed["image/encoded"]) labels = tf.sparse.to_dense(parsed["image/labels"]) # 转为稠密 return image, labels

TFRecord 优势总结:

1. 二进制格式,读取效率远超文本文件

2. 天然支持压缩(GZIP/ZLIB),减少存储和 I/O

3. 与 tf.data 深度集成,支持分片、并行读取

4. 良好的跨平台兼容性(Python、Java、Go 等均支持)

5. 适合分布式训练中的数据分发

6. ProtoBuf 序列化,schema 定义清晰

六、特征列 FeatureColumns

FeatureColumns 是 TensorFlow Estimator API(和 Keras)中用于描述特征结构的组件。它将各种原始数据类型(数值、类别、文本、嵌入等)统一转换为模型可以处理的数值张量。虽然在纯 Keras 中可以直接预处理,但 FeatureColumns 提供了一种声明式的、可序列化的特征工程方案,特别适合结构化数据的特征处理流水线。

6.1 数值列(numeric_column)

最基本的特征列,将数值特征直接传递给模型。可以指定归一化函数。

# 基础数值列 age = tf.feature_column.numeric_column("age") price = tf.feature_column.numeric_column("price", dtype=tf.float32) # 带归一化的数值列 def normalize_income(x): return (x - 50000.0) / 30000.0 income = tf.feature_column.numeric_column( "income", normalizer_fn=normalize_income )

6.2 类别列(categorical_column)

将离散值(字符串、整数)映射为类别 ID。支持词汇表映射、哈希分桶等方式。

# 词汇表映射 color = tf.feature_column.categorical_column_with_vocabulary_list( "color", vocabulary_list=["red", "green", "blue", "yellow"] ) # 哈希分桶(适合类别数不确定的场景) city = tf.feature_column.categorical_column_with_hash_bucket( "city", hash_bucket_size=1000 # 哈希到 1000 个桶 ) # 整数标识 category_id = tf.feature_column.categorical_column_with_identity( "category_id", num_buckets=50 ) # 分桶连续特征 age_buckets = tf.feature_column.bucketized_column( source_column=age, boundaries=[18, 25, 35, 45, 55, 65] )

6.3 嵌入列(embedding_column)

将高维稀疏类别特征映射到低维稠密向量。是处理大规模类别特征(如用户 ID、商品 ID)的标准方法。

# 嵌入列 city_embedding = tf.feature_column.embedding_column( categorical_column=city, dimension=16 # 嵌入维度 ) # 嵌入维度经验公式 # dimension = min(50, (num_categories + 1) // 2) # 或 dimension = int(num_categories ** 0.25) * 4 # 使用嵌入列处理 ID 特征 user_id = tf.feature_column.categorical_column_with_hash_bucket( "user_id", hash_bucket_size=100000 ) user_embedding = tf.feature_column.embedding_column( user_id, dimension=32 )

6.4 交叉列(crossed_column)

对两个或多个类别特征进行笛卡尔积交叉,产生新的组合特征。自动进行哈希分桶,避免组合爆炸。

# 特征交叉:年龄段 x 城市 age_x_city = tf.feature_column.crossed_column( [age_buckets, city], hash_bucket_size=5000 # 交叉后哈希到 5000 个桶 ) # 交叉列通常需要配合嵌入列使用 age_x_city_embedding = tf.feature_column.embedding_column( age_x_city, dimension=8 ) # 更高阶交叉 user_x_item = tf.feature_column.crossed_column( [user_id, category_id], hash_bucket_size=50000 )

6.5 指示列(indicator_column)

将类别特征转换为 one-hot 编码的多-hot 指示向量。适用于类别数较少(通常 < 100)的场景。

# One-hot 编码 color_indicator = tf.feature_column.indicator_column(color) # 嵌入列 vs 指示列的选择 # 类别数少(< 50):使用 indicator_column(one-hot) # 类别数多(>= 50):使用 embedding_column # 数学上,embedding_column 可视为 indicator_column 的低秩近似

6.6 在 Keras 中使用 FeatureColumns

# 组合所有特征列 feature_columns = [ age, # 数值列 income, # 数值列(带归一化) age_buckets, # 分桶列 color_indicator, # 指示列(one-hot) city_embedding, # 嵌入列 age_x_city_embedding, # 交叉列 + 嵌入 ] # 在 Keras 中使用 feature_layer = tf.keras.layers.DenseFeatures(feature_columns) model = tf.keras.Sequential([ feature_layer, tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(64, activation="relu"), tf.keras.layers.Dense(1, activation="sigmoid") ]) model.compile(optimizer="adam", loss="binary_crossentropy") # 训练数据用字典格式输入 model.fit(dataset_dict, epochs=10)

七、Keras 与 tf.data 集成

Keras 的 model.fit()model.evaluate()model.predict() 均支持直接接收 tf.data.Dataset 作为输入。这是 tf.data 最强大的应用场景之一,将数据流水线与训练循环无缝衔接。

7.1 基础集成

# 准备 Dataset (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train = x_train.astype(np.float32) / 255.0 x_test = x_test.astype(np.float32) / 255.0 train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.shuffle(60000).batch(32).prefetch(AUTOTUNE) test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) test_dataset = test_dataset.batch(32).prefetch(AUTOTUNE) # 直接传入 Dataset model.fit(train_dataset, validation_data=test_dataset, epochs=10) # evaluate 和 predict 同理 loss, acc = model.evaluate(test_dataset) predictions = model.predict(test_dataset)

7.2 steps_per_epoch 控制

当 Dataset 无限重复(repeat() 无参数)时,必须指定 steps_per_epoch 来决定每个 epoch 的步数。

# 无限重复的 Dataset train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.repeat() # 无限重复 train_dataset = train_dataset.shuffle(10000) train_dataset = train_dataset.batch(32).prefetch(AUTOTUNE) model.fit( train_dataset, epochs=10, steps_per_epoch=1875, # 60000 / 32 ≈ 1875 validation_data=test_dataset )

7.3 多输入模型

当模型有多个输入时,Dataset 的每个元素应为 (inputs_dict, label)(inputs_tuple, label) 格式。

# 多输入 Dataset def gen_multi_input(): for i in range(1000): yield ( {"image_input": np.random.randn(224, 224, 3), "text_input": np.random.randint(0, 1000, size=(50,))}, np.random.randint(0, 10) ) dataset = tf.data.Dataset.from_generator( gen_multi_input, output_signature=( {"image_input": tf.TensorSpec(shape=(224, 224, 3), dtype=tf.float32), "text_input": tf.TensorSpec(shape=(50,), dtype=tf.int32)}, tf.TensorSpec(shape=(), dtype=tf.int32) ) ) dataset = dataset.batch(32) # 多输入 Keras 模型 image_input = tf.keras.Input(shape=(224, 224, 3), name="image_input") text_input = tf.keras.Input(shape=(50,), name="text_input") # ... 模型定义 ... model.fit(dataset, epochs=10)

7.4 自定义训练循环

在需要更精细控制时,可以直接在 tf.GradientTape 中迭代 Dataset。

# 自定义训练循环 optimizer = tf.keras.optimizers.Adam() loss_fn = tf.keras.losses.SparseCategoricalCrossentropy() for epoch in range(10): for step, (x_batch, y_batch) in enumerate(train_dataset): with tf.GradientTape() as tape: logits = model(x_batch, training=True) loss = loss_fn(y_batch, logits) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0: print(f"Epoch {epoch}, Step {step}, Loss: {loss.numpy():.4f}")

八、完整实战案例

以下是一个端到端的实战案例,完整演示了从 TFRecord 创建到 Keras 模型训练的完整数据流水线。

# ========= 步骤1: 创建 TFRecord 文件 ========= import tensorflow as tf import numpy as np AUTOTUNE = tf.data.AUTOTUNE def serialize_example(feature0, feature1, feature2, label): feature = { "feature0": tf.train.Feature( float_list=tf.train.FloatList(value=[feature0])), "feature1": tf.train.Feature( float_list=tf.train.FloatList(value=feature1.flatten())), "feature2": tf.train.Feature( int64_list=tf.train.Int64List(value=[feature2])), "label": tf.train.Feature( int64_list=tf.train.Int64List(value=[label])), } return tf.train.Example( features=tf.train.Features(feature=feature) ).SerializeToString() # 生成模拟数据并写入 TFRecord with tf.io.TFRecordWriter("/tmp/train.tfrecord") as writer: for i in range(10000): f0 = np.random.randn() f1 = np.random.randn(10) f2 = np.random.randint(0, 100) label = np.random.randint(0, 10) writer.write(serialize_example(f0, f1, f2, label)) # ========= 步骤2: 定义解析函数 ========= feature_description = { "feature0": tf.io.FixedLenFeature([], tf.float32), "feature1": tf.io.FixedLenFeature([10], tf.float32), "feature2": tf.io.FixedLenFeature([], tf.int64), "label": tf.io.FixedLenFeature([], tf.int64), } def parse_fn(proto): parsed = tf.io.parse_single_example(proto, feature_description) features = { "f0": parsed["feature0"], "f1": parsed["feature1"], "f2": tf.cast(parsed["feature2"], tf.float32), } label = parsed["label"] return features, label # ========= 步骤3: 构建高性能流水线 ========= dataset = tf.data.TFRecordDataset(["/tmp/train.tfrecord"]) dataset = dataset.map(parse_fn, num_parallel_calls=AUTOTUNE) dataset = dataset.cache() # 缓存解析结果 dataset = dataset.shuffle(buffer_size=5000) dataset = dataset.batch(64, drop_remainder=True) dataset = dataset.prefetch(AUTOTUNE) # ========= 步骤4: 构建并训练模型 ========= feature_columns = [ tf.feature_column.numeric_column("f0"), tf.feature_column.numeric_column("f1", shape=(10,)), tf.feature_column.numeric_column("f2"), ] model = tf.keras.Sequential([ tf.keras.layers.DenseFeatures(feature_columns), tf.keras.layers.Dense(64, activation="relu"), tf.keras.layers.Dense(32, activation="relu"), tf.keras.layers.Dense(10, activation="softmax"), ]) model.compile( optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"] ) model.fit(dataset, epochs=10, verbose=2) # ========= 步骤5: 保存模型 ========= model.save("/tmp/tfdata_demo_model")

九、总结

tf.data 是 TensorFlow 生态中不可或缺的数据处理组件,它通过声明式的 API 将数据加载、预处理、增强、组批和预取等环节无缝衔接。掌握 tf.data 的核心概念和优化技巧,对于提升深度学习训练效率和 GPU 利用率至关重要。

核心要点回顾:

1. Dataset 创建:根据数据规模和格式选择合适的创建方法(from_tensor_slices、from_generator、list_files、TextLineDataset、TFRecordDataset)

2. 数据变换:灵活组合 map、batch、shuffle、repeat、cache、prefetch、filter、unbatch、window 等操作

3. 性能优化:始终使用 prefetch(AUTOTUNE);对 map 设置 num_parallel_calls=AUTOTUNE;使用 interleave 并行读取文件;在 map 后使用 cache 避免重复计算

4. TFRecord:大规模数据推荐使用 TFRecord 格式,通过 Example/Feature 协议序列化,利用分片实现分布式读取

5. FeatureColumns:声明式特征工程,覆盖数值列、类别列、嵌入列、交叉列和指示列

6. Keras 集成:model.fit() 直接接收 Dataset,支持多输入、多输出和自定义训练循环

最佳实践口诀:

数据源选 TFRecord,解析 map 加并行;

shuffle 在前 batch 后,预取 cache 不能漏;

AUTOTUNE 来调优,GPU 满载不空等;

特征列声明式写,Keras 集成直接训。