深度学习之格式转换笔记(二):CKPT 转换成 PB格式文件

简介: 将TensorFlow的CKPT模型格式转换为PB格式文件,包括保存模型的代码示例和将ckpt固化为pb模型的详细步骤。

我们使用tf.train.saver()保存模型时会产生多个文件,也就是说把计算图的结构和图上参数取值分成了不同的文件存储。这也是在tensorflow中常用的保存方式。

保存文件的代码:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    sess.run(init_op)
    print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
    print("v2:", sess.run(v2))
    saver_path = saver.save(sess, "save/model.ckpt-510")  # 将模型保存到save/model.ckpt-510文件
    print("Model saved in file:", saver_path)

这时候我们就可以看到结果
在这里插入图片描述
其中

  • checkpoint:检查点文件,文件保存了一个目录下所有的模型文件列表;
  • model.ckpt-510.meta:保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被
    tf.train.import_meta_graph 加载到当前默认的图来使用。
  • ckpt-510.data : 保存模型中每个变量的取值
  • ckpt-510.index:可能是内部需要的某种索引来正确映射前两个文件,它通常不是必需的

真正部署的时候,一般人家不会给你ckpt模型的,而是固化成pb模型以后再给你用,现在我们就来看看怎么将ckpt固化成pb模型。

实际完整代码:

# -*-coding: utf-8 -*-
import os
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util

resize_height = 299  # 指定图片高度
resize_width = 299  # 指定图片宽度
depths = 3

def freeze_graph_test(pb_path, image_path):
    '''
    :param pb_path:pb文件的路径
    :param image_path:测试图片的路径
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            # 定义输入的张量名称,对应网络结构的输入张量,往往是通过tf.placeholder调用的。
            # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
            input_image_tensor = sess.graph.get_tensor_by_name("input:0")
            input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
            input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")

            # 定义输出的张量名称
            output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

            # 读取测试图片
            im = read_image(image_path, resize_height, resize_width, normalization=True)
            im = im[np.newaxis, :]
            # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
            # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
            out = sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                                          input_keep_prob_tensor: 1.0,
                                                          input_is_training_tensor: False})
            print("out:{}".format(out))
            score = tf.nn.softmax(out, name='pre')
            class_id = tf.argmax(score, 1)
            print(
            "pre class_id:{}".format(sess.run(class_id)))

def freeze_graph(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

        # for op in sess.graph.get_operations():
        #     print(op.name, op.values())

def freeze_graph2(input_checkpoint, output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

        # for op in graph.get_operations():
        #     print(op.name, op.values())

if __name__ == '__main__':
    # 输入ckpt模型路径
    input_checkpoint = 'D:/pycharm/CarPlateIdentity-master/carIdentityData/model1/char_recongnize/model.ckpt-510'
    # 输出pb模型的路径
    out_dirpath = 'D:/pycharm/CarPlateIdentity-master/carIdentityData/model1/char_recongnize/pb/'
    os.makedirs(os.path.dirname(out_dirpath),exist_ok=True)
    out_pb_path = out_dirpath+"frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint, out_pb_path)
    print("the success cover")
    # 测试pb模型
    # image_path = 'test_image/animal.jpg'
    # freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

在将ckpt转换为pd过程中,会依据输出节点来丢弃那些与输出节点无关的参数,只保留与输出节点存在上下文关系的参数,这也就是生成pd文件的意义所在,即通过减少参数量降低模型的大小,所以在生成pd的过程中需要明确指定输出节点是谁,这样才能确定其依赖的需要固化的上下文参数。

目录
相关文章
|
2天前
|
机器学习/深度学习 vr&ar
深度学习笔记(十):深度学习评估指标
关于深度学习评估指标的全面介绍,涵盖了专业术语解释、一级和二级指标,以及各种深度学习模型的性能评估方法。
7 0
深度学习笔记(十):深度学习评估指标
|
2天前
|
机器学习/深度学习 Python
深度学习笔记(九):神经网络剪枝(Neural Network Pruning)详细介绍
神经网络剪枝是一种通过移除不重要的权重来减小模型大小并提高效率的技术,同时尽量保持模型性能。
8 0
深度学习笔记(九):神经网络剪枝(Neural Network Pruning)详细介绍
|
1天前
|
机器学习/深度学习 编解码 计算机视觉
深度学习笔记(十一):各种特征金字塔合集
这篇文章详细介绍了特征金字塔网络(FPN)及其变体PAN和BiFPN在深度学习目标检测中的应用,包括它们的结构、特点和代码实现。
5 0
|
2天前
|
机器学习/深度学习 数据可视化 Windows
深度学习笔记(七):如何用Mxnet来将神经网络可视化
这篇文章介绍了如何使用Mxnet框架来实现神经网络的可视化,包括环境依赖的安装、具体的代码实现以及运行结果的展示。
9 0
|
2天前
|
机器学习/深度学习 Python
深度学习笔记(六):如何运用梯度下降法来解决线性回归问题
这篇文章介绍了如何使用梯度下降法解决线性回归问题,包括梯度下降法的原理、线性回归的基本概念和具体的Python代码实现。
10 0
|
3天前
|
机器学习/深度学习 边缘计算 人工智能
探讨深度学习在图像识别中的应用及优化策略
【10月更文挑战第5天】探讨深度学习在图像识别中的应用及优化策略
14 1
|
8天前
|
机器学习/深度学习 人工智能 数据可视化
深度学习在图像识别中的应用与挑战
本文将深入探讨深度学习技术在图像识别领域的应用,并揭示其背后的原理和面临的挑战。我们将通过代码示例来展示如何利用深度学习进行图像识别,并讨论可能遇到的问题和解决方案。
31 3
|
3天前
|
机器学习/深度学习 存储 数据处理
深度学习在图像识别中的应用与挑战
【10月更文挑战第5天】 本文旨在探讨深度学习技术在图像识别领域的应用及其所面临的挑战。随着深度学习技术的飞速发展,其在图像识别中的应用日益广泛,不仅推动了相关技术的革新,也带来了新的挑战。本文首先介绍了深度学习的基本原理和常见模型,然后详细探讨了卷积神经网络(CNN)在图像识别中的具体应用,包括图像分类、目标检测等任务。接着,分析了当前深度学习在图像识别中面临的主要挑战,如数据标注问题、模型泛化能力、计算资源需求等。最后,提出了一些应对这些挑战的可能方向和策略。通过综合分析,本文希望为深度学习在图像识别领域的进一步研究和应用提供参考和启示。
|
3天前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习在图像识别中的应用与挑战
【10月更文挑战第5天】本文将深入探讨深度学习技术在图像识别领域的应用和面临的挑战。我们将从基础的神经网络模型出发,逐步介绍卷积神经网络(CNN)的原理和结构,并通过代码示例展示其在图像分类任务中的实际应用。同时,我们也将讨论深度学习在图像识别中遇到的一些常见问题和解决方案,以及未来的发展方向。
14 4
|
1天前
|
机器学习/深度学习 人工智能 算法框架/工具
深度学习中的卷积神经网络(CNN)及其在图像识别中的应用
【10月更文挑战第7天】本文将深入探讨卷积神经网络(CNN)的基本原理,以及它如何在图像识别领域中大放异彩。我们将从CNN的核心组件出发,逐步解析其工作原理,并通过一个实际的代码示例,展示如何利用Python和深度学习框架实现一个简单的图像分类模型。文章旨在为初学者提供一个清晰的入门路径,同时为有经验的开发者提供一些深入理解的视角。