Paddle 环境中 使用LeNet在MNIST数据集实现图像分类

简介: 测试了在AI Stuio中 使用LeNet在MNIST数据集实现图像分类 示例。基于可以搭建其他网络程序。

f9fa8d8305144ac899d46210f3f815f7.png

简 介: 测试了在AI Stuio中 使用LeNet在MNIST数据集实现图像分类 示例。基于可以搭建其他网络程序。


关键词 MNISTPaddleLeNet
  • 作者: PaddlePaddle
  • 日期: 2021.12
  • 摘要: 本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。

 

§

01 境配置


  教程基于Paddle 2.2 编写,如果你的环境不是本版本,请先参考官网安装 Paddle 2.2。

import paddle
print(paddle.__version__)
2.2.1

 

§

02 据加载


  写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为01。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist

一、加载mnist数据集合

  我们使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载。

from paddle.vision.transforms import Compose, Normalize

transform = Compose([Normalize(mean=[127.5],
                               std=[127.5],
                               data_format='CHW')])
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
download training data and load training data
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz 
Begin to download
item 8/8 [============================>.] - ETA: 0s - 4ms/item

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz 
Begin to download

Download finished
item  95/403 [======>.......................] - ETA: 0s - 2ms/item
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz 
Begin to download
item 2/2 [===========================>..] - ETA: 0s - 2ms/item

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz 
Begin to download

Download finished
load finished

二、查看数据图像

  取训练集中的一条数据看一下。

import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
train_data0 label is: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/numpy/lib/type_check.py:546: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  'a.item() instead', DeprecationWarning, stacklevel=1)

23be60d56ea04162b1c1f091349e1d37.png

▲ 图2.2.1 训练结合中的图片

 

§

03 立网络


一、利用SubClass组网

  paddle.nn下的API,如Conv2DMaxPool2DLinear完成LeNet的构建。

import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)
        self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
        self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
        self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1,stop_axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x

二、网络结构可视化

print(model.summary())
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Conv2D-11     [[64, 1, 28, 28]]     [64, 6, 28, 28]          156      
 MaxPool2D-11    [[64, 6, 28, 28]]     [64, 6, 14, 14]           0       
   Conv2D-12     [[64, 6, 14, 14]]     [64, 16, 10, 10]        2,416     
 MaxPool2D-12    [[64, 16, 10, 10]]     [64, 16, 5, 5]           0       
   Linear-16        [[64, 400]]           [64, 120]           48,120     
   Linear-17        [[64, 120]]            [64, 84]           10,164     
   Linear-18         [[64, 84]]            [64, 10]             850      
===========================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
Input size (MB): 0.19
Forward/backward pass size (MB): 3.95
Params size (MB): 0.24
Estimated Total Size (MB): 4.38

{'total_params': 61706, 'trainable_params': 61706}

 

§

04 络训练


一、基于高层API

  通过paddle提供的Model 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。

72bc6accba78d8e48927a258d83a2534.png

▲ 图4.1.1 Paddle 高层API

1、使用Model.fit完成模型训练

  方式1:基于高层API,完成模型的训练与预测

from paddle.metric import Accuracy
model = paddle.Model(LeNet())   # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

model.prepare(
    optim,
    paddle.nn.CrossEntropyLoss(),
    Accuracy()
    )
model.fit(train_dataset,
        epochs=2,
        batch_size=64,
        verbose=1
        )

2af3dfd79c5d46f3b9293296508d05df.gif

▲ 图4.1.1 训练两个周期

  运行时长:42.963秒结束时间:2021-12-12 23:08:16

2、使用Model.evaluate进行模型预测

model.evaluate(test_dataset, batch_size=64, verbose=1)

49c2be504ffd4381abd390e9af1f5980.png

▲ 图4.1.2 预测模型

{'loss': [0.0013720806], 'acc': 0.9848}

  以上就是方式一,可以快速、高效的完成网络模型训练与预测。

  上述训练过程是在普通CPU的模式下进行的。

3、模型的存储与获取

(1)存储模型参数

model.save('./work/lenet')

(2)建立新的模型

newmodel = paddle.Model(LeNet())

newmodel.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy())
newmodel.evaluate(test_dataset, batch_size=64, verbose=1)

  如果没有调用存储的信息,评估的结果为:

Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 4.2389 - acc: 0.0717 - 8ms/step          
Eval samples: 10000
{'loss': [4.238896], 'acc': 0.0717}

  调用参数之后,评估结果:

newmodel.load('./work/lenet')
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 157/157 [==============================] - loss: 2.5223e-04 - acc: 0.9883 - 8ms/step         
Eval samples: 10000
{'loss': [0.00025222846], 'acc': 0.9883}
运行时长:1.246秒结束时间:2021-12-13 23:25:45

二、基于基础API

  方式2:基于基础API,完成模型的训练与预测

1、模型训练

  组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。

import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
def train(model):
    model.train()
    epochs = 2
    optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
    # 用Adam作为优化函数
    for epoch in range(epochs):
        for batch_id, data in enumerate(train_loader()):
            x_data = data[0]
            y_data = data[1]
            predicts = model(x_data)
            loss = F.cross_entropy(predicts, y_data)
            # 计算损失
            acc = paddle.metric.accuracy(predicts, y_data)
            loss.backward()
            if batch_id % 300 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
            optim.step()
            optim.clear_grad()
model = LeNet()
train(model)

0d463aa5a72d46e48dbfa60669fd2c72.png

▲ 图4.2.1 训练过程和花费时间

  以上就是方式二,通过底层API,可以清楚的看到训练和测试中的每一步过程。但是,这种方式比较复杂。因此,我们提供了训练方式一,使用高层API来完成模型的训练与预测。对比底层API,高层API能够更加快速、高效的完成模型的训练与测试。

二、模型保存与重载

对于Layer型的模型进行保存,使用 paddle.save、paddle.load。进行

20210601234009436.png

▲ 图4.2.2 模型保存与重载

20210601234204836.png

▲ 图4.2.3 保存的方法

(1)save

参数存储时,先获取目标对象(Layer或者Optimzier)的state_dict,然后将state_dict存储至磁盘,示例如下(接前述示例):

paddle.save(layer.state_dict(), "linear_net.pdparams")
paddle.save(adam.state_dict(), "adam.pdopt")

(2)load

参数载入时,先从磁盘载入保存的state_dict,然后通过set_state_dict方法配置到目标对象中,示例如下(接前述示例):

layer_state_dict = paddle.load("linear_net.pdparams")
opt_state_dict = paddle.load("adam.pdopt")

layer.set_state_dict(layer_state_dict)
adam.set_state_dict(opt_state_dict)

三、切换环境训练

1、切换到至尊GPU环境

4041d5f65a614cd5824c9aeb00c2c63b.png

▲ 图4.2.2 切换到至尊版本环境训练

2、基于高级API训练

f4e37c93072d425ebbc0b64b98857b7a.png

▲ 图4.3.2 切换到至尊版本环境训练

运行时长:17.55秒结束时间:2021-12-12 23:18:21

3、基于基本API训练

a76c0ac9fdbe4f759aa03c1d4d58da6d.png

▲ 图4.3.3 训练过程

  运行时长:16.917秒结束时间:2021-12-12 23:19:40

e0d30f5f322240b880405e96d9f92c2c.png

▲ 图4.3.4 在高级版本下训练

9ab60177058049ce8ebeb8c9c9fcd5e8.png

▲ 图4.3.5 在高级版本下训练

 

  结 ※


  试了在AI Stuio中 使用LeNet在MNIST数据集实现图像分类 示例。基于可以搭建其他网络程序。

■ 相关文献链接:

使用LeNet在MNIST数据集实现图像分类

● 相关图表链接:

相关文章
【分享】Groovy时间戳转日期
在集成自动化中 通过Groovy处理时间戳,格式化日期输出。
1152 0
Mac 快速查找快捷键command+f失效解决办法
版权声明:本文为 testcs_dn(微wx笑) 原创文章,非商用自由转载-保持署名-注明出处,谢谢。 https://blog.csdn.net/testcs_dn/article/details/81214817 最近 Chrome 经常的遇到 Command + F 快捷键失效的问题,真是日了狗了。
3946 0
|
7月前
|
人工智能 自然语言处理 API
推荐几个常用免费的文本转语音工具
本文推荐了几款免费的文本转语音工具,包括功能全面的AI易视频、支持多语言的Google TTS、操作便捷的Natural Reader、离线使用的Balabolka以及轻量级的Speech2Go。其中AI易视频特别适合小说转语音,可智能分配角色音色,打造广播剧般的听觉体验。这些工具各具特色,能满足不同场景需求,助力内容创作更高效。
1803 5
|
6月前
|
Oracle 关系型数据库 数据库
Activiti 7建表语句及注释
本文提供了Activiti工作流引擎的数据库表结构,适用于Oracle和DM数据库。包含运行时与历史数据表的设计及字段注释,涵盖流程定义、实例、任务、变量、事件监听、附件、意见等核心功能模块。通过这些表结构,可以全面管理流程生命周期中的各类数据。
|
9月前
|
算法 Java
算法系列之数据结构-Huffman树
Huffman树(哈夫曼树)又称最优二叉树,是一种带权路径长度最短的二叉树,常用于信息传输、数据压缩等方面。它的构造基于字符出现的频率,通过将频率较低的字符组合在一起,最终形成一棵树。在Huffman树中,每个叶节点代表一个字符,而每个字符的编码则是从根节点到叶节点的路径所对应的二进制序列。
266 3
 算法系列之数据结构-Huffman树
|
算法 决策智能
基于禁忌搜索算法的VRP问题求解matlab仿真,带GUI界面,可设置参数
该程序基于禁忌搜索算法求解车辆路径问题(VRP),使用MATLAB2022a版本实现,并带有GUI界面。用户可通过界面设置参数并查看结果。禁忌搜索算法通过迭代改进当前解,并利用记忆机制避免陷入局部最优。程序包含初始化、定义邻域结构、设置禁忌列表等步骤,最终输出最优路径和相关数据图表。
|
存储 人工智能 供应链
光量子计算:计算速度的新突破
【9月更文挑战第17天】光量子计算利用光子的量子特性,突破传统计算瓶颈,展现强大信息处理能力。本文阐述了光量子计算原理,聚焦“九章三号”新进展:255光子高斯玻色取样,性能超越现有超级计算机亿亿倍。同时,展望其在优化问题解决、量子模拟、加密技术革新及人工智能加速上的应用前景,并讨论面临的挑战与未来技术发展的无限可能。
|
JavaScript
一文搞懂Vue3中slot插槽的使用!
前言 使用 Vue 的小伙伴相信你一定使用过插槽,如果你没有用过,那说明你的项目可能不是特别复杂。插槽(slot)可以说在一个 Vue 项目里面处处都有它的身影,比如我们使用一些 UI 组件库的时候,我们通常可以使用插槽来自定义我们的内容。 Vue3 已经推出很久了,也有越来越多的项目开始转向 Vue3 了,那么如果你对 Vue3 中的插槽还不熟悉,那么很有必要跟着本篇文章学习一下了!
2170 0
一文搞懂Vue3中slot插槽的使用!
|
XML 存储 网络协议
在Linux中,如何使用Wireshark进行网络协议分析?
在Linux中,如何使用Wireshark进行网络协议分析?
|
监控 负载均衡 应用服务中间件
Keepalive 解决nginx 的高可用问题
Keepalive 解决nginx 的高可用问题