tebsorflow2.0 多输出模型实例

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,5000CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: 1. 简单介绍2. 加载相关数据包2.1 图片的路径的配置2.2 读取图片3. 图片预处理4. 训练阶段4.1 设置验证集与数据集4.2 构建模型并训练5. 模型评估

1. 简单介绍

本文的应用场景多输入问题,采用的数据集有,'black_jeans','black_shoes','blue_dress','blue_jeans','blue_shirt','red_dress','red_shirt'七个类别,我们将根据颜色和衣服类型进行两类输出。

2. 加载相关数据包

import tensorflow as tf
print('Tensorflow version: {}'.format(tf.__version__))
from tensorflow import keras
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pathlib
import os
import random
Tensorflow version: 2.0.0
• 1

2.1 图片的路径的配置

#配置数据集路径
path_root = os.path.realpath(".")
data_dir = pathlib.Path(path_root)
#数量构成
image_count = len(list(data_dir.glob('dataset/*/*.jpg')))
label_names = sorted(item.name for item in data_dir.glob("dataset/*"))
#图片路径
all_image_path = list(list(data_dir.glob('dataset/*/*.jpg')))    
#将路径打乱
all_image_path = [str(path) for path in all_image_path]
random.shuffle(all_image_path)
color_label_names = set(name.split('_')[0] for name in label_names)
item_label_names = set(name.split('_')[1] for name in label_names)
color_lable_to_index = dict((name,index) for index,name in enumerate(color_label_names))
item_lable_to_index = dict((name,index) for index,name in enumerate(item_label_names))
color_label = [color_lable_to_index[pathlib.Path(path).parent.name.split("_")[0]] for path in all_image_path]
item_label = [item_lable_to_index[pathlib.Path(path).parent.name.split("_")[1]] for path in all_image_path]

2.2 读取图片

def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32)
    image /= 255.0  # normalize to [0,1] range
    image = 2*image-1
    return image
#加载图片
def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)
#举例
image_path = all_image_path[11]
color = color_label[11]
item = item_label[11]
plt.imshow((load_and_preprocess_image(image_path)+1)/2)
plt.grid(False)
##plt.xlabel(caption_image(image_path))
plt.title(str(list(color_lable_to_index.keys())[color])+"_"+str(list(item_lable_to_index.keys())[item]).title())
plt.axis("off")
plt.show()

3. 图片预处理

在这一部分我们采用from_tensor_slices的方法对图片数据集进行构建,对比tf1.x版本采用队列形式读取数据,这一种方法比较简单切易于理解。并构建(图片,标签)对数据集。

#%%构建一个tf.data.Dataset
#一个图片数据集构建 tf.data.Dataset 最简单的方法就是使用 from_tensor_slices 方法。
#将字符串数组切片,得到一个字符串数据集:
path_ds =  tf.data.Dataset.from_tensor_slices(all_image_path)
print(path_ds)
#现在创建一个新的数据集,通过在路径数据集上映射 preprocess_image 来动态加载和格式化图片。
AUTOTUNE = tf.data.experimental.AUTOTUNE
image_ds = path_ds.map(load_and_preprocess_image,num_parallel_calls=AUTOTUNE)
lable_ds = tf.data.Dataset.from_tensor_slices((color_label, item_label))
for color,item in lable_ds.take(5):
    print(str(list(color_lable_to_index.keys())[color])+"_"+str(list(item_lable_to_index.keys())[item]))
#%%构建一个(图片,标签)对数据集
#因为这些数据集顺序相同,可以将他们打包起来
image_label_ds = tf.data.Dataset.zip((image_ds,lable_ds))
print(image_label_ds)
<TensorSliceDataset shapes: (), types: tf.string>
blue_dress
blue_dress
blue_shirt
red_dress
blue_shirt
<ZipDataset shapes: ((224, 224, 3), ((), ())), types: (tf.float32, (tf.int32, tf.int32))>

4. 训练阶段

4.1 设置验证集与数据集

#%%设置训练数据和验证集数据的大小
test_count = int(image_count*0.2)
train_count = image_count - test_count
print(test_count,train_count)
#跳过test_count个
train_dataset = image_label_ds.skip(test_count)
test_dataset = image_label_ds.take(test_count)
batch_size = 16
# 设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据被充分打乱。
train_ds = train_dataset.shuffle(buffer_size=train_count).repeat().batch(batch_size)
test_ds = test_dataset.batch(batch_size)
419 1677
• 1

4.2 构建模型并训练

mobile_net = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), 
                                               include_top=False,
                                               weights='imagenet')
mobile_net.trianable = False
inputs = tf.keras.Input(shape=(224, 224, 3))
x = mobile_net(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x1 = tf.keras.layers.Dense(1024, activation='relu')(x)
out_color = tf.keras.layers.Dense(len(color_label_names), 
                                  activation='softmax',
                                  name='out_color')(x1)
x2 = tf.keras.layers.Dense(1024, activation='relu')(x)
out_item = tf.keras.layers.Dense(len(item_label_names), 
                                 activation='softmax',
                                 name='out_item')(x2)
model = tf.keras.Model(inputs=inputs,
                       outputs=[out_color, out_item])
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
mobilenetv2_1.00_224 (Model)    (None, 7, 7, 1280)   2257984     input_2[0][0]                    
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1280)         0           mobilenetv2_1.00_224[1][0]       
__________________________________________________________________________________________________
dense (Dense)                   (None, 1024)         1311744     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1024)         1311744     global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
out_color (Dense)               (None, 3)            3075        dense[0][0]                      
__________________________________________________________________________________________________
out_item (Dense)                (None, 4)            4100        dense_1[0][0]                    
==================================================================================================
Total params: 4,888,647
Trainable params: 4,854,535
Non-trainable params: 34,112
__________________________________________________________________________________________________
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss={'out_color':'sparse_categorical_crossentropy',
                    'out_item':'sparse_categorical_crossentropy'},
              metrics=['acc']
)
steps_per_eooch = train_count//batch_size
validation_steps = test_count//batch_size
history = model.fit(train_ds,epochs=3,steps_per_epoch=steps_per_eooch,validation_data=test_ds,validation_steps=validation_steps)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss={'out_color':'sparse_categorical_crossentropy',
                    'out_item':'sparse_categorical_crossentropy'},
              metrics=['acc']
)
steps_per_eooch = train_count//batch_size
validation_steps = test_count//batch_size
history = model.fit(train_ds,epochs=3,steps_per_epoch=steps_per_eooch,validation_data=test_ds,validation_steps=validation_steps)
Train for 104 steps, validate for 26 steps
Epoch 1/3
104/104 [==============================] - 103s 991ms/step - loss: 0.4240 - out_color_loss: 0.2223 - out_item_loss: 0.2017 - out_color_acc: 0.9123 - out_item_acc: 0.9399 - val_loss: 0.1177 - val_out_color_loss: 0.0892 - val_out_item_loss: 0.0284 - val_out_color_acc: 0.9688 - val_out_item_acc: 0.9880
Epoch 2/3
104/104 [==============================] - 86s 825ms/step - loss: 0.0716 - out_color_loss: 0.0404 - out_item_loss: 0.0312 - out_color_acc: 0.9838 - out_item_acc: 0.9916 - val_loss: 0.0842 - val_out_color_loss: 0.0412 - val_out_item_loss: 0.0430 - val_out_color_acc: 0.9856 - val_out_item_acc: 0.9880
Epoch 3/3
104/104 [==============================] - 87s 833ms/step - loss: 0.0416 - out_color_loss: 0.0170 - out_item_loss: 0.0246 - out_color_acc: 0.9952 - out_item_acc: 0.9922 - val_loss: 0.0546 - val_out_color_loss: 0.0309 - val_out_item_loss: 0.0238 - val_out_color_acc: 0.9904 - val_out_item_acc: 0.9904

5. 模型评估

model.evaluate(test_ds)
27/27 [==============================] - 9s 347ms/step - loss: 0.0526 - out_color_loss: 0.0297 - out_item_loss: 0.0229 - out_color_acc: 0.9905 - out_item_acc: 0.9905
[0.052628928653171494, 0.02974791, 0.022881018, 0.9904535, 0.9904535]

我们在网上找一个图片,检验其准确性

str(list(data_dir.glob('dataset/*.jpg'))[0])
my_image = load_and_preprocess_image(str(list(data_dir.glob('dataset/*.jpg'))[0]))
my_image = tf.expand_dims(my_image, 0)
pred = model.predict(my_image)
plt.grid(False)
plt.imshow((tf.squeeze(my_image,axis=0)+1)/2)
plt.title(str(list(color_lable_to_index.keys())[np.argmax(pred[0])])+"_"+str(list(item_lable_to_index.keys())[np.argmax(pred[1])]).title())
plt.axis("off")
plt.show()

相关文章
|
7月前
|
自然语言处理
在ModelScope中,你可以通过设置模型的参数来控制输出的阈值
在ModelScope中,你可以通过设置模型的参数来控制输出的阈值
184 1
|
机器学习/深度学习 PyTorch TensorFlow
TensorRT 模型加速——输入、输出、部署流程
本文首先简要介绍 Tensor RT 的输入、输出以及部署流程,了解 Tensor RT 在部署模型中起到的作用。然后介绍 Tensor RT 模型导入流程,针对不同的深度学习框架,使用不同的方法导入模型。
1319 1
|
2月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
302 2
|
4月前
|
JavaScript 前端开发 开发者
数据输出方法
【8月更文挑战第30天】
46 3
|
4月前
|
消息中间件 网络协议 JavaScript
函数计算产品使用问题之删除应用重建后,如何快速生成之前的模型和参数
函数计算产品作为一种事件驱动的全托管计算服务,让用户能够专注于业务逻辑的编写,而无需关心底层服务器的管理与运维。你可以有效地利用函数计算产品来支撑各类应用场景,从简单的数据处理到复杂的业务逻辑,实现快速、高效、低成本的云上部署与运维。以下是一些关于使用函数计算产品的合集和要点,帮助你更好地理解和应用这一服务。
|
5月前
|
算法
创建一个训练函数
【7月更文挑战第22天】创建一个训练函数。
35 4
|
5月前
|
人工智能 监控 Serverless
函数计算产品使用问题之sdXL 1.0模型启动无效,该怎么办
实时计算Flink版作为一种强大的流处理和批处理统一的计算框架,广泛应用于各种需要实时数据处理和分析的场景。实时计算Flink版通常结合SQL接口、DataStream API、以及与上下游数据源和存储系统的丰富连接器,提供了一套全面的解决方案,以应对各种实时计算需求。其低延迟、高吞吐、容错性强的特点,使其成为众多企业和组织实时数据处理首选的技术平台。以下是实时计算Flink版的一些典型使用合集。
YOLOv8打印模型结构配置信息并查看网络模型详细参数:参数量、计算量(GFLOPS)
YOLOv8打印模型结构配置信息并查看网络模型详细参数:参数量、计算量(GFLOPS)
|
7月前
|
机器学习/深度学习 存储 编解码
了解FastSam:一个通用分割模型(草记)(1)
一、FastSam下载与体验 1 问题记录 似乎从网页上下载压缩包,会比使用git clone要方便很多。 1 CLIP是什么?
402 0
|
7月前
|
机器学习/深度学习 并行计算 计算机视觉
了解FastSam:一个通用分割模型(草记)(2)
2 Sam相关项目 阅读:Segment Anything(sam)项目整理汇总 新鲜名词:点云分割, 有趣的项目:
299 0

热门文章

最新文章