使用TensorFlow提供的slim模型来训练数据模型供iOS使用

简介: 使用slim模型来训练数据供移动端使用 1、  数据可以是slim提供的数据集或者是自己采集的图片 1.1、下载slim提供的数据集flowers 1.1.1、设置下载目录命令: DATA_DIR=/Users/javalong/Desktop/Test/output/flowers 1.


1、下载slim模型包

cd /Users/javalong/Download
git clone https://github.com/tensorflow/models/


2、  数据可以是slim提供的数据集或者是自己采集的图片


2.1、下载slim提供的数据集flowers

2.1.1、设置下载目录命令:

DATA_DIR=/Users/javalong/Desktop/Test/output/flowers


2.1.2、进入到slim模型目录命令:

cd /Users/javalong/Downloads/models-master/slim


2.1.3、下载数据集命令:

python3 download_and_convert_data.py \

    --dataset_name=flowers \

    --dataset_dir="${DATA_DIR}"


2.1.4、查看目录下的文件命令:

ls ${DATA_DIR}


得到:

flowers_train-00000-of-00005.tfrecord

...

flowers_train-00004-of-00005.tfrecord

flowers_validation-00000-of-00005.tfrecord

...

flowers_validation-00004-of-00005.tfrecord

labels.txt


2.2、我们可以看到下载slim提供的数据文件是tfrecord格式,所以我们要训练自己采集的图片,第一步先将图片转换成tfrecord格式。


2.2.1、将图片转换成TFRecord文件,需要安装的软件


pip3 install Pillow

pip3 install matplotlib


2.2.2、在/Users/javalong/Downloads/models-master/slim下创建一个fu_img_to_tfrecord.py文件。

如图:

a665c93bb752afbd2964b75354a376c14c42eea6


2.2.3、fu_img_to_tfrecord.py的内容为:


import os 
import os.path 
import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt 
import sys
import pprint
pp = pprint.PrettyPrinter(indent = 2)

data_dir=sys.argv[1]
train_dir=sys.argv[2]
classes=[]
for dir in os.listdir(data_dir):
    path = os.path.join(data_dir, dir)
    if os.path.isdir(path):
        classes.append(dir)


train= tf.python_io.TFRecordWriter(train_dir+"/iss_train.tfrecord") 
test= tf.python_io.TFRecordWriter(train_dir+"/iss_test.tfrecord") 


def int64_feature(values):
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

def image_to_tfexample(image_data, image_format, height, width, class_id):
    return tf.train.Example(features=tf.train.Features(feature={ 
        'image/encoded': bytes_feature(image_data),
        'image/format': bytes_feature(image_format),
        'image/class/label': int64_feature(class_id),
        'image/height': int64_feature(height),
        'image/width': int64_feature(width),
    }))

def get_extension(path):
    return os.path.splitext(path)[1] 

class ImageReader(object):
  """Helper class that provides TensorFlow image coding utilities."""

  def __init__(self):
    # Initializes function that decodes RGB JPEG data.
    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)

  def read_image_dims(self, sess, image_data):
    image = self.decode_jpeg(sess, image_data)
    return image.shape[0], image.shape[1]
  
  def decode_jpeg(self, sess, image_data):
    image = sess.run(self._decode_jpeg,
                     feed_dict={self._decode_jpeg_data: image_data})
    assert len(image.shape) == 3
    assert image.shape[2] == 3
    return image

def write_label_file(labels_to_class_names, dataset_dir,
                     filename='lables.txt'):
  """Writes a file with the list of class names.

  Args:
    labels_to_class_names: A map of (integer) labels to class names.
    dataset_dir: The directory in which the labels file should be written.
    filename: The filename where the class names are written.
  """
  labels_filename = os.path.join(dataset_dir, filename)
  with tf.gfile.Open(labels_filename, 'w') as f:
    for label in labels_to_class_names:
      class_name = labels_to_class_names[label]
      f.write('%d:%s\n' % (label, class_name))

lable_file=train_dir+'/lable.txt'
lable_input=open(lable_file, 'w')

info_file=train_dir+'/meta_info.txt'
test_num=0;
train_num=0;

with tf.Graph().as_default():
    image_reader = ImageReader()
    with tf.Session('') as sess: 

        for index,name in enumerate(classes):
            lable_input.write('%d:%s\n' % (index, name))  
            class_path=data_dir+'/'+name+'/'
            for num, img_name in enumerate(os.listdir(class_path)): 
                img_path=class_path+img_name 
                
                format=get_extension(img_name)
                image_data = tf.gfile.FastGFile(img_path, 'rb').read()
                height, width = image_reader.read_image_dims(sess, image_data)
                example = image_to_tfexample(image_data, b'jpg', height, width, index)
                if num % 5 == 0:
                    test_num= test_num+1
                    #pass
                    #print img_path + " " + str(index) + " " + name
                    test.write(example.SerializeToString()) 
                else:
                    train_num=train_num+1
                    train.write(example.SerializeToString())
                    #print img_path + " " + str(index) + " " + name

train.close()
test.close()

info_input=open(info_file,'w')
info_input.write("train_num:"+str(train_num)+'\n')
info_input.write("test_num:"+str(test_num)+'\n')
info_input.close()

lable_input.close()



2.2.4、执行转换命令:

python3 /Users/javalong/Downloads/models-master/slim/fu_img_to_tfrecord.py /Users/javalong/Desktop/flowers /Users/javalong/Desktop/flower_record


注:

2.2.5/Users/javalong/Desktop/flowers是存放采集的图片,如图:

a9fb532a0a37402e2a2063e65e8518763175e0e4


2.2.6/Users/javalong/Desktop/flower_record是生成的tfrecord格式文件存放目录。最终生成的文件如图:

f2ff611499c9a63665c7f550c5804f95e8246afd


2.2.7使用/Users/javalong/Desktop/flowers目录的子目录名作为分类文本会存储到生成的label.txt中。如图:

e41d2c85d9b9be759fa8bca267c6ee00e0e272be


2.2.8fu_img_to_tfrecord.py功能实现参考/Users/javalong/Downloads/models-master/slim/datasets/download_and_convert_flowers.py文件


3、用预训练数据集inception_v3来训练数据集flowers

3.1、设置相应的目录:

DATASET_DIR=/Users/javalong/Desktop/Test/output/flowers

CHECKPOINT_PATH=/Users/javalong/Desktop/Test/output/inception/inception_v3.ckpt

TRAIN_DIR=/Users/javalong/Desktop/Test/output/tran


3.2、训练命令:

python3 train_image_classifier.py \

    --train_dir=${TRAIN_DIR} \

    --dataset_dir=${DATASET_DIR} \

    --dataset_name=flowers \

    --dataset_split_name=train \

    --model_name=inception_v3 \

    --checkpoint_path=${CHECKPOINT_PATH} \

    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \

    --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \

    --clone_on_cpu=true


4、生成.pb文件

4.1、在/Users/javalong/Downloads/models-master/slim下创建一个bbb.py文件。

如图:

c3fc804bf05dacac52d013da34c160cdeb47c057


4.2、bbb.py的内容为:


import os
import tensorflow as tf
import tensorflow.contrib.slim as slim

from nets import inception
from nets import inception_v1
from nets import inception_v3
from nets import nets_factory

from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from google.protobuf import text_format


checkpoint_path = tf.train.latest_checkpoint('/Users/javalong/Desktop/Test/output/tran')
with tf.Graph().as_default() as graph:
    input_tensor = tf.placeholder(tf.float32, shape=(None, 299, 299, 3), name='input_image')
    with tf.Session() as sess:
      #  with tf.variable_scope('model') as scope:
            with slim.arg_scope(inception.inception_v3_arg_scope()):
                logits, end_points = inception.inception_v3(input_tensor, num_classes=5, is_training=False)

    saver = tf.train.Saver()
    saver.restore(sess, checkpoint_path)

    output_node_names = 'InceptionV3/Predictions/Reshape_1'
     
    input_graph_def = graph.as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(","))
    with open('/Users/javalong/Desktop/Test/output/output_graph_nodes.txt', 'w') as f:
        f.write(text_format.MessageToString(output_graph_def)) 

    output_graph = '/Users/javalong/Desktop/Test/output/inception_v3_final.pb'
    with gfile.FastGFile(output_graph, 'wb') as f:
        f.write(output_graph_def.SerializeToString())




5优化模型并去掉iOS不支持的算子 


查考此篇文章


目录
相关文章
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
将Keras训练好的.hdf5模型转换为TensorFlow的.pb模型,然后再转换为TensorRT支持的.uff格式,并提供了转换代码和测试步骤。
100 3
深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式
|
2月前
|
并行计算 Shell TensorFlow
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
在使用TensorFlow-GPU训练MTCNN时,如果遇到“Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED”错误,通常是由于TensorFlow、CUDA和cuDNN版本不兼容或显存分配问题导致的,可以通过安装匹配的版本或在代码中设置动态显存分配来解决。
53 1
Tensorflow-GPU训练MTCNN出现错误-Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
|
2月前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
50 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
16天前
|
机器学习/深度学习 数据采集 数据可视化
TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤
本文介绍了 TensorFlow,一款由谷歌开发的开源深度学习框架,详细讲解了使用 TensorFlow 构建深度学习模型的步骤,包括数据准备、模型定义、损失函数与优化器选择、模型训练与评估、模型保存与部署,并展示了构建全连接神经网络的具体示例。此外,还探讨了 TensorFlow 的高级特性,如自动微分、模型可视化和分布式训练,以及其在未来的发展前景。
39 5
|
26天前
|
机器学习/深度学习 人工智能 算法
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
手写数字识别系统,使用Python作为主要开发语言,基于深度学习TensorFlow框架,搭建卷积神经网络算法。并通过对数据集进行训练,最后得到一个识别精度较高的模型。并基于Flask框架,开发网页端操作平台,实现用户上传一张图片识别其名称。
69 0
【手写数字识别】Python+深度学习+机器学习+人工智能+TensorFlow+算法模型
|
26天前
|
机器学习/深度学习 人工智能 算法
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
蔬菜识别系统,本系统使用Python作为主要编程语言,通过收集了8种常见的蔬菜图像数据集('土豆', '大白菜', '大葱', '莲藕', '菠菜', '西红柿', '韭菜', '黄瓜'),然后基于TensorFlow搭建卷积神经网络算法模型,通过多轮迭代训练最后得到一个识别精度较高的模型文件。在使用Django开发web网页端操作界面,实现用户上传一张蔬菜图片识别其名称。
73 0
基于深度学习的【蔬菜识别】系统实现~Python+人工智能+TensorFlow+算法模型
|
1月前
|
机器学习/深度学习 人工智能 算法
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
车辆车型识别,使用Python作为主要编程语言,通过收集多种车辆车型图像数据集,然后基于TensorFlow搭建卷积网络算法模型,并对数据集进行训练,最后得到一个识别精度较高的模型文件。再基于Django搭建web网页端操作界面,实现用户上传一张车辆图片识别其类型。
79 0
【车辆车型识别】Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+算法模型
|
3月前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
114 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
2月前
|
机器学习/深度学习 移动开发 TensorFlow
深度学习之格式转换笔记(四):Keras(.h5)模型转化为TensorFlow(.pb)模型
本文介绍了如何使用Python脚本将Keras模型转换为TensorFlow的.pb格式模型,包括加载模型、重命名输出节点和量化等步骤,以便在TensorFlow中进行部署和推理。
107 0
|
4月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
88 0