如何在一夜之间成为模型微调大师?——从零开始的深度学习修炼之旅,让你的算法功力飙升!

简介: 【10月更文挑战第5天】在机器学习领域,预训练模型具有强大的泛化能力,但直接使用可能效果不佳,尤其在特定任务上。此时,模型微调显得尤为重要。本文通过图像分类任务,详细介绍如何利用PyTorch对ResNet-50模型进行微调,包括环境搭建、数据预处理、模型加载与训练等步骤,并提供完整Python代码。通过调整超参数和采用早停策略等技巧,可进一步优化模型性能。适合初学者快速上手模型微调。

通俗易懂理解模型微调全流程

机器学习特别是深度学习领域,预训练模型因其强大的泛化能力而被广泛应用。然而,直接使用预训练模型可能无法达到最佳效果,尤其是在特定领域或任务上。这时候就需要对模型进行微调(Fine-tuning),使其更好地适应新的数据集。本文将通过一个简单的例子带你一步步了解如何进行模型的微调,并附带Python代码实现。

假设我们有一个图像分类任务,想要利用已经训练好的ResNet-50模型来进行微调。首先确保你的环境中已安装PyTorch和相关依赖包。下面是一个从环境准备到模型训练、评估的完整流程示例。

第一步,导入必要的库:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader

第二步,定义数据预处理方法:

data_transforms = {
   
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, .456, .406], [0.229, .224, .225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, .456, .406], [0.229, .224, .225])
    ]),
}

第三步,加载数据集。这里假设我们使用的是ImageNet的一个子集作为训练数据:

data_dir = 'path/to/dataset'
image_datasets = {
   x: datasets.ImageFolder(root=data_dir + x, transform=data_transforms[x]) for x in ['train', 'val']}
dataloaders = {
   x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {
   x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

第四步,加载预训练的ResNet-50模型:

model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))

第五步,设置训练参数:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = torch.optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

第六步,定义训练函数:

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders['train']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        if scheduler:
            scheduler.step()

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_corrects = 0

        for inputs, labels in dataloaders['val']:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            val_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / dataset_sizes['train']
        val_epoch_loss = val_loss / dataset_sizes['val']
        val_epoch_acc = val_corrects.double() / dataset_sizes['val']

        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        print(f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')

train_model(model_ft, criterion, optimizer_ft, scheduler, num_epochs=2)

以上步骤完成了模型微调的基本过程。值得注意的是,在实际应用中,根据具体情况调整超参数如学习率、批量大小等是常见的做法。此外,还可以加入更多的技巧如早停策略(Early Stopping)来防止过拟合。希望这篇指南能帮助你更好地理解和实践模型微调。

相关文章
|
1天前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能食品加工优化的深度学习模型
使用Python实现智能食品加工优化的深度学习模型
90 59
|
2天前
|
机器学习/深度学习 数据采集 数据库
使用Python实现智能食品营养分析的深度学习模型
使用Python实现智能食品营养分析的深度学习模型
21 6
|
4天前
|
机器学习/深度学习 供应链 安全
使用Python实现智能食品供应链管理的深度学习模型
使用Python实现智能食品供应链管理的深度学习模型
22 3
|
8天前
|
机器学习/深度学习 数据采集 存储
使用Python实现智能农业灌溉系统的深度学习模型
使用Python实现智能农业灌溉系统的深度学习模型
50 6
|
6天前
|
机器学习/深度学习 PyTorch TensorFlow
使用Python实现智能食品质量检测的深度学习模型
使用Python实现智能食品质量检测的深度学习模型
37 1
|
6天前
|
机器学习/深度学习 数据采集 自然语言处理
如何使用深度学习模型来提高命名实体识别的准确率?
如何使用深度学习模型来提高命名实体识别的准确率?
|
8天前
|
机器学习/深度学习 人工智能 算法
青否数字人声音克隆算法升级,16个超真实直播声音模型免费送!
青否数字人的声音克隆算法全面升级,能够完美克隆真人的音调、语速、情感和呼吸。提供16种超真实的直播声音模型,支持3大AI直播类型和6大核心AIGC技术,60秒快速开播,助力商家轻松赚钱。AI讲品、互动和售卖功能强大,支持多平台直播,确保每场直播话术不重复,智能互动和真实感十足。新手小白也能轻松上手,有效规避违规风险。
|
9天前
|
分布式计算 Java 开发工具
阿里云MaxCompute-XGBoost on Spark 极限梯度提升算法的分布式训练与模型持久化oss的实现与代码浅析
本文介绍了XGBoost在MaxCompute+OSS架构下模型持久化遇到的问题及其解决方案。首先简要介绍了XGBoost的特点和应用场景,随后详细描述了客户在将XGBoost on Spark任务从HDFS迁移到OSS时遇到的异常情况。通过分析异常堆栈和源代码,发现使用的`nativeBooster.saveModel`方法不支持OSS路径,而使用`write.overwrite().save`方法则能成功保存模型。最后提供了完整的Scala代码示例、Maven配置和提交命令,帮助用户顺利迁移模型存储路径。
|
3天前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现智能食品安全监测的深度学习模型
使用Python实现智能食品安全监测的深度学习模型
17 0
|
4天前
|
机器学习/深度学习 存储 自然语言处理
使用深度学习模型进行情感分析!!!
本文介绍了如何使用深度学习模型进行中文情感分析。首先导入了必要的库,包括`transformers`、`pandas`、`jieba`和`re`。然后定义了一个`SentimentAnalysis`类,用于处理数据、加载真实标签和评估模型准确性。在主函数中,使用预训练的情感分析模型对处理后的数据进行预测,并计算模型的准确性。
10 0