如何用Pytorch加载部分权重

简介: 在我做实验的过程中,由于卷积神经网络层数的更改,导致原始网络模型的权重加载失败,经过分析,是因为不匹配造成的,如下方式可以解决.

在我做实验的过程中,由于卷积神经网络层数的更改,导致原始网络模型的权重加载失败,经过分析,是因为不匹配造成的,如下方式可以解决.

import torch
import models
checkpoint = torch.load("./logs/01origial/model_best.pth")
model = models.__dict__["vgg"](dataset="Beans", depth=16) #提取网络结结构,分别是数据集,网络的深度和每层的输出通道数
model.load_state_dict(checkpoint['state_dict'])
model_10 = models.__dict__["vgg10"](dataset="Beans", depth=10)
model_dict = model.state_dict()
model_10_dict = model_10.state_dict()
pretrained_dict = {k: v for k, v in model_dict.items() if k in model_10_dict.keys()}
model_10_dict.update(pretrained_dict)
model_10.load_state_dict(model_10_dict)
相关文章
|
7月前
|
PyTorch API 算法框架/工具
SWA(随机权重平均) for Pytorch
SWA(随机权重平均) for Pytorch
240 0
|
6月前
|
机器学习/深度学习 人工智能 PyTorch
人工智能平台PAI产品使用合集之Alink是否加载预训练好的pytorch模型
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
7月前
|
PyTorch 算法框架/工具 异构计算
pytorch 模型保存与加载
pytorch 模型保存与加载
51 0
|
7月前
|
机器学习/深度学习 PyTorch 调度
PyTorch进阶:模型保存与加载,以及断点续训技巧
【4月更文挑战第17天】本文介绍了PyTorch中模型的保存与加载,以及断点续训技巧。使用`torch.save`和`torch.load`可保存和加载模型权重和状态字典。保存模型时,可选择仅保存轻量级的状态字典或整个模型对象。加载时,需确保模型结构与保存时一致。断点续训需保存训练状态,包括epoch、batch index、optimizer和scheduler状态。中断后,加载这些状态以恢复训练,节省时间和资源。
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
442 0
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
89 0
|
并行计算 PyTorch 算法框架/工具
Pytorch:模型的保存/加载、并行化、分布式
Pytorch:模型的保存/加载、并行化、分布式
169 0
|
存储 机器学习/深度学习 PyTorch
Pytorch学习笔记(9)模型的保存与加载、模型微调、GPU使用
Pytorch学习笔记(9)模型的保存与加载、模型微调、GPU使用
682 0
Pytorch学习笔记(9)模型的保存与加载、模型微调、GPU使用
|
PyTorch 算法框架/工具
【PyTorch】初始化网络各层权重
【PyTorch】初始化网络各层权重
65 0