TensorFlow数据流水线(tf.data)
深度学习专题 · 高效数据加载与预处理
专题:深度学习系统学习
关键词:深度学习, tf.data, Dataset, TFRecord, prefetch, cache, 高性能, 数据流水线, FeatureColumn
一、tf.data 概述
tf.data 是 TensorFlow 2.x 中官方推荐的数据输入流水线 API,它提供了一套高性能、可扩展的数据处理工具,能够将原始数据(文件、张量、生成器等)构建为高效的输入流水线。与传统使用 feed_dict 或手动编写数据加载循环的方式相比,tf.data 能更好地利用硬件资源(CPU、GPU、TPU),实现数据加载与模型训练的并行化。
tf.data 的核心概念是 tf.data.Dataset,它表示一个元素的序列,每个元素包含一个或多个 Tensor 对象。Dataset 可以像 Python 迭代器一样使用,但背后实现了大量性能优化:流水线重叠、并行化数据变换、自动调优参数等。
核心理念:tf.data 将数据加载与模型训练解耦,通过声明式的 API 链式调用构建数据流水线,让框架自动处理并发、预取和缓存等复杂任务,开发者只需关注数据变换逻辑。
从更高的视角看,tf.data 不仅仅是一个数据加载工具,它提供了完整的 ETL(Extract, Transform, Load)流程:从各种数据源提取原始数据,执行清洗和增强变换,最后以最优方式加载到训练循环中。数据流水线的性能直接影响 GPU 利用率和整体训练速度,因此掌握 tf.data 是深度学习工程实践中的关键技能。
二、Dataset 的创建方式
tf.data.Dataset 提供了多种静态和动态创建方法,覆盖了绝大部分实际场景。根据数据源的不同,选用的创建方式也各不相同。
2.1 from_tensor_slices
最基础的创建方式,将 NumPy 数组或 TensorFlow 张量按第一个维度切分成多个样本。常用于小数据集或内存数据。
import tensorflow as tf
import numpy as np
# 从 NumPy 数组创建
features = np.random.randn(100, 28, 28, 3).astype(np.float32)
labels = np.random.randint(0, 10, size=(100,))
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
# 从字典创建(结构化数据)
data_dict = {
"image": features,
"label": labels,
"weight": np.ones((100,), dtype=np.float32)
}
dict_dataset = tf.data.Dataset.from_tensor_slices(data_dict)
for batch in dict_dataset.batch(4).take(1):
print(batch["image"].shape) # (4, 28, 28, 3)
2.2 from_generator
当数据无法全部加载到内存中,或者数据来自外部流时,可以通过 Python 生成器动态产生数据。生成器是一个 Python 函数,每次 yield 一个样本或一批样本。
def data_generator():
for i in range(1000):
yield np.random.randn(28, 28, 3).astype(np.float32), i
dataset = tf.data.Dataset.from_generator(
data_generator,
output_signature=(
tf.TensorSpec(shape=(28, 28, 3), dtype=tf.float32),
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
# 重要:from_generator 的序列化限制
# 生成器函数不能依赖外部 Python 状态,否则无法序列化(如用于 tf.function)
2.3 list_files
当数据以文件形式存储在磁盘上时,使用 list_files 获取文件路径列表,并结合 map 函数读取。支持通配符模式匹配,可以按比例分片。
# 获取所有图像文件路径
file_pattern = "/data/images/train/*.jpg"
file_dataset = tf.data.Dataset.list_files(
file_pattern,
shuffle=True, # 是否打乱文件顺序
seed=42
)
# 定义文件读取函数
def parse_image(file_path):
image = tf.io.read_file(file_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [224, 224])
image = tf.cast(image, tf.float32) / 255.0
return image
image_dataset = file_dataset.map(parse_image)
# 按文件数量分配到各 worker
# num_parallel_reads 控制读取并行度
files = tf.data.Dataset.list_files("/data/*.tfrecord")
dataset = files.interleave(
lambda f: tf.data.TFRecordDataset(f),
cycle_length=4,
num_parallel_calls=tf.data.AUTOTUNE
)
2.4 TextLineDataset
专门用于读取文本文件,每行作为一个元素。常用于 CSV 数据、日志文件等纯文本格式。
# 读取文本文件,每行一个元素
text_dataset = tf.data.TextLineDataset(
["file1.csv", "file2.csv"],
compression_type="GZIP" # 支持压缩文件
)
# 跳过 CSV 表头
text_dataset = text_dataset.skip(1)
# 解析 CSV 行
def parse_csv_line(line):
fields = tf.io.decode_csv(
line,
record_defaults=[tf.constant("", dtype=tf.string)] * 5
)
return fields
parsed_dataset = text_dataset.map(parse_csv_line)
2.5 TFRecordDataset
专为 TFRecord 格式设计的读取器,是 TensorFlow 生态中最推荐的大规模数据存储格式。TFRecord 将数据序列化为二进制记录,读取效率远高于文本文件。
# 读取 TFRecord 文件
filenames = ["/data/train-00000-of-00100", "/data/train-00001-of-00100"]
raw_dataset = tf.data.TFRecordDataset(
filenames,
compression_type="GZIP", # 压缩类型
buffer_size=1024 * 1024 * 16 # 16MB 读取缓冲区
)
# 需要配合解析函数使用(详见第五节 TFRecord)
dataset = raw_dataset.map(parse_tfrecord_fn)
| 创建方法 |
适用场景 |
数据量 |
读取速度 |
| from_tensor_slices |
小数据内存加载、快速原型 |
<10GB |
★★★ |
| from_generator |
动态生成、流式数据、外部数据源 |
无上限 |
★★ |
| list_files + map |
图像文件、自定义二进制文件 |
中等 |
★★★ |
| TextLineDataset |
CSV、日志等纯文本 |
中等 |
★★★ |
| TFRecordDataset |
大规模训练数据(推荐) |
TB级 |
★★★★★ |
三、数据变换操作
Dataset 创建完成后,可以链式调用多种变换方法对其进行处理。这些变换是"惰性"的,只在迭代时才实际执行,从而允许 TensorFlow 对流水线进行整体优化。
3.1 map
对数据集的每个元素应用一个函数。这是最常用的变换,用于数据预处理、解析、增强等。map 支持并行化执行,是性能优化的关键点。
# 图像增强示例
def augment(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.1)
image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
return image, label
# 并行 map(关键性能优化)
dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
3.2 batch
将连续元素组合成批次。最后一个批次可能小于 batch_size,可通过 drop_remainder 控制。
dataset = dataset.batch(
batch_size=32,
drop_remainder=True # 丢弃最后一个不完整批次
)
3.3 shuffle
以大小为 buffer_size 的缓冲区实现数据打乱。buffer_size 越大,随机性越好,但内存消耗也越大。经验法则是 buffer_size 至少为样本总数的 10%。
# shuffle 是"采样式"打乱,不是全局重排
dataset = dataset.shuffle(
buffer_size=10000, # 缓冲区大小(越大越随机)
seed=42,
reshuffle_each_iteration=True # 每轮重新打乱
)
3.4 repeat
重复数据集多次,用于多轮训练。不传参数时无限重复。
# epoch 由 repeat 控制
dataset = dataset.repeat(count=10) # 重复10个epoch
# 或与 batch 联合:无限重复,由训练步数控制
dataset = dataset.repeat().batch(32)
3.5 prefetch
预取数据到内存中,让数据预处理和模型训练重叠。这通常是提升 GPU 利用率最有效的单一操作。
# 预取1个批次
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
# 实现原理:
# CPU 在 GPU 处理当前批次时,已经在准备下一个批次
# 从而消除"GPU 等 CPU"的空闲时间
3.6 cache
将数据集缓存到内存或磁盘,避免重复计算。在第一个 epoch 中执行所有变换并缓存结果,后续 epoch 直接读取缓存。
# 缓存到内存(适合小数据集)
dataset = dataset.cache()
# 缓存到磁盘(适合大数据集,注意磁盘空间)
dataset = dataset.cache(filename="/tmp/cache_dir")
# 典型用法:map/cache/shuffle/batch/prefetch
dataset = dataset.map(preprocess, num_parallel_calls=AUTOTUNE)
dataset = dataset.cache() # 第一次 epoch 后缓存预处理结果
dataset = dataset.shuffle(1024)
dataset = dataset.batch(64)
dataset = dataset.prefetch(AUTOTUNE)
3.7 filter
根据条件函数过滤元素。常用于移除无效样本或筛选特定类别的数据。
# 过滤掉标签为 -1 的无效样本
dataset = dataset.filter(lambda x, y: y >= 0)
# 筛选出特定类别的数据
dataset = dataset.filter(lambda x, y: tf.reduce_any(tf.equal(y, [0, 1, 2])))
3.8 unbatch
将批次数据重新展开为单个元素,是 batch 的逆操作。常用于变长序列数据的处理。
# unbatch 示例
batched = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).batch(2)
unbatched = batched.unbatch()
for item in unbatched:
print(item.numpy()) # 1, 2, 3, 4
3.9 window
将连续元素组合为"窗口",每个窗口本身是一个子数据集。常用于序列模型的时间窗口特征构造。
# 时间序列窗口
time_series = tf.data.Dataset.range(100)
windows = time_series.window(size=5, shift=1, drop_remainder=True)
# 将窗口展平为特征-标签对
def window_to_dataset(window):
window = window.batch(5, drop_remainder=True)
return window.map(lambda w: (w[:-1], w[-1:]))
seq_dataset = windows.flat_map(window_to_dataset)
四、性能优化技巧
tf.data 提供了多种性能优化机制,合理配置可以使数据加载速度提升数倍至数十倍。以下是经过实践验证的核心优化策略。
4.1 使用 prefetch 实现流水线重叠
prefetch 是 tf.data 中最简单但最有效的优化手段。它将数据准备与模型训练在时间上重叠,让 CPU 在 GPU 执行前向/反向传播的同时准备下一批数据。使用 tf.data.AUTOTUNE 让框架自动选择最佳的预取缓冲区大小。
# 始终在流水线末尾添加 prefetch
dataset = dataset.batch(32)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# 可视化理解:
# 无 prefetch: [train][load][train][load][train][load]
# 有 prefetch: [train]──[train]──[train]──
# [load]──[load]──[load]── (并行执行)
4.2 并行 map(num_parallel_calls)
map 操作默认是串行执行的,通过设置 num_parallel_calls 可以在多个 CPU 线程上并行处理元素。对于计算密集型预处理(如图像解码、数据增强),并行度可以显著降低延迟。
# 串行 map(慢)
dataset = dataset.map(preprocess_fn)
# 并行 map(快)
dataset = dataset.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)
# 并行 map 的注意事项:
# 1. preprocess_fn 必须是纯函数(无副作用)
# 2. 不是 num_parallel_calls 越大越好,需考虑 CPU 核心数
# 3. AUTOTUNE 让 tf.data 自动选择最优并行度
4.3 使用 AUTOTUNE 自动调优
tf.data.AUTOTUNE 是一个特殊常量,告知 tf.data 运行时根据系统负载动态调整并行度、预取大小和缓冲区大小等参数。这是从手动调优迈向自动调优的关键一步。
# 完全自动调优的流水线
dataset = (
tf.data.Dataset.list_files(pattern)
.shuffle(buffer_size=10000)
.interleave(
lambda f: tf.data.TFRecordDataset(f),
num_parallel_calls=tf.data.AUTOTUNE
)
.map(parse_fn, num_parallel_calls=tf.data.AUTOTUNE)
.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
.batch(64, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE)
)
4.4 缓存到内存
对于预处理计算密集型但数据集较小(能放入内存)的场景,在 map 之后添加 cache() 可以极大加速后续 epoch。第一个 epoch 执行所有 map 变换并缓存结果,后续 epoch 直接复用。
# 策略1: 预处理后立即缓存(适合 < 10GB)
dataset = dataset.map(heavy_preprocessing, num_parallel_calls=AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(1024).batch(64).prefetch(AUTOTUNE)
# 策略2: 分阶段缓存
dataset = dataset.map(light_preprocess).cache()
dataset = dataset.map(random_augment, num_parallel_calls=AUTOTUNE)
dataset = dataset.batch(64).prefetch(AUTOTUNE)
# 确定性预处理被缓存,随机增强每次重新计算
4.5 并行化文件读取(interleave)
interleave 允许同时从多个文件中读取数据,避免因单个大文件的 I/O 等待而阻塞流水线。cycle_length 控制并行读取的文件数,block_length 控制每个文件连续读取的记录数。
# interleave 与 flat_map 对比
files = tf.data.Dataset.list_files("/data/*.tfrecord")
# flat_map:串行读取(慢)
dataset = files.flat_map(lambda f: tf.data.TFRecordDataset(f))
# interleave:并行读取(快)
dataset = files.interleave(
lambda f: tf.data.TFRecordDataset(f).map(parse_fn),
cycle_length=4, # 同时读4个文件
block_length=16, # 每个文件连续取16条
num_parallel_calls=tf.data.AUTOTUNE
)
# 注意:当文件数量远大于 cycle_length 时,
# interleave 会动态切换文件,保持并行度
4.6 文件分片(shard)
在分布式训练中,每个 worker 只处理总文件的一个子集,避免数据重复。shard 根据 worker 数量和索引对数据集进行均匀划分。
# 分布式训练中的文件分片
num_workers = 8
worker_index = 0 # 当前 worker 编号
dataset = tf.data.Dataset.list_files("/data/*.tfrecord")
dataset = dataset.shard(num_workers, worker_index)
# 分片后每个 worker 只处理 1/8 的文件
# 配合 interleave 进一步提升读取效率
# 最佳实践:先分片,后 shuffle
dataset = dataset.shard(num_workers, worker_index)
dataset = dataset.shuffle(buffer_size=10000)
4.7 性能优化最佳实践总结
推荐流水线结构(按顺序):
1. list_files(获取文件列表)→ 2. shard(分布式分片)→ 3. shuffle(文件级打乱)→ 4. interleave(并行读取)→ 5. map(预处理,并行)→ 6. cache(缓存预处理结果)→ 7. map(随机增强,并行)→ 8. shuffle(样本级打乱)→ 9. batch(组批)→ 10. prefetch(预取)
注意:cache 之前的 shuffle/random 操作会被固定,因此随机增强应在 cache 之后进行。
五、TFRecord 格式详解
TFRecord 是 TensorFlow 原生推荐的二进制数据存储格式。它将数据序列化为一条条记录(Record),每条记录由字节串构成,读取时无需解析文件格式,I/O 效率极高。在大规模分布式训练场景中,TFRecord 几乎是标准配置。
5.1 Example 与 Feature 协议
TFRecord 的核心数据结构是 tf.train.Example,它本质上是一个字符串到 tf.train.Feature 的映射。Feature 支持三种类型:BytesList(字符串/字节)、FloatList(浮点数)、Int64List(整数)。
# 构建一个 Example 消息
def create_example(image_bytes, label, height, width):
feature = {
"image/encoded": tf.train.Feature(
bytes_list=tf.train.BytesList(value=[image_bytes])),
"image/height": tf.train.Feature(
int64_list=tf.train.Int64List(value=[height])),
"image/width": tf.train.Feature(
int64_list=tf.train.Int64List(value=[width])),
"image/label": tf.train.Feature(
int64_list=tf.train.Int64List(value=[label])),
"image/format": tf.train.Feature(
bytes_list=tf.train.BytesList(value=[b"jpeg"])),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
5.2 序列化与写入
创建 Example 后,需要通过 SerializeToString 序列化为字节串,再写入 TFRecord 文件。写入时可以使用 tf.io.TFRecordWriter。
# 写入 TFRecord 文件
output_file = "/data/train.tfrecord"
with tf.io.TFRecordWriter(output_file) as writer:
for i in range(1000):
image_bytes = read_image_as_bytes(i) # 假设的函数
example = create_example(
image_bytes, label=i % 10,
height=224, width=224
)
writer.write(example.SerializeToString())
# 写入多个文件(推荐用于大规模数据)
num_shards = 100
for shard_id in range(num_shards):
filename = f"/data/train-{shard_id:05d}-of-{num_shards:05d}"
with tf.io.TFRecordWriter(filename) as writer:
for i in range(1000):
writer.write(create_example_for_index(i).SerializeToString())
5.3 解析 TFRecord
读取时需要编写解析函数,使用 tf.io.parse_single_example 将序列化的字节串还原为特征张量。解析函数的特征描述必须与写入时完全一致。
# 定义特征描述
feature_description = {
"image/encoded": tf.io.FixedLenFeature([], tf.string),
"image/height": tf.io.FixedLenFeature([], tf.int64),
"image/width": tf.io.FixedLenFeature([], tf.int64),
"image/label": tf.io.FixedLenFeature([], tf.int64),
"image/format": tf.io.FixedLenFeature([], tf.string),
}
def parse_tfrecord_fn(example_proto):
# 解析 Example
parsed = tf.io.parse_single_example(example_proto, feature_description)
# 解码图像
image = tf.image.decode_jpeg(parsed["image/encoded"], channels=3)
image = tf.image.resize(image, [224, 224])
image = tf.cast(image, tf.float32) / 255.0
# 返回 (特征, 标签) 元组
return image, parsed["image/label"]
# 组装流水线
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=AUTOTUNE)
dataset = dataset.batch(64).prefetch(AUTOTUNE)
5.4 VarLenFeature 与序列特征
当特征的长度不固定时(如变长序列、多值标签),使用 VarLenFeature 代替 FixedLenFeature,它会返回 SparseTensor。
# 示例:多标签分类
feature_desc = {
"image/encoded": tf.io.FixedLenFeature([], tf.string),
"image/labels": tf.io.VarLenFeature(tf.int64), # 变长标签列表
"text/tokens": tf.io.VarLenFeature(tf.string), # 变长文本 token
}
def parse_fn(proto):
parsed = tf.io.parse_single_example(proto, feature_desc)
image = tf.image.decode_jpeg(parsed["image/encoded"])
labels = tf.sparse.to_dense(parsed["image/labels"]) # 转为稠密
return image, labels
TFRecord 优势总结:
1. 二进制格式,读取效率远超文本文件
2. 天然支持压缩(GZIP/ZLIB),减少存储和 I/O
3. 与 tf.data 深度集成,支持分片、并行读取
4. 良好的跨平台兼容性(Python、Java、Go 等均支持)
5. 适合分布式训练中的数据分发
6. ProtoBuf 序列化,schema 定义清晰
六、特征列 FeatureColumns
FeatureColumns 是 TensorFlow Estimator API(和 Keras)中用于描述特征结构的组件。它将各种原始数据类型(数值、类别、文本、嵌入等)统一转换为模型可以处理的数值张量。虽然在纯 Keras 中可以直接预处理,但 FeatureColumns 提供了一种声明式的、可序列化的特征工程方案,特别适合结构化数据的特征处理流水线。
6.1 数值列(numeric_column)
最基本的特征列,将数值特征直接传递给模型。可以指定归一化函数。
# 基础数值列
age = tf.feature_column.numeric_column("age")
price = tf.feature_column.numeric_column("price", dtype=tf.float32)
# 带归一化的数值列
def normalize_income(x):
return (x - 50000.0) / 30000.0
income = tf.feature_column.numeric_column(
"income", normalizer_fn=normalize_income
)
6.2 类别列(categorical_column)
将离散值(字符串、整数)映射为类别 ID。支持词汇表映射、哈希分桶等方式。
# 词汇表映射
color = tf.feature_column.categorical_column_with_vocabulary_list(
"color",
vocabulary_list=["red", "green", "blue", "yellow"]
)
# 哈希分桶(适合类别数不确定的场景)
city = tf.feature_column.categorical_column_with_hash_bucket(
"city", hash_bucket_size=1000 # 哈希到 1000 个桶
)
# 整数标识
category_id = tf.feature_column.categorical_column_with_identity(
"category_id", num_buckets=50
)
# 分桶连续特征
age_buckets = tf.feature_column.bucketized_column(
source_column=age,
boundaries=[18, 25, 35, 45, 55, 65]
)
6.3 嵌入列(embedding_column)
将高维稀疏类别特征映射到低维稠密向量。是处理大规模类别特征(如用户 ID、商品 ID)的标准方法。
# 嵌入列
city_embedding = tf.feature_column.embedding_column(
categorical_column=city,
dimension=16 # 嵌入维度
)
# 嵌入维度经验公式
# dimension = min(50, (num_categories + 1) // 2)
# 或 dimension = int(num_categories ** 0.25) * 4
# 使用嵌入列处理 ID 特征
user_id = tf.feature_column.categorical_column_with_hash_bucket(
"user_id", hash_bucket_size=100000
)
user_embedding = tf.feature_column.embedding_column(
user_id, dimension=32
)
6.4 交叉列(crossed_column)
对两个或多个类别特征进行笛卡尔积交叉,产生新的组合特征。自动进行哈希分桶,避免组合爆炸。
# 特征交叉:年龄段 x 城市
age_x_city = tf.feature_column.crossed_column(
[age_buckets, city],
hash_bucket_size=5000 # 交叉后哈希到 5000 个桶
)
# 交叉列通常需要配合嵌入列使用
age_x_city_embedding = tf.feature_column.embedding_column(
age_x_city, dimension=8
)
# 更高阶交叉
user_x_item = tf.feature_column.crossed_column(
[user_id, category_id],
hash_bucket_size=50000
)
6.5 指示列(indicator_column)
将类别特征转换为 one-hot 编码的多-hot 指示向量。适用于类别数较少(通常 < 100)的场景。
# One-hot 编码
color_indicator = tf.feature_column.indicator_column(color)
# 嵌入列 vs 指示列的选择
# 类别数少(< 50):使用 indicator_column(one-hot)
# 类别数多(>= 50):使用 embedding_column
# 数学上,embedding_column 可视为 indicator_column 的低秩近似
6.6 在 Keras 中使用 FeatureColumns
# 组合所有特征列
feature_columns = [
age, # 数值列
income, # 数值列(带归一化)
age_buckets, # 分桶列
color_indicator, # 指示列(one-hot)
city_embedding, # 嵌入列
age_x_city_embedding, # 交叉列 + 嵌入
]
# 在 Keras 中使用
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
model = tf.keras.Sequential([
feature_layer,
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dense(1, activation="sigmoid")
])
model.compile(optimizer="adam", loss="binary_crossentropy")
# 训练数据用字典格式输入
model.fit(dataset_dict, epochs=10)
七、Keras 与 tf.data 集成
Keras 的 model.fit()、model.evaluate() 和 model.predict() 均支持直接接收 tf.data.Dataset 作为输入。这是 tf.data 最强大的应用场景之一,将数据流水线与训练循环无缝衔接。
7.1 基础集成
# 准备 Dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(60000).batch(32).prefetch(AUTOTUNE)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(32).prefetch(AUTOTUNE)
# 直接传入 Dataset
model.fit(train_dataset, validation_data=test_dataset, epochs=10)
# evaluate 和 predict 同理
loss, acc = model.evaluate(test_dataset)
predictions = model.predict(test_dataset)
7.2 steps_per_epoch 控制
当 Dataset 无限重复(repeat() 无参数)时,必须指定 steps_per_epoch 来决定每个 epoch 的步数。
# 无限重复的 Dataset
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.repeat() # 无限重复
train_dataset = train_dataset.shuffle(10000)
train_dataset = train_dataset.batch(32).prefetch(AUTOTUNE)
model.fit(
train_dataset,
epochs=10,
steps_per_epoch=1875, # 60000 / 32 ≈ 1875
validation_data=test_dataset
)
7.3 多输入模型
当模型有多个输入时,Dataset 的每个元素应为 (inputs_dict, label) 或 (inputs_tuple, label) 格式。
# 多输入 Dataset
def gen_multi_input():
for i in range(1000):
yield (
{"image_input": np.random.randn(224, 224, 3),
"text_input": np.random.randint(0, 1000, size=(50,))},
np.random.randint(0, 10)
)
dataset = tf.data.Dataset.from_generator(
gen_multi_input,
output_signature=(
{"image_input": tf.TensorSpec(shape=(224, 224, 3), dtype=tf.float32),
"text_input": tf.TensorSpec(shape=(50,), dtype=tf.int32)},
tf.TensorSpec(shape=(), dtype=tf.int32)
)
)
dataset = dataset.batch(32)
# 多输入 Keras 模型
image_input = tf.keras.Input(shape=(224, 224, 3), name="image_input")
text_input = tf.keras.Input(shape=(50,), name="text_input")
# ... 模型定义 ...
model.fit(dataset, epochs=10)
7.4 自定义训练循环
在需要更精细控制时,可以直接在 tf.GradientTape 中迭代 Dataset。
# 自定义训练循环
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
for epoch in range(10):
for step, (x_batch, y_batch) in enumerate(train_dataset):
with tf.GradientTape() as tape:
logits = model(x_batch, training=True)
loss = loss_fn(y_batch, logits)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(f"Epoch {epoch}, Step {step}, Loss: {loss.numpy():.4f}")
八、完整实战案例
以下是一个端到端的实战案例,完整演示了从 TFRecord 创建到 Keras 模型训练的完整数据流水线。
# ========= 步骤1: 创建 TFRecord 文件 =========
import tensorflow as tf
import numpy as np
AUTOTUNE = tf.data.AUTOTUNE
def serialize_example(feature0, feature1, feature2, label):
feature = {
"feature0": tf.train.Feature(
float_list=tf.train.FloatList(value=[feature0])),
"feature1": tf.train.Feature(
float_list=tf.train.FloatList(value=feature1.flatten())),
"feature2": tf.train.Feature(
int64_list=tf.train.Int64List(value=[feature2])),
"label": tf.train.Feature(
int64_list=tf.train.Int64List(value=[label])),
}
return tf.train.Example(
features=tf.train.Features(feature=feature)
).SerializeToString()
# 生成模拟数据并写入 TFRecord
with tf.io.TFRecordWriter("/tmp/train.tfrecord") as writer:
for i in range(10000):
f0 = np.random.randn()
f1 = np.random.randn(10)
f2 = np.random.randint(0, 100)
label = np.random.randint(0, 10)
writer.write(serialize_example(f0, f1, f2, label))
# ========= 步骤2: 定义解析函数 =========
feature_description = {
"feature0": tf.io.FixedLenFeature([], tf.float32),
"feature1": tf.io.FixedLenFeature([10], tf.float32),
"feature2": tf.io.FixedLenFeature([], tf.int64),
"label": tf.io.FixedLenFeature([], tf.int64),
}
def parse_fn(proto):
parsed = tf.io.parse_single_example(proto, feature_description)
features = {
"f0": parsed["feature0"],
"f1": parsed["feature1"],
"f2": tf.cast(parsed["feature2"], tf.float32),
}
label = parsed["label"]
return features, label
# ========= 步骤3: 构建高性能流水线 =========
dataset = tf.data.TFRecordDataset(["/tmp/train.tfrecord"])
dataset = dataset.map(parse_fn, num_parallel_calls=AUTOTUNE)
dataset = dataset.cache() # 缓存解析结果
dataset = dataset.shuffle(buffer_size=5000)
dataset = dataset.batch(64, drop_remainder=True)
dataset = dataset.prefetch(AUTOTUNE)
# ========= 步骤4: 构建并训练模型 =========
feature_columns = [
tf.feature_column.numeric_column("f0"),
tf.feature_column.numeric_column("f1", shape=(10,)),
tf.feature_column.numeric_column("f2"),
]
model = tf.keras.Sequential([
tf.keras.layers.DenseFeatures(feature_columns),
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dense(32, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax"),
])
model.compile(
optimizer="adam",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
model.fit(dataset, epochs=10, verbose=2)
# ========= 步骤5: 保存模型 =========
model.save("/tmp/tfdata_demo_model")
九、总结
tf.data 是 TensorFlow 生态中不可或缺的数据处理组件,它通过声明式的 API 将数据加载、预处理、增强、组批和预取等环节无缝衔接。掌握 tf.data 的核心概念和优化技巧,对于提升深度学习训练效率和 GPU 利用率至关重要。
核心要点回顾:
1. Dataset 创建:根据数据规模和格式选择合适的创建方法(from_tensor_slices、from_generator、list_files、TextLineDataset、TFRecordDataset)
2. 数据变换:灵活组合 map、batch、shuffle、repeat、cache、prefetch、filter、unbatch、window 等操作
3. 性能优化:始终使用 prefetch(AUTOTUNE);对 map 设置 num_parallel_calls=AUTOTUNE;使用 interleave 并行读取文件;在 map 后使用 cache 避免重复计算
4. TFRecord:大规模数据推荐使用 TFRecord 格式,通过 Example/Feature 协议序列化,利用分片实现分布式读取
5. FeatureColumns:声明式特征工程,覆盖数值列、类别列、嵌入列、交叉列和指示列
6. Keras 集成:model.fit() 直接接收 Dataset,支持多输入、多输出和自定义训练循环
最佳实践口诀:
数据源选 TFRecord,解析 map 加并行;
shuffle 在前 batch 后,预取 cache 不能漏;
AUTOTUNE 来调优,GPU 满载不空等;
特征列声明式写,Keras 集成直接训。