tensorflow2.0图片分类实战---对fashion-mnist数据集分类

简介: tensorflow2.0图片分类实战---对fashion-mnist数据集分类

前言


其实写这篇博客的想法主要还是记载一些tf2.0常用api的用法以及如何简单快速的利用tf.keras搭建一个神经网络


1.首先讲讲tf.keras


有了它我们可以很轻松的搭建自己想搭建的网络模型,就像拼积木一样,一层一层的网络叠加起来。但是深层的网络会出现梯度消失等等问题,所以只是能搭建一个网络模型,对于模型的效果还需要一些其他知识方法来优化。对于fashion-mnist数据集的介绍可以看看下面的链接Github上fashion-mnist的介绍


2.再说说一般对于图像分类问题常用的优化方法


  • 1.图像数据的归一化(标准化):加快网络收敛,具体原理可以想象成同心圆沿着梯度到达圆心最快,而不正规的图形沿着梯度到达中心会很曲折


image.png


  • 2.数据特征增强:链接
  • 3.网络的超参数搜索:得到最好的模型参数,主要是网格搜索、随机搜索、遗传算法、启发式搜索
  • 4.dropout、earlystopping,正则化等方法的应用:通过添加遗忘层,正则化以及早停来防止模型过拟合


3.实现代码以及结果部分


#先导入一些常用库,后续用到再增加
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import sklearn
import os
import sys
#看一下版本,确认是2.0
print(tf.__version__)
复制代码


image.png


#使用keras自带的模块导入数据,并且切分训练集、验证集、测试集,对训练数据进行标准化处理
fashion_mnist=keras.datasets.fashion_mnist
(x_train_all,y_train_all),(x_test,y_test)=fashion_mnist.load_data()
print(x_train_all.shape)
print(y_train_all.shape)
print(x_test.shape)
print(y_test.shape)
#切分训练集和验证集
x_train,x_valid=x_train_all[5000:],x_train_all[:5000]
y_train,y_valid=y_train_all[5000:],y_train_all[:5000]
print(x_train.shape)
print(y_train.shape)
print(x_valid.shape)
print(y_valid.shape)
#标准化
from sklearn.preprocessing import StandardScaler
scaler=StandardScaler()
x_train_scaled=scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled=scaler.fit_transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled=scaler.fit_transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
复制代码
#可视化一下图片以及对应的标签
#展示多张图片
def show_imgs(n_rows,n_cols,x_data,y_data,class_names):
    assert len(x_data)==len(y_data)#判断输入数据的信息是否对应一致
    assert n_rows*n_cols<=len(x_data)#保证不会出现数据量不够
    plt.figure(figsize=(n_cols*2,n_rows*1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index=n_cols*row+col   #得到当前展示图片的下标
            plt.subplot(n_rows,n_cols,index+1)
            plt.imshow(x_data[index],cmap="binary",interpolation="nearest")
            plt.axis("off")
            plt.title(class_names[y_data[index]])
    plt.show()
class_names=['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
show_imgs(5,5,x_train,y_train,class_names)
复制代码


image.png

#搭建网络模型
model=keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=[28,28]))
model.add(keras.layers.Dense(300,activation="relu"))
model.add(keras.layers.Dense(100,activation="relu"))
model.add(keras.layers.Dense(10,activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy",optimizer="adam",metrics=["acc"])
model.summary()
复制代码


image.png


这里网络信息中params中的数字怎么来的呢? y=wx+b  然后根据矩阵相乘的规则从(None,784)到(None,300)中间的矩阵就是(784,300)然后偏置项b的大小是300,所以784300+300=235500,这是个小细节稍微提一下。


#训练,并且保存最好的模型、训练的记录以及使用早停防止过拟合
import datetime
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join('logs', current_time)
output_model=os.path.join(logdir,"fashionmnist_model.h5")
callbacks=[
    keras.callbacks.TensorBoard(log_dir=logdir),
    keras.callbacks.ModelCheckpoint(output_model,save_best_only=True),
    keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3)
          ]
history=model.fit(x_train_scaled,y_train,epochs=30,validation_data=(x_valid_scaled,y_valid),callbacks=callbacks)
复制代码


image.png

之前我用自己命名的文件夹使用TensorBoard和ModelCheckpoint运行会出错,搜了一下好像是windows上的bug,上面的这是一种解决方法,然后打开tensorboard看一下。


image.png


最好的模型也保存为h5文件,方便调用


def plot_learning_curves(history):
    pd.DataFrame(history.history).plot(figsize=(8,5))
    plt.grid()
    plt.gca().set_ylim(0,1)
    plt.show()
plot_learning_curves(history)
复制代码


这是自己绘制每次训练的变化情况,和上面的差不多


image.png


#最后在测试集上的准确率
loss,acc=model.evaluate(x_test_scaled,y_test,verbose=0)
print("在测试集上的损失为:",loss)
print("在测试集上的准确率为:",acc)
复制代码


image.png

#得到测试集上的预测标签,可视化和真实标签的区别
y_pred=model.predict(x_test_scaled)
predict = np.argmax(y_pred,axis=1) 
show_imgs(3,5,x_test,predict,class_names)
show_imgs(3,5,x_test,y_test,class_names)
复制代码


预测的结果

image.png

真实的结果


image.png


4.总结:


看了上面的例子,使用tf.keras搭建模型写法就是


model=keras.models.Sequential()
model.add(...)
model.add(...)
...
model.compile(...)
model.fit(...)
#当然也可以写成
model=keras.models.Sequential([
  ...
  ...
  ...
])
#这两者差别不大
#还有函数式的写法
inputs=...
hidden1=...(inputs)
....
#子类的写法
class ...:
  ...
复制代码


不过对于模型中的参数,比如损失函数的选择("sparse_categorical_crossentropy"与"categorical_crossentropy" 或者"binary_crossentropy")什么时候需要用到哪种损失函数最适合、每一层网络中的激活函数的选择、优化器的选择……都需要了解其中的含义才能在适当的场合使用,这里我没有给出使用超参数搜索得到最优模型参数的例子,下次应该会写一个关于超参数搜索的例子。

目录
相关文章
|
3月前
|
数据采集 TensorFlow 算法框架/工具
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
本教程详细介绍了如何使用TensorFlow 2.3训练自定义图像分类数据集,涵盖数据集收集、整理、划分及模型训练与测试全过程。提供完整代码示例及图形界面应用开发指导,适合初学者快速上手。[教程链接](https://www.bilibili.com/video/BV1rX4y1A7N8/),配套视频更易理解。
78 0
【大作业-03】手把手教你用tensorflow2.3训练自己的分类数据集
|
2月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
75 3
|
2月前
|
机器学习/深度学习 TensorFlow API
机器学习实战:TensorFlow在图像识别中的应用探索
【10月更文挑战第28天】随着深度学习技术的发展,图像识别取得了显著进步。TensorFlow作为Google开源的机器学习框架,凭借其强大的功能和灵活的API,在图像识别任务中广泛应用。本文通过实战案例,探讨TensorFlow在图像识别中的优势与挑战,展示如何使用TensorFlow构建和训练卷积神经网络(CNN),并评估模型的性能。尽管面临学习曲线和资源消耗等挑战,TensorFlow仍展现出广阔的应用前景。
86 5
|
2月前
|
机器学习/深度学习 人工智能 TensorFlow
基于TensorFlow的深度学习模型训练与优化实战
基于TensorFlow的深度学习模型训练与优化实战
125 0
|
4月前
|
机器学习/深度学习 数据挖掘 TensorFlow
解锁Python数据分析新技能,TensorFlow&PyTorch双引擎驱动深度学习实战盛宴
在数据驱动时代,Python凭借简洁的语法和强大的库支持,成为数据分析与机器学习的首选语言。Pandas和NumPy是Python数据分析的基础,前者提供高效的数据处理工具,后者则支持科学计算。TensorFlow与PyTorch作为深度学习领域的两大框架,助力数据科学家构建复杂神经网络,挖掘数据深层价值。通过Python打下的坚实基础,结合TensorFlow和PyTorch的强大功能,我们能在数据科学领域探索无限可能,解决复杂问题并推动科研进步。
81 0
|
5月前
|
API UED 开发者
如何在Uno Platform中轻松实现流畅动画效果——从基础到优化,全方位打造用户友好的动态交互体验!
【8月更文挑战第31天】在开发跨平台应用时,确保用户界面流畅且具吸引力至关重要。Uno Platform 作为多端统一的开发框架,不仅支持跨系统应用开发,还能通过优化实现流畅动画,增强用户体验。本文探讨了Uno Platform中实现流畅动画的多个方面,包括动画基础、性能优化、实践技巧及问题排查,帮助开发者掌握具体优化策略,提升应用质量与用户满意度。通过合理利用故事板、减少布局复杂性、使用硬件加速等技术,结合异步方法与预设缓存技巧,开发者能够创建美观且流畅的动画效果。
100 0
|
5月前
|
UED 存储 数据管理
深度解析 Uno Platform 离线状态处理技巧:从网络检测到本地存储同步,全方位提升跨平台应用在无网环境下的用户体验与数据管理策略
【8月更文挑战第31天】处理离线状态下的用户体验是现代应用开发的关键。本文通过在线笔记应用案例,介绍如何使用 Uno Platform 优雅地应对离线状态。首先,利用 `NetworkInformation` 类检测网络状态;其次,使用 SQLite 实现离线存储;然后,在网络恢复时同步数据;最后,通过 UI 反馈提升用户体验。
127 0
|
5月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
56 0
|
5月前
|
机器学习/深度学习 数据采集 TensorFlow
从零到精通:TensorFlow与卷积神经网络(CNN)助你成为图像识别高手的终极指南——深入浅出教你搭建首个猫狗分类器,附带实战代码与训练技巧揭秘
【8月更文挑战第31天】本文通过杂文形式介绍了如何利用 TensorFlow 和卷积神经网络(CNN)构建图像识别系统,详细演示了从数据准备、模型构建到训练与评估的全过程。通过具体示例代码,展示了使用 Keras API 训练猫狗分类器的步骤,旨在帮助读者掌握图像识别的核心技术。此外,还探讨了图像识别在物体检测、语义分割等领域的广泛应用前景。
53 0
|
1月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
286 55

热门文章

最新文章

相关实验场景

更多