模型保存与加载部署

从训练到生产的完整流程 —— Keras / PyTorch / ONNX / TensorRT 全链路详解

核心主题: 深度学习模型从训练到部署的全流程技术栈详解

涉及框架: TensorFlow / Keras / PyTorch / ONNX / TensorRT / TensorFlow Serving / Triton / TFLite

主要内容: Keras模型保存与加载、PyTorch模型序列化、ONNX导出与优化、TensorRT高性能推理、模型服务部署、移动端与边缘端部署

关键词: 模型部署, SavedModel, ONNX, TensorRT, TFLite, TensorFlow Serving, 模型量化, 模型序列化

一、概述

模型保存与加载部署是深度学习工程化中最关键的环节之一。一个模型在训练完成后,如果不能被有效地保存、加载并以高效的方式在生产环境中提供服务,其学术价值就无法转化为实际生产力。本文将系统梳理从模型保存到生产部署的完整技术链路,涵盖当前业界主流的技术方案和最佳实践。

整个模型部署生命周期可以分为三个核心阶段:模型序列化(将训练好的模型权重和结构持久化到磁盘)、模型转换与优化(将模型转换为目标部署格式并进行推理优化)、模型服务(将优化后的模型部署到生产环境提供在线推理服务)。每个阶段都有多种技术方案可选,具体选择取决于模型框架、部署平台和性能要求。

核心挑战

  • 框架兼容性: 训练框架(PyTorch/TF)与推理框架(ONNX Runtime/TensorRT)之间的格式转换
  • 性能优化: 模型量化、算子融合、内存优化等推理加速技术
  • 部署环境: 服务器端(GPU/CPU)、移动端(Android/iOS)、边缘设备(IoT/嵌入式)
  • 服务化: 高并发、低延迟、弹性伸缩的生产级推理服务架构
阶段 技术方案 适用场景 关键考量
模型保存 SavedModel / HDF5 / state_dict / TorchScript 训练后存档、迁移学习、模型分发 完整性、可移植性、框架依赖
模型转换 ONNX / TensorRT / TFLite / CoreML 跨平台部署、推理加速 算子支持、精度损失、优化空间
模型服务 TF Serving / Triton / TorchServe / BentoML 生产环境在线推理 吞吐量、延迟、可扩展性

二、Keras / TensorFlow模型保存与加载

Keras作为TensorFlow的高级API,提供了多种模型保存与加载的方式。选择合适的保存格式和使用方式对于模型的生产部署至关重要。

2.1 SavedModel格式(推荐)

SavedModel是TensorFlow推荐的默认保存格式,它同时保存模型的网络结构、权重参数以及计算图,具备良好的跨版本兼容性和部署友好性。SavedModel目录包含 saved_model.pb(序列化的计算图)、variables/(权重文件)和 assets/(辅助资源)三个部分。

import tensorflow as tf # 构建并训练模型 model = tf.keras.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 训练完成后保存为SavedModel格式 model.save('my_model_savedmodel/', save_format='tf') # 加载SavedModel loaded_model = tf.keras.models.load_model('my_model_savedmodel/') # 验证加载后的模型预测结果一致 import numpy as np test_input = np.random.randn(1, 784) assert np.allclose(model(test_input), loaded_model(test_input))

2.2 HDF5格式(.h5)

HDF5格式是Keras早期版本的默认保存格式,将模型结构、权重和训练配置打包到一个单一文件中。虽然HDF5格式便于分发(只有一个文件),但在跨版本兼容性和自定义层支持方面不如SavedModel灵活。

# 保存为HDF5格式 model.save('my_model.h5') # 从HDF5文件加载模型 loaded_model = tf.keras.models.load_model('my_model.h5') # HDF5格式适合做模型版本归档 # 注意:包含自定义层时需要提供custom_objects参数 class MyLayer(tf.keras.layers.Layer): def call(self, inputs): return tf.nn.swish(inputs) model_with_custom = tf.keras.Sequential([ tf.keras.layers.Dense(64), MyLayer() ]) model_with_custom.save('custom_model.h5') # 加载时需要注册自定义层 loaded_custom = tf.keras.models.load_model( 'custom_model.h5', custom_objects={'MyLayer': MyLayer} )

2.3 仅保存权重(save_weights / load_weights)

当只需要保存模型参数而不需要完整网络结构时(如迁移学习场景),可以使用 save_weights 方法。这种方式保存的文件体积最小,但加载时必须先构建完全相同的模型结构。

# 仅保存权重 model.save_weights('model_weights.h5') # 加载权重前必须先重建模型结构 new_model = tf.keras.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) # 先编译(对函数式模型和子类化模型是必需的) new_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') # 然后加载权重 new_model.load_weights('model_weights.h5') # 权重格式也可以保存为checkpoint格式,支持断点续训 model.save_weights('checkpoints/model_ckpt') model.load_weights('checkpoints/model_ckpt')

2.4 保存仅网络结构

在某些场景下,我们可能需要单独保存网络结构(JSON/YAML格式)或仅保存网络配置,实现结构与权重的分离管理。

# 仅保存网络结构为JSON json_config = model.to_json() with open('model_architecture.json', 'w') as f: f.write(json_config) # 从JSON重建模型结构 from keras.models import model_from_json with open('model_architecture.json', 'r') as f: loaded_json = f.read() reconstructed_model = tf.keras.models.model_from_json(loaded_json) # 然后再加载权重 reconstructed_model.load_weights('model_weights.h5') # 保存为YAML格式(需要安装pyyaml) yaml_config = model.to_yaml() with open('model_architecture.yaml', 'w') as f: f.write(yaml_config)

最佳实践建议

  • 新项目默认使用SavedModel格式,它是最标准和最具备前瞻性的选择
  • 迁移学习使用保存/加载权重(save_weights/load_weights),减少磁盘占用和加载时间
  • 模型存档推荐HDF5,单文件便于管理和备份
  • 生产部署必须使用SavedModel,因为TensorFlow Serving原生支持该格式

三、PyTorch模型保存与加载

PyTorch提供了灵活多样的模型序列化方案,从最简单的参数保存到完整的TorchScript编译,覆盖了从研究到生产的不同需求层次。

3.1 state_dict:推荐的方式

state_dict 是一个Python字典对象,将每一层可学习参数(权重和偏置)映射到其对应的张量。这是PyTorch官方推荐的模型保存方式,具有最小的文件体积和最佳的安全性。

import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 128) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.2) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) def forward(self, x): x = self.relu(self.fc1(x)) x = self.dropout(x) x = self.relu(self.fc2(x)) return self.fc3(x) model = SimpleModel() # 保存state_dict(推荐) torch.save(model.state_dict(), 'model_state_dict.pth') # 加载state_dict前必须创建模型实例 loaded_model = SimpleModel() loaded_model.load_state_dict(torch.load('model_state_dict.pth')) loaded_model.eval() # 切换到评估模式

3.2 保存完整模型

PyTorch也支持将完整模型(结构+权重)序列化到一个文件中。这种方式虽然使用方便,但安全性较低(因为序列化的是Python对象),且在不同版本的PyTorch之间可能存在兼容性问题。

# 保存完整模型(包含结构和权重) # 注意:这种方式通过pickle序列化整个模型对象 torch.save(model, 'complete_model.pth') # 加载完整模型 # 警告:存在安全风险,不要加载不可信的模型文件 loaded_model = torch.load('complete_model.pth') loaded_model.eval() # 从文件后缀来看,.pth, .pt, .pkl 都是常见的PyTorch模型文件后缀 # 推荐使用 .pth 作为统一后缀

3.3 完整模型 vs 参数保存对比

对比维度 state_dict(推荐) 完整模型
文件体积 小(仅参数) 较大(包含模型定义)
安全性 高(纯张量数据) 低(pickle反序列化有代码执行风险)
跨版本兼容性 差(需相同环境)
使用便捷性 需要额外创建模型结构 直接加载即可使用
推荐场景 生产环境、模型分发 快速原型、个人实验

3.4 检查点(Checkpoint)保存与断点续训

在训练大型模型时,定期保存检查点是一个良好的实践。一个好的检查点应该包含模型参数、优化器状态、当前epoch和最佳性能指标,以便从中断处恢复训练。

# 保存包含优化器状态和训练进度的检查点 checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': best_loss, 'accuracy': best_acc } torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth') # 从中断处恢复训练 checkpoint = torch.load('checkpoint_epoch_10.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 best_loss = checkpoint['loss'] # 继续训练循环 for epoch in range(start_epoch, num_epochs): # 训练代码... pass

3.5 TorchScript:从研究到生产的桥梁

TorchScript是PyTorch模型从研究到生产部署的桥樑技术。通过 torch.jit.scripttorch.jit.trace 两种方式,将动态的Python模型编译为静态的计算图,使其可以脱离Python运行时在C++环境中独立执行。

import torch import torchvision.models as models # ===== torch.jit.trace(跟踪)===== # 适用于不含控制流的模型 resnet = models.resnet18(pretrained=True) resnet.eval() # 提供示例输入进行跟踪 example_input = torch.randn(1, 3, 224, 224) traced_model = torch.jit.trace(resnet, example_input) traced_model.save('resnet18_traced.pt') # ===== torch.jit.script(脚本)===== # 适用于包含条件/循环等控制流的模型 class AdaptiveModel(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 5) def forward(self, x, use_dropout: bool = False): x = self.fc(x) if use_dropout: # 控制流 - 必须使用script x = torch.dropout(x, 0.5, train=self.training) return x adaptive_model = AdaptiveModel() scripted_model = torch.jit.script(adaptive_model) scripted_model.save('adaptive_model_scripted.pt') # 在C++环境中加载TorchScript模型 # torch::jit::script::Module module = torch::jit::load("resnet18_traced.pt"); # std::vector inputs; # inputs.push_back(torch::randn({1, 3, 224, 224})); # at::Tensor output = module.forward(inputs).toTensor();

script vs trace 选择指南

  • torch.jit.trace: 适用于无数据依赖控制流的模型(如CNN、ResNet、EfficientNet)。简单可靠,推荐优先使用。
  • torch.jit.script: 适用于包含条件分支或循环的模型(如Transformer解码器、RNN with attention)。功能更强大,但需要模型代码兼容TorchScript语法。
  • torch.onnx.export: 内部也使用trace机制,因此同样要求模型不含动态控制流。

3.6 多GPU模型保存的特殊处理

当使用 DataParallelDistributedDataParallel 时,模型参数名称会带有 module. 前缀,需要在保存或加载时特殊处理。

# 多GPU训练时的模型保存 if isinstance(model, nn.DataParallel): # 去掉module.前缀,保存原始模型参数 torch.save(model.module.state_dict(), 'model_ddp.pth') else: torch.save(model.state_dict(), 'model_ddp.pth') # 加载时的处理 raw_state_dict = torch.load('model_ddp.pth') # 如果state_dict的key带有module.前缀,创建新的不包含该前缀的字典 from collections import OrderedDict new_state_dict = OrderedDict() for k, v in raw_state_dict.items(): name = k[7:] if k.startswith('module.') else k # 移除module.前缀 new_state_dict[name] = v model.load_state_dict(new_state_dict)

四、ONNX导出与优化

ONNX(Open Neural Network Exchange)是一种开放的模型交换格式,旨在实现不同深度学习框架之间的模型互操作性。它充当了训练框架和推理引擎之间的中间表示层,使得我们可以用PyTorch训练模型,然后导出为ONNX格式,最后在各种推理引擎(ONNX Runtime、TensorRT、OpenVINO)上运行。

4.1 torch.onnx.export 基本用法

import torch import torchvision.models as models # 加载预训练模型 model = models.resnet18(pretrained=True) model.eval() # 创建示例输入 dummy_input = torch.randn(1, 3, 224, 224) # 导出为ONNX格式 torch.onnx.export( model, # 模型 dummy_input, # 示例输入 'resnet18.onnx', # 输出文件名 export_params=True, # 导出模型参数 opset_version=17, # ONNX算子集版本 do_constant_folding=True, # 常量折叠优化 input_names=['input'], # 输入张量名称 output_names=['output'], # 输出张量名称 dynamic_axes={ # 动态轴定义 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} } )

4.2 opset版本选择

ONNX算子集(opset)定义了支持的算子集合和版本。选择opset版本时需要权衡:新版本支持更多算子但兼容性可能受限,旧版本兼容性好但可能缺少某些算子支持。

# 查看ONNX版本和默认opset import onnx print(f"ONNX version: {onnx.__version__}") print(f"Default opset: {onnx.defs.onnx_opset_version()}") # 不同opset版本的显式指定 # opset 11: 支持Split、Pad、Slice等算子的改进版本 # opset 13: 支持Softmax的axis参数、ReduceOps的noop_with_empty_axes # opset 15: 支持BatchNormalization、Trainable # opset 17: 支持DFT、STFT等信号处理算子 # opset 18: 支持BitwiseAnd等位运算算子 torch.onnx.export( model, dummy_input, 'model.onnx', opset_version=17, # 推荐使用17及以上版本 operator_export_type=torch.onnx.OperatorExportTypes.ONNX )

4.3 动态轴配置

在生产部署中,推理请求的batch size通常是可变的。通过配置动态轴,可以让ONNX模型接受不同大小的输入,而不需要为每个batch size重新导出模型。

# 配置多个动态轴 dynamic_axes = { 'input': { 0: 'batch_size', # batch维度可变 2: 'height', # 高度可变(适用于可变尺寸图像) 3: 'width' # 宽度可变 }, 'output': { 0: 'batch_size' } } torch.onnx.export( model, dummy_input, 'dynamic_model.onnx', dynamic_axes=dynamic_axes, input_names=['input'], output_names=['output'] ) # 使用动态轴时,ONNX Runtime中需要动态指定输入尺寸 import onnxruntime as ort session = ort.InferenceSession('dynamic_model.onnx') # 可以用不同的batch size进行推理 for batch_size in [1, 4, 8, 16]: dynamic_input = np.random.randn(batch_size, 3, 224, 224).astype(np.float32) outputs = session.run(None, {'input': dynamic_input})

4.4 ONNX模型验证

导出ONNX模型后,必须进行完整验证以确保导出正确性和数值一致性。ONNX提供了标准的验证工具。

import onnx import onnxruntime as ort import numpy as np # 步骤1:检查模型结构完整性 onnx_model = onnx.load('resnet18.onnx') onnx.checker.check_model(onnx_model) print("ONNX模型结构检查通过") # 步骤2:打印模型信息 print(f"Opset版本: {onnx_model.opset_import[0].version}") print(f"输入节点数: {len(onnx_model.graph.input)}") print(f"输出节点数: {len(onnx_model.graph.output)}") # 步骤3:使用ONNX Runtime验证数值一致性 ort_session = ort.InferenceSession('resnet18.onnx') # 准备测试数据 test_input = np.random.randn(1, 3, 224, 224).astype(np.float32) # PyTorch原始模型推理 with torch.no_grad(): torch_output = model(torch.from_numpy(test_input)).numpy() # ONNX Runtime推理 ort_inputs = {ort_session.get_inputs()[0].name: test_input} ort_output = ort_session.run(None, ort_inputs)[0] # 数值一致性检查 np.testing.assert_allclose(torch_output, ort_output, rtol=1e-03, atol=1e-05) print(f"数值一致性验证通过,最大绝对误差: {np.max(np.abs(torch_output - ort_output))}")

4.5 ONNX模型优化

ONNX Runtime提供了图优化和级别设置。通过 onnxoptimizeronnxruntime.GraphOptimizationLevel,可以进一步提升推理性能。

# 使用ONNX Runtime的图优化 import onnxruntime as ort # 创建推理会话时配置优化级别 options = ort.SessionOptions() options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL options.optimized_model_filepath = 'resnet18_optimized.onnx' options.intra_op_num_threads = 4 # intra-op并行线程数 options.inter_op_num_threads = 2 # inter-op并行线程数 session = ort.InferenceSession('resnet18.onnx', sess_options=options) # 使用onnx-simplifier简化模型 # pip install onnx-simplifier import onnxsim import onnx # 简化模型:合并常量、消除冗余算子 model_simp, check = onnxsim.simplify('resnet18.onnx') assert check, "Simplification failed" onnx.save(model_simp, 'resnet18_simplified.onnx') print(f"简化后模型大小减小: " f"{os.path.getsize('resnet18.onnx')/1024:.1f}KB -> " f"{os.path.getsize('resnet18_simplified.onnx')/1024:.1f}KB")

4.6 兼容性与常见问题排查

在ONNX导出过程中,最常见的两类问题是算子不支持和动态控制流导致的导出失败。针对这些问题,需要采用针对性方案解决。

# 问题1:算子不支持的替代方案 # 使用torch.onnx.symbolic自定义算子映射 from torch.onnx import register_custom_op_symbolic def my_custom_op_symbolic(g, input, other): """将PyTorch自定义操作映射到ONNX支持的算子组合""" return g.op('Add', input, other) register_custom_op_symbolic('custom_namespace::my_op', my_custom_op_symbolic, 1) # 问题2:控制流导出 # 使用torch.jit.script预处理再导出 class ConditionalModel(torch.nn.Module): def forward(self, x, mask): # 使用torch.where替代if-else控制流 return torch.where(mask > 0.5, x, -x) # 先script再export scripted_model = torch.jit.script(ConditionalModel()) torch.onnx.export( scripted_model, (torch.randn(3), torch.randn(3)), 'conditional_model.onnx' ) # 问题3:动态shape导致导出的模型包含不必要的reshape # 使用torch._assert和torch._shape_as_tensor def safe_reshape(x, target_shape): """安全的动态reshape操作""" if isinstance(target_shape, torch.Tensor): return x.reshape(target_shape.tolist()) return x.reshape(target_shape)

ONNX导出最佳实践总结

  1. 固定输入尺寸: 如果业务场景允许,优先使用固定尺寸,避免动态轴带来的兼容性问题
  2. 选择合适的opset: 推荐opset版本17,在算子丰富度和兼容性之间取得平衡
  3. 逐层验证: 导出后先验证结构完整性,再验证数值一致性,最后测试推理性能
  4. 模型简化: 使用onnx-simplifier消除冗余算子,减小模型文件大小
  5. 兼容性问题: 遇到不支持的算子时,考虑用等效算子组合替代,或修改模型结构绕开

五、TensorRT优化部署

NVIDIA TensorRT是一个高性能深度学习推理优化器,通过对模型进行层融合、精度校准、内存优化等操作,可以显著提升在NVIDIA GPU上的推理速度。TensorRT的优化流程通常包括:模型解析、图优化、精度校准和Engine生成四个阶段。

5.1 从ONNX到TensorRT Engine

最推荐的TensorRT工作流是先导出ONNX模型,再通过TensorRT的ONNX解析器转换为TensorRT Engine。这种方式将格式转换与优化分离,便于调试和迭代。

import tensorrt as trt import numpy as np # TensorRT日志记录器 TRT_LOGGER = trt.Logger(trt.Logger.WARNING) def build_engine(onnx_file_path, engine_file_path, precision='fp32'): """从ONNX文件构建TensorRT engine""" with trt.Builder(TRT_LOGGER) as builder, \ builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \ trt.OnnxParser(network, TRT_LOGGER) as parser: # 读取ONNX模型 with open(onnx_file_path, 'rb') as f: if not parser.parse(f.read()): for error in range(parser.num_errors): print(parser.get_error(error)) return None # 配置构建选项 config = builder.create_builder_config() # 设置工作空间大小(这里设为1GB) config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 设置精度 if precision == 'fp16': if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) print("启用FP16推理") elif precision == 'int8': if builder.platform_has_fast_int8: config.set_flag(trt.BuilderFlag.INT8) print("启用INT8推理") # 构建序列化engine serialized_engine = builder.build_serialized_network(network, config) with open(engine_file_path, 'wb') as f: f.write(serialized_engine) return serialized_engine # 构建FP16 engine build_engine('resnet18.onnx', 'resnet18_fp16.engine', precision='fp16')

5.2 INT8量化与精度校准

INT8量化是减少模型体积和加速推理的最有效手段之一,通常可以带来约4倍的理论加速和4倍的内存减少。但INT8量化需要使用校准数据集来最小化精度损失。

import tensorrt as trt import torch from torch.utils.data import DataLoader class Calibrator(trt.IInt8EntropyCalibrator2): """INT8校准器实现""" def __init__(self, dataloader, cache_file='calibration.cache'): super().__init__() self.dataloader = dataloader self.cache_file = cache_file self.batch_size = dataloader.batch_size self.data_iter = iter(dataloader) self.device = torch.device('cuda:0') def get_batch_size(self): return self.batch_size def get_batch(self, names): """返回校准批次数据""" try: data = next(self.data_iter) if isinstance(data, (list, tuple)): data = data[0] # 取输入数据 # 转移到GPU并转换为float32 data = data.to(self.device).float().contiguous() # 返回数据指针 return [int(data.data_ptr())] except StopIteration: return None def read_calibration_cache(self): """读取校准缓存""" import os if os.path.exists(self.cache_file): with open(self.cache_file, 'rb') as f: return f.read() return None def write_calibration_cache(self, cache): """写入校准缓存""" with open(self.cache_file, 'wb') as f: f.write(cache) # 使用校准器构建INT8 engine def build_int8_engine(onnx_path, engine_path, calib_dataloader): TRT_LOGGER = trt.Logger(trt.Logger.WARNING) calibrator = Calibrator(calib_dataloader) with trt.Builder(TRT_LOGGER) as builder, \ builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \ trt.OnnxParser(network, TRT_LOGGER) as parser: with open(onnx_path, 'rb') as f: parser.parse(f.read()) config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) if builder.platform_has_fast_int8: config.set_flag(trt.BuilderFlag.INT8) config.int8_calibrator = calibrator print("使用INT8精度,已配置校准器") serialized_engine = builder.build_serialized_network(network, config) with open(engine_path, 'wb') as f: f.write(serialized_engine) return serialized_engine

5.3 层融合与图优化

TensorRT的核心优化技术之一是层融合(Layer Fusion)。通过将连续的算子(如Conv + BN + ReLU)合并为一个名为CBR的融合算子,减少kernel launch开销和全局内存访问,从而实现显著的加速效果。典型的融合模式包括:

# TensorRT自动层融合模式示例 # 原始图: Conv(3x3) -> BN -> ReLU -> Conv(1x1) -> BN -> ReLU # 融合后: CBR(3x3) -> CBR(1x1) # 减少kernel launch次数: 6次 -> 2次 # 其他常见融合模式: # Conv + BN + ReLU -> CBR # Conv + Add + ReLU -> CBR (residual) # FC + ReLU -> FCR # Concat + BN -> merge # LayerNorm + Scale -> merge # 可以通过trtexec工具查看优化细节 # trtexec --onnx=resnet18.onnx --fp16 --verbose --dumpLayerInfo

5.4 动态形状处理

生产环境中推理请求的输入尺寸往往是可变的。TensorRT通过优化形状(Optimization Profiles)支持动态输入形状,需要在构建engine时指定输入尺寸范围。

def build_dynamic_engine(onnx_path, engine_path): TRT_LOGGER = trt.Logger(trt.Logger.WARNING) with trt.Builder(TRT_LOGGER) as builder, \ builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) as network, \ trt.OnnxParser(network, TRT_LOGGER) as parser: with open(onnx_path, 'rb') as f: parser.parse(f.read()) config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 创建优化profile,定义输入尺寸范围 profile = builder.create_optimization_profile() # 输入名称必须与ONNX导出的input_names一致 input_name = 'input' # 定义动态形状谱:最小、最优、最大 min_shape = (1, 3, 224, 224) opt_shape = (4, 3, 224, 224) # 最常使用的尺寸 max_shape = (16, 3, 224, 224) profile.set_shape(input_name, min_shape, opt_shape, max_shape) config.add_optimization_profile(profile) serialized_engine = builder.build_serialized_network(network, config) with open(engine_path, 'wb') as f: f.write(serialized_engine) return serialized_engine # 推理时指定实际输入尺寸 def infer_with_dynamic_shape(engine_path, input_data): with open(engine_path, 'rb') as f: serialized_engine = f.read() runtime = trt.Runtime(TRT_LOGGER) engine = runtime.deserialize_cuda_engine(serialized_engine) # 创建执行上下文 context = engine.create_execution_context() context.set_input_shape('input', input_data.shape) # 分配设备内存并执行推理 # ... 标准TensorRT推理流程 pass

5.5 Engine序列化与反序列化

构建TensorRT engine是一个耗时的过程(可能需要几分钟到几十分钟)。因此,必须将构建好的engine序列化到磁盘,在部署时直接反序列化加载,避免每次启动时重新构建。

import tensorrt as trt import numpy as np import pycuda.driver as cuda import pycuda.autoinit class TensorRTInference: """封装TensorRT推理的类""" def __init__(self, engine_path): TRT_LOGGER = trt.Logger(trt.Logger.WARNING) self.runtime = trt.Runtime(TRT_LOGGER) # 反序列化engine with open(engine_path, 'rb') as f: self.engine = self.runtime.deserialize_cuda_engine(f.read()) self.context = self.engine.create_execution_context() # 分配设备内存 self.inputs = [] self.outputs = [] self.bindings = [] for binding in self.engine: shape = self.engine.get_binding_shape(binding) size = trt.volume(shape) dtype = trt.nptype(self.engine.get_binding_dtype(binding)) # 分配主机和设备内存 host_mem = cuda.pagelocked_empty(size, dtype) device_mem = cuda.mem_alloc(host_mem.nbytes) self.bindings.append(int(device_mem)) if self.engine.binding_is_input(binding): self.inputs.append({'host': host_mem, 'device': device_mem, 'name': binding, 'shape': shape, 'size': size}) else: self.outputs.append({'host': host_mem, 'device': device_mem, 'name': binding, 'shape': shape, 'size': size}) def infer(self, input_np): """执行推理""" # 复制输入数据到设备 cuda.memcpy_htod(self.inputs[0]['device'], input_np.ravel()) # 执行推理 self.context.execute_v2(bindings=self.bindings) # 复制输出回主机 cuda.memcpy_dtoh(self.outputs[0]['host'], self.outputs[0]['device']) # 重塑输出形状 return np.reshape(self.outputs[0]['host'], self.outputs[0]['shape']) def __del__(self): # 清理资源 for binding in self.inputs + self.outputs: binding['device'].free() # 使用示例 trt_infer = TensorRTInference('resnet18_fp16.engine') result = trt_infer.infer(np.random.randn(1, 3, 224, 224).astype(np.float32))

TensorRT部署关键指标

  • FP16推理: 相比FP32通常获得2-3倍加速,精度损失极小(通常小于0.5%)
  • INT8推理: 相比FP32获得4-6倍加速,需要校准数据集,精度损失通常在1-3%
  • 层融合: 减少kernel launch次数,降低推理延迟,对batch size较小时效果显著
  • 动态形状: 灵活支持可变batch size,但可能引入少量性能开销
  • Engine序列化: 建议在CI/CD流水线中预构建engine,部署时直接加载

六、模型服务部署

将优化后的模型部署为可访问的在线推理服务,是模型投入生产的最后一步。当前业界主流的模型服务框架包括TensorFlow Serving、NVIDIA Triton Inference Server以及TorchServe等。本节重点介绍最广泛使用的两种方案。

6.1 TensorFlow Serving

TensorFlow Serving是一个专为TensorFlow模型设计的灵活、高性能的模型服务系统。它天然支持SavedModel格式,提供REST API和gRPC两种接口,支持模型版本管理、优雅加载与卸载等生产级特性。

# 训练并导出SavedModel供TF Serving使用 import tensorflow as tf import tempfile import os model = tf.keras.Sequential([ tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy') # 使用model_id创建版本化目录结构 model_version = 1 export_path = os.path.join('models/mnist_model', str(model_version)) model.save(export_path, save_format='tf') # 目录结构如下: # models/ # mnist_model/ # 1/ <-- 版本号目录 # saved_model.pb # variables/ # variables.data-00000-of-00001 # variables.index print(f"模型已导出到 {export_path}") print(f"模型签名: {model.signatures['serving_default']}") # 可以通过命令行启动TF Serving # docker run -p 8501:8501 -p 8500:8500 \ # --mount type=bind,source=$(pwd)/models,target=/models \ # -e MODEL_NAME=mnist_model -t tensorflow/serving

6.2 REST API 调用

TF Serving的REST API基于HTTP/JSON,接口直观易用,适合原型验证和低并发场景。每个请求携带序列化的JSON数据,响应也以JSON格式返回预测结果。

import requests import json import numpy as np # 构建预测请求 def predict_rest(input_data, server_url='http://localhost:8501'): """通过REST API调用TF Serving""" model_name = 'mnist_model' url = f'{server_url}/v1/models/{model_name}:predict' # 数据预处理:转为列表格式 if isinstance(input_data, np.ndarray): input_data = input_data.tolist() payload = { 'instances': input_data # 注意key为instances } response = requests.post(url, json=payload) if response.status_code == 200: result = response.json() return result['predictions'] else: raise Exception(f"预测请求失败: {response.status_code}, {response.text}") # 使用示例 sample = np.random.randn(1, 784).astype(np.float32) predictions = predict_rest(sample) print(f"预测结果: {np.argmax(predictions[0])}") # REST API也支持获取模型元信息 metadata_url = 'http://localhost:8501/v1/models/mnist_model/metadata' metadata = requests.get(metadata_url).json() print(f"模型元数据: {json.dumps(metadata, indent=2)}")

6.3 gRPC 调用(高性能)

gRPC基于Protocol Buffers二进制序列化协议,相比REST API具有更高的传输效率和更低的延迟,是生产环境中的推荐接口方式。

import grpc import tensorflow as tf from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc import numpy as np def predict_grpc(input_data, server_url='localhost:8500'): """通过gRPC调用TF Serving""" # 创建gRPC通道(使用不安全的通道) channel = grpc.insecure_channel(server_url) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) # 构建请求 request = predict_pb2.PredictRequest() request.model_spec.name = 'mnist_model' request.model_spec.signature_name = 'serving_default' # 将numpy数组转换为TensorProto if isinstance(input_data, np.ndarray): request.inputs['input'].CopyFrom( tf.make_tensor_proto(input_data) ) # 发送请求并获取响应(设置超时时间为10秒) response = stub.Predict(request, timeout=10.0) # 解析响应 output_key = list(response.outputs.keys())[0] result = tf.make_ndarray(response.outputs[output_key]) return result # 性能对比:gRPC vs REST # gRPC: ~2-5ms per request (small payload) # REST: ~5-15ms per request (JSON serialization overhead) sample = np.random.randn(1, 784).astype(np.float32) predictions = predict_grpc(sample) print(f"gRPC预测结果: {np.argmax(predictions[0])}")

6.4 Docker部署TF Serving

Docker是部署TF Serving的推荐方式,它提供了环境隔离、资源限制和快速回滚等生产级特性。

# 拉取TensorFlow Serving镜像 docker pull tensorflow/serving:2.13.0-gpu # 启动TF Serving容器(GPU版) docker run --gpus all -d --name tf_serving \ -p 8500:8500 \ # gRPC端口 -p 8501:8501 \ # REST API端口 --mount type=bind,source=/absolute/path/to/models,target=/models \ -e MODEL_NAME=mnist_model \ tensorflow/serving:2.13.0-gpu # 多模型服务配置(models.config) # cat > models.config << EOF # model_config_list: { # config: { # name: "mnist_model", # base_path: "/models/mnist_model", # model_platform: "tensorflow", # model_version_policy: {latest: {num_versions: 3}} # }, # config: { # name: "resnet_model", # base_path: "/models/resnet_model", # model_platform: "tensorflow" # } # } # EOF # 启动多模型服务 docker run -d --name tf_serving_multi \ -p 8500:8500 -p 8501:8501 \ --mount type=bind,source=/path/to/models,target=/models \ --mount type=bind,source=/path/to/models.config,target=/models/models.config \ -e MODEL_NAME=mnist_model \ tensorflow/serving --model_config_file=/models/models.config

6.5 NVIDIA Triton Inference Server

Triton Inference Server是NVIDIA推出的高性能推理服务器,相比TF Serving,它支持更多后端(TensorRT、ONNX Runtime、PyTorch、TensorFlow、自定义Python后端),提供更丰富的调度策略(动态批处理、并发模型执行)和更完善的性能分析工具。

# Triton模型仓库目录结构 # model_repository/ # resnet50/ # 1/ <-- 版本号 # model.onnx <-- ONNX模型文件 # config.pbtxt <-- 模型配置(可选) # sentiment_model/ # 1/ # model.py <-- Python后端自定义逻辑 # 2/ # model.py # 模型配置文件示例 (config.pbtxt) # name: "resnet50" # platform: "onnxruntime_onnx" # max_batch_size: 32 # input [ # { # name: "input" # data_type: TYPE_FP32 # dims: [3, 224, 224] # } # ] # output [ # { # name: "output" # data_type: TYPE_FP32 # dims: [1000] # } # ] # dynamic_batching { # preferred_batch_size: [4, 8, 16] # max_queue_delay_microseconds: 100 # } # Python客户端调用 import tritonclient.http as httpclient import numpy as np client = httpclient.InferenceServerClient(url='localhost:8000') # 检查服务器状态 print(client.is_server_live()) print(client.get_model_repository_index()) # 构建推理请求 input_data = np.random.randn(1, 3, 224, 224).astype(np.float32) inputs = [httpclient.InferInput('input', input_data.shape, 'FP32')] inputs[0].set_data_from_numpy(input_data) outputs = [httpclient.InferRequestedOutput('output')] # 执行推理 results = client.infer('resnet50', inputs, outputs=outputs) output = results.as_numpy('output') print(f"预测类别索引: {np.argmax(output)}")

TF Serving vs Triton 选型建议

  • 纯TensorFlow生态: 选择TF Serving,集成最简单,性能优秀
  • 多框架混合部署: 选择Triton,支持TensorRT/ONNX/PyTorch/TF多后端统一管理
  • 需要动态批处理: Triton的动态批处理功能更完善,支持preferred_batch_size和max_queue_delay
  • 团队学习成本: TF Serving上手简单,Triton配置稍复杂但功能更强大

七、移动端与边缘部署

随着边缘计算和移动AI的兴起,将深度学习模型部署到资源受限的设备(手机、IoT、嵌入式系统)成为一个重要场景。这要求模型在保持精度的同时尽可能小、快、省电。本节介绍主流的移动端和边缘端部署方案。

7.1 TensorFlow Lite(TFLite)

TFLite是TensorFlow针对移动和嵌入式设备的轻量级解决方案,通过量化、委派加速等技术在端侧实现高效推理。

import tensorflow as tf import numpy as np # 加载SavedModel并转换为TFLite converter = tf.lite.TFLiteConverter.from_saved_model('my_model_savedmodel/') # 默认使用Float32 tflite_model = converter.convert() with open('model_fp32.tflite', 'wb') as f: f.write(tflite_model) # === Float16量化 === converter = tf.lite.TFLiteConverter.from_saved_model('my_model_savedmodel/') converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.target_spec.supported_types = [tf.float16] tflite_fp16_model = converter.convert() with open('model_fp16.tflite', 'wb') as f: f.write(tflite_fp16_model) # === INT8量化(需要校准数据)=== def representative_dataset(): """提供校准数据生成器""" for _ in range(100): data = np.random.randn(1, 784).astype(np.float32) yield [data] converter = tf.lite.TFLiteConverter.from_saved_model('my_model_savedmodel/') converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_int8_model = converter.convert() with open('model_int8.tflite', 'wb') as f: f.write(tflite_int8_model) # 比较模型大小 import os for name in ['model_fp32.tflite', 'model_fp16.tflite', 'model_int8.tflite']: size_kb = os.path.getsize(name) / 1024 print(f"{name}: {size_kb:.1f} KB")

7.2 TFLite推理与委派

TFLite支持GPU委派(Delegate)和NNAPI委派,利用设备上的硬件加速器提升推理性能。

import tensorflow as tf import numpy as np # 加载TFLite模型 interpreter = tf.lite.Interpreter(model_path='model_int8.tflite') interpreter.allocate_tensors() # 获取输入输出信息 input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() print(f"输入信息: {input_details}") print(f"输出信息: {output_details}") # 执行推理 input_data = np.random.randn(1, 784).astype(np.int8) interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) print(f"预测结果: {output_data}") # === GPU委派(Android/iOS)=== from tensorflow.lite.python import interpreter as tflite_interpreter # 在Android上使用GPU委派 gpu_delegate = tf.lite.experimental.GpuDelegate( precision_loss_allowed=1, # 允许FP16精度损失以获得加速 inference_preference=tf.lite.experimental.GpuDelegate.Options.INFERENCE_PREFERENCE_SUSTAINED_SPEED ) interpreter = tf.lite.Interpreter( model_path='model_fp16.tflite', experimental_delegates=[gpu_delegate] ) interpreter.allocate_tensors() # === NNAPI委派(Android)=== # NNAPI利用Android设备的DSP/NPU等硬件加速器 interpreter = tf.lite.Interpreter( model_path='model_int8.tflite', experimental_delegates=[tf.lite.experimental.NnApiDelegate()] ) interpreter.allocate_tensors()

7.3 Edge TPU部署

Google Edge TPU(如Coral系列设备)是专为TFLite模型设计的硬件加速器,提供低功耗、高性能的边缘推理能力。在将模型部署到Edge TPU之前,需要将模型编译为Edge TPU兼容格式。

# 步骤1:将TFLite模型编译为Edge TPU模型 # 使用edgetpu_compiler工具(需要安装Edge TPU编译器) edgetpu_compiler model_int8.tflite # 输出: model_int8_edgetpu.tflite # 步骤2:在Python中加载Edge TPU模型 # 使用pycoral库进行推理 import numpy as np from pycoral.utils.edgetpu import make_interpreter from pycoral.adapters import common # 加载Edge TPU编译后的模型 interpreter = make_interpreter('model_int8_edgetpu.tflite') interpreter.allocate_tensors() # 获取输入输出张量详情 input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() # 运行推理 input_data = np.random.randn(1, 784).astype(np.uint8) common.set_input(interpreter, input_data) interpreter.invoke() output = common.output_tensor(interpreter, 0) print(f"Edge TPU推理结果: {output}") # 注意事项: # 1. Edge TPU只支持INT8量化的模型 # 2. 部分算子可能不被Edge TPU支持,会回退到CPU执行 # 3. 编译时需要指定输入输出尺寸,不支持动态形状

7.4 CoreML部署(Apple设备)

CoreML是Apple的机器学习框架,支持将PyTorch和TensorFlow模型转换为Apple设备可以高效运行的格式。CoreML可以通过 coremltools 进行转换。

# 安装coremltools # pip install coremltools import coremltools as ct import torch import torchvision.models as models # 加载PyTorch模型 model = models.resnet18(pretrained=True) model.eval() # 创建示例输入 example_input = torch.randn(1, 3, 224, 224) # 转换为CoreML格式 # 方法1: 从TorchScript转换 traced_model = torch.jit.trace(model, example_input) mlmodel = ct.convert( traced_model, inputs=[ct.TensorType(shape=(1, 3, 224, 224), name='input')], outputs=[ct.TensorType(name='output')], minimum_deployment_target=ct.target.iOS15 # 最低目标iOS版本 ) mlmodel.save('ResNet18.mlpackage') # 方法2: 直接从PyTorch模型转换(coremltools >= 5.0) mlmodel_direct = ct.convert( model, inputs=[ct.TensorType(shape=(1, 3, 224, 224), name='input')], convert_to='mlprogram' # 使用MIL程序格式,性能更好 ) mlmodel_direct.save('ResNet18_direct.mlpackage') # 方法3: 量化CoreML模型(减少体积) mlmodel_fp16 = ct.convert( model, inputs=[ct.TensorType(shape=(1, 3, 224, 224), name='input')], compute_precision=ct.precision.FLOAT16 # FP16精度 ) mlmodel_fp16.save('ResNet18_fp16.mlpackage') # 在iOS中使用CoreML模型的Swift代码: # import CoreML # let model = try MLModel(contentsOf: ResNet18.mlmodel) # let input = try MLMultiArray(shape: [1,3,224,224], dataType: .float32) # let prediction = try model.prediction(from: MLFeatureProvider)

移动端部署选型指南

  1. TFLite (INT8): 通用性最强,支持Android/iOS/Linux嵌入式。优先选择。
  2. TFLite + GPU委派: 当需要更高帧率时使用,支持FP16加速。
  3. Edge TPU: 需要极低功耗、极高能效比时使用。仅支持INT8且算子受限。
  4. CoreML: 纯iOS/macOS生态推荐,与Apple硬件深度优化集成。
  5. TFLite + NNAPI: Android平台利用NPU/DSP加速的推荐方式。

部署方案综合对比

  • 服务器GPU推理: TensorRT(FP16/INT8)提供最高吞吐量和最低延迟,适合高并发场景
  • 服务器CPU推理: ONNX Runtime(INT8量化)或OpenVINO,适合无GPU环境
  • 手机端推理: TFLite INT8量化(通用)或CoreML(iOS专用),模型大小可压缩至原始1/4
  • 边缘设备推理: TFLite + Edge TPU,功耗极低但算子支持受限
  • Web端推理: TensorFlow.js或ONNX Runtime Web,适合浏览器端应用

八、核心要点总结

九、进一步思考

模型部署领域正在快速发展,有几个值得关注的技术趋势:

1. 自适应推理与模型压缩

知识蒸馏(Knowledge Distillation)通过让轻量学生模型模仿大模型的输出,实现模型瘦身,在保持90%以上精度的同时将模型体积缩小10倍以上。剪枝(Pruning)技术移除对最终结果影响较小的连接或通道,进一步压缩模型。结构搜索(NAS)自动化寻找最优的轻量级网络结构。这些技术与量化技术结合,可以使模型在移动端达到接近服务器端的精度水平。

2. 模型版本管理与A/B测试

生产环境中模型的频繁更新需要完善的版本管理机制。通过蓝绿部署(Blue-Green Deployment)和金丝雀发布(Canary Release)策略,可以在不影响在线服务的情况下平滑升级模型。影子测试(Shadow Testing)允许将新模型的推理结果与生产模型进行离线对比,确保新模型不会引入回归问题。MLflow和Kubeflow等MLOps平台提供了模型注册表和版本管理能力。

3. 异构计算与边缘协同

在5G和边缘计算的大背景下,云端协同推理成为趋势。模型的部分层在端侧执行(特征提取),部分层在云端执行(复杂分类),通过网络传输中间特征表示。这种方案在保护数据隐私(端侧处理敏感数据)和降低延迟之间取得平衡。同时,Intel OpenVINO、Qualcomm SNPE和Apple ANE等异构计算框架利用CPU/GPU/NPU/DSP的协同工作,使边缘设备能够运行越来越复杂的大模型。

4. 大模型部署挑战

随着LLM(大语言模型)的兴起,百亿甚至千亿参数模型的部署成为新的挑战。模型并行(Model Parallelism)、流水线并行(Pipeline Parallelism)、张量并行(Tensor Parallelism)等分布式推理技术成为必要。vLLM、DeepSpeed MII和TensorRT-LLM等大模型推理引擎通过PagedAttention、Continuous Batching和KVCache管理等技术创新,大幅降低了大模型推理的显存占用和延迟,推动了LLM的商业化落地。

实践建议

  • 从小处着手: 先部署一个简单模型跑通端到端流程,再逐步增加复杂度
  • 性能基准测试: 每次优化后都要进行性能基准测试,量化加速效果和精度损失
  • 模型监控: 部署后持续监控模型推理延迟、吞吐量、内存占用和预测漂移
  • 自动化流水线: 建立自动化的模型训练 -> 转换 -> 验证 -> 部署的CI/CD流水线
  • 文档化: 记录每次模型部署的配置、参数和性能指标,便于问题排查和优化迭代