PyTorch深度学习实战 |基于Alexnet网络预训练模型完成训练花分类任务实战

本文涉及的产品
RDS DuckDB + QuickBI 企业套餐,8核32GB + QuickBI 专业版
简介: 本文介绍了使用AlexNet模型进行花卉图像分类的实战过程。首先讲解了数据集的准备方法,包括5类花卉数据(雏菊、蒲公英等)的8:2训练集/验证集划分。详细解析了AlexNet的网络结构(5个卷积层+3个全连接层)及其创新点,如ReLU激活函数和Dropout正则化。提供了完整的PyTorch实现代码,包括模型定义、数据增强和训练流程。实验结果表明,50轮训练后验证集准确率可达80%。文章还介绍了使用预训练模型进行迁移学习的方法,通过修改分类器层并微调参数,可以显著提升训练效率和分类效果。整个项目从数据准备到

 使用的数据集

花分类数据集:

百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg

提取码:58p0

下载好之后,解压到flower_data文件夹下,此时flower_data\flower_photos下就是放的我们的数据

集,我们看一下原始的数据是什么样子的:

分类类别:共包含 5 类花卉,对应 5 个文件夹: daisy(雏菊) dandelion(蒲公英) roses(玫

瑰) sunflowers(向日葵) tulips(郁金香)

image.gif

跑过一些项目的应该都有印象,比如YOLO等,他们的数据集的放置是有要求的一般情况下都是分

成两个,一个是train文件夹,train文件夹下是各种分类的文件夹(每个文件夹的名字是类报名)。

另外一个是val文件夹,val文件夹下是各种分类的文件夹(每个文件夹的名字是类报名)。一般是

按照8:2的比例去分这两个数据集的。这里的话可以用AI写代码整理,但是别忘记了检查一下。

训练集的路径:D:\vscode\shenduxvexishizhan\CNN\flower_data\train

验证集的路径是:D:\vscode\shenduxvexishizhan\CNN\flower_data\val


Alexnet

AlexNet创新点

(1)AlexNet首次成功使用了8层深度网络(5个卷积层 + 3个全连接层),比之前的网络深得多

(2)首次在深层网络中大规模且成功地使用了(ReLU)作为激活函数,取代了传统的 Sigmoid

或 Tanh 函数。

(3)引入GPU 加速训练 (GPU Acceleration),Dropout 正则化 (Dropout Regularization),数据增

强 (Data 局部响应归一化(LRN)和重叠池化(Overlapping Pooling)Augmentation)

网络结构

层类型 具体参数 输入尺寸 输出尺寸 作用
卷积层 1 Conv2d (3→48, 11×11, 步长 4, 填充 2) [3,224,224] [48,55,55] 提取基础纹理特征
池化层 1 MaxPool2d (3×3, 步长 2) [48,55,55] [48,27,27] 降维 + 增强鲁棒性
卷积层 2 Conv2d (48→128, 5×5, 填充 2) [48,27,27] [128,27,27] 提取更复杂特征
池化层 2 MaxPool2d (3×3, 步长 2) [128,27,27] [128,13,13] 继续降维
卷积层 3 Conv2d (128→192, 3×3, 填充 1) [128,13,13] [192,13,13] 特征细化
卷积层 4 Conv2d (192→192, 3×3, 填充 1) [192,13,13] [192,13,13] 特征细化
卷积层 5 Conv2d (192→128, 3×3, 填充 1) [192,13,13] [128,13,13] 特征压缩
池化层 3 MaxPool2d (3×3, 步长 2) [128,13,13] [128,6,6] 最终降维
全连接层 1 Linear(128×6×6 → 2048) 4608 2048 特征映射到高维空间
全连接层 2 Linear(2048 → 2048) 2048 2048 特征变换
全连接层 3 Linear(2048 → num_classes) 2048 num_classes 最终分类

image.gif

model.py

import torch.nn as nn
import torch
class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
# 测试模型前向传播
if __name__ == "__main__":
    # 创建模型实例
    model = AlexNet(num_classes=1000, init_weights=True)
    # 生成随机输入(batch_size=4, 3通道, 224×224)
    input_tensor = torch.randn(4, 3, 224, 224)
    # 前向传播
    output = model(input_tensor)
    # 打印输出形状
    print(f"输入形状: {input_tensor.shape}")
    print(f"输出形状: {output.shape}")  # 应输出 [4, 1000]
    # 打印模型参数量
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"总参数量: {total_params/1e6:.2f}M")
    print(f"可训练参数量: {trainable_params/1e6:.2f}M")

image.gif

模型的输入是【B,3,224,224】代表B张图片,输出是【B,5】代表B个样本,每个类别的概

率。在实际的项目中,这是最简单,也是最核心的部分。简单是因为,所有神经网络的本质都是为

了提取特征的,所有我们很多时候不需要知道,其是怎么实现的,只需要知道,网络的输入和输出

就行。最核心是因为,往往特征提取的好坏直接决定了训练效果的好坏。

image.gif

图解处理和图像增强:

在验证的时候是不需要数据增强的

data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

image.gif

train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))
    net = AlexNet(num_classes=5, init_weights=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)
    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    print('Finished Training')
if __name__ == '__main__':
    main()

image.gif

训练结果:

训练50代的效果如下:验证集的准确度大致可以稳定在0.8左右

image.gif


预训练模型完成训练

加载预训练权重: 导入 ImageNet 上训练好的 AlexNet 权重。

修改分类器: 由于你的任务是 5 类花卉分类,需要替换 AlexNet 原本的 1000 类输出层,以匹配你的 num_classes=5。

设置学习率/冻结层: 通常对预训练模型使用更小的学习率,或者冻结特征提取层(features)的参数,只训练分类器(classifier)的参数。

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models # 导入 models
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet # 假设这是你自定义的 AlexNet 类
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))
    # ... (数据加载和处理部分保持不变) ...
    # ------------------ 【修改开始】模型加载与设置 ------------------
    # 1. 实例化 AlexNet 模型
    # 如果使用你自定义的 AlexNet 类,这里加载预训练权重(需要文件支持)
    net = AlexNet(num_classes=1000, init_weights=False) # 实例化1000类模型
    
    # 假设你的预训练权重是 'alexnet_imagenet.pth'
    # weights_path = "./alexnet_imagenet.pth"
    # assert os.path.exists(weights_path), f"Pretrained weights file: '{weights_path}' not found."
    # net.load_state_dict(torch.load(weights_path), strict=False) # strict=False 如果你的类和权重不完全匹配
    # 或者:使用官方预训练模型(更简单)
    net = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
    # 2. 替换分类器 (迁移学习的关键)
    in_features = net.classifier[6].in_features # 官方 AlexNet 的最后一个 Linear 层是第 6 个模块 (索引从 0 开始)
    # 替换为 5 个类别的输出
    net.classifier[6] = nn.Linear(in_features, 5) 
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    
    # 3. 设置微调学习率(通常更小)
    # 只训练分类器参数,使用更高的学习率:
    # optimizer = optim.Adam(net.classifier.parameters(), lr=0.001) 
    
    # 或者,微调所有参数,使用更小的学习率:
    optimizer = optim.Adam(net.parameters(), lr=0.00005) # 降低学习率进行微调
    # ------------------ 【修改结束】模型加载与设置 ------------------
    epochs = 50
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    print('Finished Training')
if __name__ == '__main__':
    main()

image.gif

训练效果:

我们应该可以直观的感受到,这种方式的训练速度更快,效果更好

image.gif


目录
相关文章
|
1小时前
|
机器学习/深度学习 数据挖掘 PyTorch
PyTorch深度学习实战 |手算​​FCN全卷积神经网络
本文介绍了FCN-8s语义分割网络的实现细节。首先解释了语义分割的概念及其与图像分类的区别,重点分析了FCN网络结构中的全卷积化、上采样和跳跃连接三个关键技术。全卷积化将传统CNN的全连接层改为卷积层,实现像素级分类;上采样通过双线性插值恢复特征图尺寸;跳跃连接则融合高低层特征以提升细节表现。文章详细推导了损失函数的计算过程,并提供了完整的PyTorch实现代码,包括双线性插值权重初始化、VGG16骨干网络和FCN-8s主体结构。最后通过测试验证了模型能正确输出与输入尺寸匹配的预测结果。
74 3
|
1小时前
|
机器学习/深度学习 编解码 算法
PyTorch深度学习实战 |手算​​U-net
本文详细解析了U-Net网络架构及其在医学图像分割中的应用。重点对比了U-Net与FCN的核心区别:U-Net采用特征拼接(Concat)保留所有层级信息,而FCN使用特征相加(Add)进行融合。文章深入剖析了U-Net的编码器-瓶颈-解码器结构,解释了其独特的裁剪拼接机制和Overlap-tile策略,并提供了完整的PyTorch实现代码。现代U-Net通过SamePadding实现了输入输出尺寸一致,显著提升了分割精度。文章还探讨了弹性形变数据增强和带空间权重的损失函数设计,为医学图像分析提供了实用解决
87 2
|
1小时前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习实战 |层归一化层和FeedForward
本文介绍了PyTorch深度学习中Add&Norm层和FeedForward层的实现原理。Add&Norm层由残差连接(Add)和层归一化(Norm)组成,能加速模型收敛并稳定训练。层归一化会对神经网络每层的输出进行归一化处理,文中详细展示了其计算方法和PyTorch实现代码。FeedForward层是一个两层的全连接网络,通过线性变换提取更深层次特征。文章还分析了Transformer模型中使用层归一化的原因,并提供了完整的代码实现,包括参数初始化和前向传播过程。
64 0
|
1小时前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch深度学习实战 | 基于LSTM的时间序列预测任务
本文介绍了使用LSTM模型预测印度德里市平均温度的两个项目。项目1对温度数据进行归一化处理,采用滑动窗口法构建监督学习样本,设计5层LSTM网络结构,并详细说明了模型训练过程及评估方法。项目2在数据处理上增加了标准化和周期性特征,改进了网络架构,引入了学习率调整和早停机制优化训练过程。两个项目均通过可视化对比预测值和真实值,验证了LSTM模型在时间序列预测中的有效性。文章从数据处理、模型构建到训练优化,完整呈现了温度预测的实现流程,为时序预测任务提供了实用参考。
59 0
|
1小时前
|
机器学习/深度学习 人工智能 自然语言处理
图解人工智能的数学基础(最优化)
本文深入解析人工智能中最优化问题的核心:通过最小化损失函数来训练模型。涵盖回归(MSE)与分类(交叉熵)任务的典型损失函数,详解梯度下降原理及BGD、SGD、Mini-batch等算法差异,并介绍Momentum、Adam等现代优化技巧,辅以PyTorch代码实现。
62 0
|
1小时前
|
人工智能 自然语言处理 Python
人工智能|BERT的简单介绍
BERT(2018年谷歌提出)是基于Transformer编码器的双向预训练语言模型,通过掩码语言建模(MLM)和下一句预测(NSP)任务学习深度上下文语义,在文本分类、问答、NER等理解型任务中表现卓越。
119 1
|
1小时前
|
数据采集 人工智能 数据可视化
人工智能|YOLOv5必须了解的知识
本文详解YOLOv5网络结构(Input/Backbone/Neck/Head)及train.py核心实现:包括模型加载(预训练权重适配)、yaml配置解析、数据集读取与增强、标签格式说明、多尺度特征融合机制,以及推理阶段预处理、NMS过滤与结果可视化全流程。
171 2
|
1小时前
|
机器学习/深度学习 数据可视化 机器人
PyTorch深度学习实战 |手算​​自编码Autoencoder
自编码器是一种无监督神经网络,通过编码器将数据压缩为低维潜在表示,再由解码器重建原始输入。其核心价值在于自动提取关键特征、实现降维与数据去噪,广泛应用于图像重建、特征学习和可视化分析等领域。
120 3
|
1小时前
|
机器学习/深度学习 存储 编解码
PyTorch深度学习实战 | 手算卷积网络(Resnet-18)
ResNet-18是解决深层网络梯度消失与退化问题的经典模型,核心在于残差连接(Shortcut):让输入X直接跳跃传递,与卷积学习的残差F(X)相加(F(X)+X),实现恒等映射。其含4个stage、18层可训练层,每个BasicBlock由两个3×3卷积+BN+ReLU构成,并通过1×1卷积适配尺寸/通道差异,显著提升深层网络训练稳定性与性能。(239字)
121 2
|
1小时前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch深度学习实战 |基于ViT(Vision Transformer)网络花分类任务
本文介绍了基于PyTorch的ViT(Vision Transformer)模型在花卉分类任务中的实战应用。主要内容包括: 数据集准备:使用包含5类花卉(雏菊、蒲公英、玫瑰、向日葵、郁金香)的数据集,按8:2比例划分为训练集和验证集。 模型架构:实现了一个精简版ViT模型,包含Patch Embedding、CLS Token、位置嵌入和Transformer编码器等核心组件。 训练流程:详细展示了数据加载、模型训练、验证及测试的完整代码实现,包括损失函数、优化器和学习率调度等配置。 辅助功能:提供了设备选
52 0