【tensorflow】TF1.x保存与读取.pb模型写法介绍

简介: 由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。

由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。


举例:模型定义如下


# 定义模型
with tf.name_scope("Model"):
    """MLP"""
    # 13个连续特征数据(13列)
    x = tf.placeholder(tf.float32, [None,13], name='X') 
    # 正则化
    x_norm = tf.layers.batch_normalization(inputs=x)
    # 定义一层Dense
    dense_1 = tf.layers.Dense(64, activation="relu")(x_norm)
    """EMBED"""
    # 离散输入
    y = tf.placeholder(tf.int32, [None,2], name='Y')
    # 创建嵌入矩阵变量
    embedding_matrix = tf.Variable(tf.random_uniform([len(vocab_dict) + 1, 8], -1.0, 1.0))
    # 使用tf.nn.embedding_lookup函数获取嵌入向量
    embeddings = tf.nn.embedding_lookup(embedding_matrix, y)
    # 创建 LSTM 层
    lstm_cell = tf.nn.rnn_cell.LSTMCell(64)
    # 初始化 LSTM 单元状态
    initial_state = lstm_cell.zero_state(tf.shape(embeddings)[0], tf.float32)
    # 将输入数据传递给 LSTM 层
    lstm_out, _ = tf.nn.dynamic_rnn(lstm_cell, embeddings, initial_state=initial_state)
    # 定义一层Dense
    dense_2 = tf.layers.Dense(64, activation="relu")(lstm_out[:, -1, :])
    """MERGE"""
    combined = tf.concat([dense_1, dense_2], axis = -1)
    pred = tf.layers.Dense(2, activation="relu")(combined)
    pred = tf.layers.Dense(1, activation="linear", name='P')(pred)
    z = tf.placeholder(tf.float32, [None, 1], name='Z')


  虽然写这么多,但是上面模型的输入只有xyz,输出只有pred。所以我们保存、加载模型时,只用考虑这几个变量就可以。


模型保存代码


import tensorflow as tf
from tensorflow import saved_model as sm
# 创建 Saver 对象
saver = tf.train.Saver()
# 生成会话,训练STEPS轮
with tf.Session() as sess:
    # 初始化参数
    sess.run(tf.global_variables_initializer())
    ...... # 模型训练逻辑
    # 准备存储模型
    path = 'pb_model/'
    dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    pb_saver = tf.train.Saver(dense_model_var)
    builder = sm.builder.SavedModelBuilder(path)
    # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
    # 自定义 根据自己的模型来写
    X = sm.utils.build_tensor_info(x)
    Y = sm.utils.build_tensor_info(y)
    Z = sm.utils.build_tensor_info(z)
    P = sm.utils.build_tensor_info(pred)
    # 构建 SignatureDef protobuf
    # inputs outputs 自定义 根据自己的模型来写
    SignatureDef = sm.signature_def_utils.build_signature_def(
                                inputs={'X': X, 'Y': Y, 'Z': Z},  # 可用sm.signature_constants.PREDICT_INPUTS
                                outputs={'P': P},  # 可用sm.signature_constants.PREDICT_OUTPUTS
                                method_name="tensorflow/serving/predict"
    )
    # 将 graph 和变量等信息写入 MetaGraphDef protobuf
    # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,也可用tf里预设好的方便统一
    builder.add_meta_graph_and_variables(sess, tags=['serve'],
                                             signature_def_map={
                                                 sm.signature_constants.PREDICT_METHOD_NAME: SignatureDef},
                                             saver=pb_saver,
                                             main_op=tf.local_variables_initializer())
    # 将 MetaGraphDef 写入磁盘
    builder.save()


 最重要的是这一句:dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),意思是保存当前作用域下的所有可训练的变量。


 我之前写的是dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name_scope="Model"),这样读不了所有的可训练变量,只能读到embedding_matrix 一个,虽然也能保存模型,但是没保存模型的其他变量值,就会出错。


模型加载代码


import tensorflow as tf
from tensorflow import saved_model as sm
tf.reset_default_graph()
# 创建一个新的默认图
graph = tf.Graph()
# 需要建立一个会话对象,将模型恢复到其中
with tf.Session(graph=graph) as sess:
    path = 'pb_model/'
    MetaGraphDef = sm.loader.load(sess, tags=['serve'], export_dir=path)
    # 解析得到 SignatureDef protobuf
    SignatureDef_map = MetaGraphDef.signature_def
    SignatureDef = SignatureDef_map[sm.signature_constants.PREDICT_METHOD_NAME]
    # 解析得到 3 个变量对应的 TensorInfo protobuf
    X = SignatureDef.inputs['X']
    Y = SignatureDef.inputs['Y']
    Z = SignatureDef.inputs['Z']
    P = SignatureDef.outputs['P']
    # 解析得到具体 Tensor
    # .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
    # x = sm.utils.get_tensor_from_tensor_info(X)
    # y = sm.utils.get_tensor_from_tensor_info(Y)
    # z = sm.utils.get_tensor_from_tensor_info(Z)
    x = sess.graph.get_tensor_by_name(X.name)
    y = sess.graph.get_tensor_by_name(Y.name)
    z = sess.graph.get_tensor_by_name(Z.name)
    p = sess.graph.get_tensor_by_name(P.name)
    # 这里就可以开始进行预测或者继续训练了 TODO
    total_loss = sess.run(loss_function, feed_dict={x: dense_ch_val, y: sparse_ch_val, z: score_val})
    print(total_loss)



相关文章
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
490 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
843 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
1151 5
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
875 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
923 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
559 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
机器学习/深度学习 API TensorFlow
深入解析TensorFlow 2.x中的Keras API:快速搭建深度学习模型的实战指南
【8月更文挑战第31天】本文通过搭建手写数字识别模型的实例,详细介绍了如何利用TensorFlow 2.x中的Keras API简化深度学习模型构建流程。从环境搭建到数据准备,再到模型训练与评估,展示了Keras API的强大功能与易用性,适合初学者快速上手。通过简单的代码,即可完成卷积神经网络的构建与训练,显著降低了深度学习的技术门槛。无论是新手还是专业人士,都能从中受益,高效实现模型开发。
518 1
|
机器学习/深度学习 TensorFlow 算法框架/工具
全面解析TensorFlow Lite:从模型转换到Android应用集成,教你如何在移动设备上轻松部署轻量级机器学习模型,实现高效本地推理
【8月更文挑战第31天】本文通过技术综述介绍了如何使用TensorFlow Lite将机器学习模型部署至移动设备。从创建、训练模型开始,详细演示了模型向TensorFlow Lite格式的转换过程,并指导如何在Android应用中集成该模型以实现预测功能,突显了TensorFlow Lite在资源受限环境中的优势及灵活性。
2069 1
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow Serving 部署指南超赞!让机器学习模型上线不再困难,轻松开启高效服务之旅!
【8月更文挑战第31天】TensorFlow Serving是一款高性能开源服务系统,专为部署机器学习模型设计。本文通过代码示例详细介绍其部署流程:从安装TensorFlow Serving、训练模型到配置模型服务器与使用gRPC客户端调用模型,展示了一站式模型上线解决方案,使过程变得简单高效。借助该工具,你可以轻松实现模型的实际应用。
765 1
|
机器学习/深度学习 存储 前端开发
实战揭秘:如何借助TensorFlow.js的强大力量,轻松将高效能的机器学习模型无缝集成到Web浏览器中,从而打造智能化的前端应用并优化用户体验
【8月更文挑战第31天】将机器学习模型集成到Web应用中,可让用户在浏览器内体验智能化功能。TensorFlow.js作为在客户端浏览器中运行的库,提供了强大支持。本文通过问答形式详细介绍如何使用TensorFlow.js将机器学习模型带入Web浏览器,并通过具体示例代码展示最佳实践。首先,需在HTML文件中引入TensorFlow.js库;接着,可通过加载预训练模型如MobileNet实现图像分类;然后,编写代码处理图像识别并显示结果;此外,还介绍了如何训练自定义模型及优化模型性能的方法,包括模型量化、剪枝和压缩等。
1115 1