使用pytorch自己构建网络模型实战

简介: 使用pytorch自己构建网络模型实战

写在前面

   前段时间在Git上下载了yolov5的代码,经过调试,最后运行成功。但是发现对网络训练的步骤其实很不熟悉,于是乎最近看了看基于pytorch的深度学习——通过学习,对pytorch的框架有了较清晰的认识,也可以自己来构建一些模型来进行训练。如果你也发现自己只知道在Git上克隆别人的代码,但是自己对程序的结构不了解,那么下面的内容可能会帮到你!!!


    这部分内容主要是根据B站视频总结而来,视频中给出了pytorch从安装到最后训练模型的完整教程,本篇文章主要总结神经网络的完整的模型训练套路,希望通过本篇文章可以让你对网络训练步骤有一个清晰的认识。


    本次内容用到的数据集是CIFAR10,使用这个数据的原因是这个数据比较轻量,基本上所有的电脑都可以跑。CIFAR10数据集里是一些32X32大小的图片,这些图片都有一个自己所属的类别(如airplane、cat等),如下图所示:image.png

    注意:这个数据集不需另外要从网页下载,程序中可以调整代码参数进行下载


 我们先来了解一下我们需要进行的工作及实现的功能:我们首先需要下载数据集,然后通过数据来训练模型,并在测试集上进行测试,这时候我们可以保存我们训练好的模型。最后通过我们训练的模型来判断一些图片的类别(从网络上下载一些图片,判断它是猫是狗或是其他的类型【当然这个数据集只有10种类型,如上图所示的10种】)


    下面我们就来一步步的介绍!!!【代码我分流程分部分介绍,完整代码放在文末自取】

完整网络模型训练步骤

1、准备数据集

   很显然,没有数据一切都是空谈,那么第一步就是准备我们需要的数据集CIFAR10。

#1、准备数据集train_dataset=torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download=Ture)
test_dataset=torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download=Ture)

    第一个参数“./data”是指定下载数据集保存的位置,第二个参数train=True/Flase是指下载的数据是训练集数据还是测试集数据【True表示训练集,Flase表示测试集】,第三个参数是图片的一个转化,要将图片格式转化为tensor类型,第四个参数download为True表示你没有这个数据,这时候会自动下载数据,为Flase表示有这个数据,不会再进行下载【注意:这个参数设置成True且你有数据集,那同样不会进行数据下载,故这个参数一直设置成True就好了】。


    我们可以打印数据集的长度来看一下这个数据集的大小,可以发现训练集有5000张图片,测试集有1000张图片。

train_dataset_size=len(train_dataset)
test_dataset_size=len(test_dataset)
print("train_dataset_size:{}".format(train_dataset_size))
print("test_dataset_size:{}".format(test_dataset_size))

                                   image.png

2、加载数据集

#2、加载数据集train_dataset_loader=DataLoader(dataset=train_dataset, batch_size=64)
test_dataset_loader=DataLoader(dataset=test_dataset, batch_size=64)

  在得到数据集后,我们还要对数据集进行加载,加载数据集就类似于打包,比如这里的第二个参数设置的是batch_size=64,则表示把dataset中的64个数据打包一起放入dataloader中。

487e832c1ddc4c5f82b879c7832f04bb.png

3、搭建神经网络✨✨✨

   加载好数据后,就可以搭建神经网络了,我们可以百度CIFAR10 model,可以出现很多CIFAR10的网络模型,如图所示:

image.png

   我们可以根据上图来搭建网络模型,如下:

#3、搭建神经网络classNet(nn.Module):
def__init__(self):
super(Net, self).__init__()
self.model1=nn.Sequential(
nn.Conv2d(3, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
        )
defforward(self, input):
input=self.model1(input)
returninput

   这部分代码完全是根据上图中的模型一步步写的,具有一一对应的关系,只是在卷积中的padding需要我们根据前后输入输出的尺寸进行计算,最后发现三步卷积padding都为2,这里给出pytorch官网的相关计算公式:

6c17143286e5416d9bebc8a06f4ecf6b.png

4、创建网络模型

这步只要一行代码,其实就是实列化了一个对象。

#4、创建网络模型net=Net()

我们可以打印出来看一看我们自己创建的网络模型,如下图。可以看出和前文的结构是一致的。

6b3bb169f9514d11b4da85bbeddfd724.png

  到这里我们已经创建好了自己的模型,这个模型输入是3x32x32的图片【可以认为就是一个3x32x32的张量】,输出是1x10的向量。每当我们创建好一个模型后,应该检测一下模型的输入输出是否是我们所期待的,若不是则即使调整模型。我们可以用以下代码来检测输出是否符合要求。

net=Net()
input=torch.ones((64, 3, 32, 32))  #64为batch_size,3x32x32表示张量尺寸output=net(input)
print(output.shape)

image.png

可以看出输出是符合要求的,64是输入的batch_size,相当于输入64张图片。


5、设置损失函数、优化器

设置损失函数、优化器这些都是神经网络的一些基础知识,不知道的自行补充。当然这里的损失函数和优化器可以和我不同,感兴趣的也可以改变这些来看看我们最后训练的效果会不会发生变化【我测试了几个,对于本例效果差别不大】

#5、设置损失函数、优化器#损失函数loss_fun=nn.CrossEntropyLoss()   #交叉熵loss_fun=loss_fun.to(device)
#优化器learning_rate=1e-2optimizer=torch.optim.SGD(net.parameters(), learning_rate)   #SGD:梯度下降算法

6、设置网络训练中的一些参数

这部分主要是用来记录一些训练测试的次数及网络训练轮数。

#6、设置网络训练中的一些参数total_train_step=0#记录总计训练次数total_test_step=0#记录总计测试次数epoch=10#设计训练轮数

7、开始训练网络✨✨✨

   进行网络训练时,我们首先会通过自己构建的网络得到输出,然后比较输出和真实值,计算出损失,最后通过反向传播,调整网络中参数的值。对于反向传播不理解的可以参考我的这篇文章:BP神经网络

#7、开始进行训练foriinrange(epoch):
print("---第{}轮训练开始---".format(i+1))
net.train()     #开始训练,不是必须的,在网络中有BN,dropout时需要fordataintrain_dataset_loader:
imgs, targets=datatargets=targets.to(device)
outputs=net(imgs)
#比较输出与真实值,计算Lossloss=loss_fun(outputs, targets)
#反向传播,调整参数optimizer.zero_grad()    #每次让梯度重置loss.backward()
optimizer.step()
total_train_step+=1iftotal_train_step%100==0:
print("---第{}次训练结束, Loss:{})".format(total_train_step, loss.item()))

8、开始测试网络✨✨✨

对网络进行测试过程和训练是类似的,不同的是测试过程不需要通过反向传播来更新参数。

#8、开始进行测试,测试不需要进行反向传播net.eval()   #开始测试,不是必须的,在网络中有BN,dropout时需要withtorch.no_grad():    #这句表示测试不需要进行反向传播,即不需要梯度变化【可以不加】total_test_loss=0#测试损失total_test_accuracy=0#测试集准确率fordataintest_dataset_loader:
imgs, targets=dataoutputs=net(imgs)
#计算测试损失loss=loss_fun(outputs, targets)
total_test_loss=total_test_loss+loss.item()
accuracy= (outputs.argmax(1) ==targets).sum()
total_test_accuracy=total_test_accuracy+accuracyprint("第{}轮测试的总损失为:{}".format(i+1, total_test_loss))
print("第{}轮测试的准确率为:{}".format(i+1, total_test_accuracy/test_dataset_size))

9、保存模型

将每一个epoch的模型都保存下来,为后面物体识别准备模型。

#9、保存模型torch.save(net, "./self_model_{}".pth.format(i+1))
print("模型已保存")




检测训练模型的效果

   介绍到这里,完整的自建网络模型训练步骤我们就讲完了,接下来来看看我们用之前保存的模型来检测一些我们从网络上下载的图片,代码如下:

importtorchimporttorchvisionfromPILimportImagefromtorchimportnnimage_path="./imgs/airplane.png"#网络下载的图片放置地址image=Image.open(image_path)
image=image.convert('RGB')  #将图片转化为RGB三通道图片,有的图片有4个通道(多了个透明度)transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()])
image=transform(image)
classNet(nn.Module):
def__init__(self):
super(Net, self).__init__()
self.model1=nn.Sequential(
nn.Conv2d(3, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
        )
defforward(self, x):
x=self.model1(x)
returnxmodel=torch.load("net_29.pth", map_location=torch.device('cpu'))
print(model)
image=torch.reshape(image, (1, 3, 32, 32))
model.eval()
withtorch.no_grad():
output=model(image)
print(output.argmax(1))

网络下载图片如下:

image.png

输出结果如下:

image.png

0表示的就是airplane【可以从官网中10种类型顺序得出,从上到下是0-9】。

我们可以在来测试一张狗的图片,从官网可知,输出5为狗,原始图片和输出图片如下:

image.png


这里我们可以来看一下模型的检测损失和正确率(设置的epoch=20),准确率大概在65%左右。【这里是在Google Colab上用GPU训练的,单用CPU训练速度还是很慢】445d7747de31478781ca2efddbcd911f.png

模型的准确率似乎就停留在65%上下,我尝试增大epoch到30,但是准确率基本一致。同时我也用3x3的小卷积核代替5x5的卷积核、用卷积代替池化,用卷积代替全连接层等方式进行训练,但是效果都不显著,当然这里我只训练了30个epoch,增大epoch效果可能会好,但耗时会比较多,这部分主要是学习训练模型的思路,感兴趣可以尝试各种方式看能否改进模型效果。


   下图是用Tensorboard画的损失和准确率的曲线图,上文的代码中只关注模型的训练步骤,没有设计tensorboard的讲解,在文末源代码中会包含这部分内容。

                              3993d07ba5c74664a65952f56d963a04.png

完整代码

相关文章
|
1天前
|
负载均衡 网络协议 开发者
掌握 Docker 网络:构建复杂的容器通信
在 Docker 容器化环境中,容器间的通信至关重要。本文详细介绍了 Docker 网络的基本概念和类型,包括桥接网络、宿主网络、覆盖网络和 Macvlan 网络等,并提供了创建、管理和配置自定义网络的实用命令。通过掌握这些知识,开发者可以构建更健壮和灵活的容器化应用,提高应用的可扩展性和安全性。
|
3天前
|
人工智能 安全 算法
网络安全与信息安全:构建数字世界的防线
在数字化浪潮席卷全球的今天,网络安全与信息安全已成为维系社会秩序、保障个人隐私与企业机密的关键。本文旨在探讨网络安全漏洞的成因、加密技术的应用及安全意识的提升策略,以期为读者提供一个全面而深入的网络安全知识框架。
|
2天前
|
存储 安全 网络安全
网络安全与信息安全:构建安全防线的多维策略在当今数字化时代,网络安全已成为维护个人隐私、企业机密和国家安全的关键要素。本文旨在探讨网络安全漏洞的本质、加密技术的重要性以及提升公众安全意识的必要性,以期为构建更加坚固的网络环境提供参考。
本文聚焦于网络安全领域的核心议题,包括网络安全漏洞的现状与应对、加密技术的发展与应用,以及安全意识的培养与实践。通过分析真实案例,揭示网络安全威胁的多样性与复杂性,强调综合防护策略的重要性。不同于传统摘要,本文将直接深入核心内容,以简洁明了的方式概述各章节要点,旨在迅速吸引读者兴趣,引导其进一步探索全文。
|
2天前
|
安全 网络安全 云计算
云计算与网络安全:构建安全的数字未来
在数字化浪潮中,云计算已成为推动企业创新与发展的重要引擎。然而,随着云服务的普及,网络安全问题也日益凸显,成为制约云计算进一步发展的瓶颈。本文旨在深入探讨云计算环境下的网络安全挑战,分析云服务中的安全隐患,并提出相应的信息安全对策。通过构建安全的云计算环境,为企业数字化转型保驾护航,共同迈向安全的数字未来。
|
2天前
|
存储 安全 网络安全
网络安全与信息安全:构建防线的多维策略
在数字化浪潮中,网络安全已成为企业和个人不可忽视的重要议题。本文深入探讨了网络安全漏洞的本质、加密技术的核心作用以及提升安全意识的重要性。通过分析真实案例和最新研究成果,我们揭示了网络威胁的多样性和复杂性,同时提供了实用的防护措施和策略。无论你是技术专家还是普通用户,本文都将帮助你建立更全面的网络安全视角,共同守护数字世界的安全与和谐。
|
2天前
|
数据采集 存储 JSON
从零到一构建网络爬虫帝国:HTTP协议+Python requests库深度解析
在网络数据的海洋中,网络爬虫遵循HTTP协议,穿梭于互联网各处,收集宝贵信息。本文将从零开始,使用Python的requests库,深入解析HTTP协议,助你构建自己的网络爬虫帝国。首先介绍HTTP协议基础,包括请求与响应结构;然后详细介绍requests库的安装与使用,演示如何发送GET和POST请求并处理响应;最后概述爬虫构建流程及挑战,帮助你逐步掌握核心技术,畅游数据海洋。
17 3
|
2天前
|
存储 安全 网络安全
云计算与网络安全:构建安全的数字天空##
随着数字化时代的到来,云计算已经成为企业和个人不可或缺的基础设施。然而,伴随其便利性而来的是一系列网络安全风险和挑战。本文将探讨云计算的基本概念、云服务的类型、网络安全的重要性及常见威胁,并讨论如何通过技术手段和管理策略来确保信息安全,以期为读者提供全面的理解和实用的建议。 ##
|
2天前
|
数据采集 API 开发者
🚀告别网络爬虫小白!urllib与requests联手,Python网络请求实战全攻略
在网络的广阔世界里,Python凭借其简洁的语法和强大的库支持,成为开发网络爬虫的首选语言。本文将通过实战案例,带你探索urllib和requests两大神器的魅力。urllib作为Python内置库,虽API稍显繁琐,但有助于理解HTTP请求本质;requests则简化了请求流程,使开发者更专注于业务逻辑。从基本的网页内容抓取到处理Cookies与Session,我们将逐一剖析,助你从爬虫新手成长为高手。
17 1
|
5天前
|
安全 算法 网络安全
网络安全与信息安全:构建数字世界的防线在数字化浪潮席卷全球的今天,网络安全和信息安全已成为维系社会秩序、保障个人隐私与企业机密的基石。本文旨在深入探讨网络安全漏洞的本质、加密技术的前沿进展以及提升公众安全意识的重要性,共同绘制一幅维护网络空间安宁的蓝图。
本文聚焦网络安全与信息安全的核心议题,通过剖析网络安全漏洞的成因与影响,阐述加密技术在保护信息安全中的关键作用,强调了提升全社会安全意识的紧迫性。不同于常规摘要,本文采用叙述式摘要,以第一人称视角引领读者走进网络安全的世界,揭示问题本质,展望未来趋势。
|
2天前
|
Python
HTTP协议不再是迷!Python网络请求实战,带你走进网络世界的奥秘
本文介绍了HTTP协议,它是互联网信息传递的核心。作为客户端与服务器通信的基础,HTTP请求包括请求行、头和体三部分。通过Python的`requests`库,我们可以轻松实现HTTP请求。本文将指导你安装`requests`库,并通过实战示例演示如何发送GET和POST请求。无论你是想获取网页内容还是提交表单数据,都能通过简单的代码实现。希望本文能帮助你在Python网络请求的道路上迈出坚实的一步。
9 0