Pytorch 中怎么计算网络的参数量

简介: 本博文以 Dense Block 为例,Pytorch 为 DL 框架,最终计算模块参数量方法如下:

本博文以 Dense Block 为例,Pytorch 为 DL 框架,最终计算模块参数量方法如下:

import torch
import torch.nn as nn
class Norm_Conv(nn.Module):
    def __init__(self,in_channel):
        super(Norm_Conv,self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channel,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel),
            nn.Conv2d(in_channel,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel),
            nn.Conv2d(in_channel,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel))
    def forward(self,input):
        out = self.layers(input)
        return out
class DenseBlock_Norm(nn.Module):
    def __init__(self,in_channel):
        super(DenseBlock_Norm,self).__init__()
        self.first_layer = nn.Sequential(nn.Conv2d(in_channel,in_channel,3,1,1),
                                        nn.ReLU(True),
                                        nn.BatchNorm2d(in_channel))
        self.second_layer = nn.Sequential(nn.Conv2d(in_channel*2,in_channel,3,1,1),
                                          nn.ReLU(True),
                                          nn.BatchNorm2d(in_channel))
        self.third_layer = nn.Sequential(
            nn.Conv2d(in_channel*3,in_channel,3,1,1),
            nn.ReLU(True),
            nn.BatchNorm2d(in_channel))
    def forward(self,input):
        output1 = self.first_layer(input)
        output2 = self.second_layer(torch.cat((output1,input),dim=1))
        output3 = self.third_layer(torch.cat((input,output1,output2),dim=1))
        return output3
def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count
# Get Parameter number of Network
in_channel = 128
net1 = Norm_Conv(in_channel)
print('Norm Conv parameter count is {}'.format(count_param(net1)))
net2 = DenseBlock_Norm(in_channel)
print('DenseBlock Norm parameter count is {}'.format(count_param(net2)))

最终结果如下

Norm Conv parameter count is 443520
DenseBlock Norm parameter count is 885888
相关文章
|
22天前
|
机器学习/深度学习
神经网络各种层的输入输出尺寸计算
神经网络各种层的输入输出尺寸计算
31 1
|
24天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 中的动态计算图:实现灵活的神经网络架构
【8月更文第27天】PyTorch 是一款流行的深度学习框架,它以其灵活性和易用性而闻名。与 TensorFlow 等其他框架相比,PyTorch 最大的特点之一是支持动态计算图。这意味着开发者可以在运行时定义网络结构,这为构建复杂的模型提供了极大的便利。本文将深入探讨 PyTorch 中动态计算图的工作原理,并通过一些示例代码展示如何利用这一特性来构建灵活的神经网络架构。
49 1
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
【深度学习】使用PyTorch构建神经网络:深度学习实战指南
PyTorch是一个开源的Python机器学习库,特别专注于深度学习领域。它由Facebook的AI研究团队开发并维护,因其灵活的架构、动态计算图以及在科研和工业界的广泛支持而受到青睐。PyTorch提供了强大的GPU加速能力,使得在处理大规模数据集和复杂模型时效率极高。
160 59
|
8天前
|
机器学习/深度学习
小土堆-pytorch-神经网络-损失函数与反向传播_笔记
在使用损失函数时,关键在于匹配输入和输出形状。例如,在L1Loss中,输入形状中的N代表批量大小。以下是具体示例:对于相同形状的输入和目标张量,L1Loss默认计算差值并求平均;此外,均方误差(MSE)也是常用损失函数。实战中,损失函数用于计算模型输出与真实标签间的差距,并通过反向传播更新模型参数。
|
1月前
|
云安全 安全 网络安全
云端防御战线:融合云计算与网络安全的未来策略
【7月更文挑战第47天】 在数字化时代,云计算已成为企业运营不可或缺的部分,而网络安全则是维护这些服务正常运行的基石。随着技术不断进步,传统的安全措施已不足以应对新兴的威胁。本文将探讨云计算环境中的安全挑战,并提出一种融合云服务与网络安全的综合防御策略。我们将分析云服务模式、网络威胁类型以及信息安全实践,并讨论如何构建一个既灵活又强大的安全体系,确保数据和服务的完整性、可用性与机密性。
|
1月前
|
机器学习/深度学习 PyTorch TensorFlow
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
【PyTorch】PyTorch深度学习框架实战(一):实现你的第一个DNN网络
76 1
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
57 1
|
22天前
|
机器学习/深度学习 PyTorch 测试技术
深度学习入门:使用 PyTorch 构建和训练你的第一个神经网络
【8月更文第29天】深度学习是机器学习的一个分支,它利用多层非线性处理单元(即神经网络)来解决复杂的模式识别问题。PyTorch 是一个强大的深度学习框架,它提供了灵活的 API 和动态计算图,非常适合初学者和研究者使用。
33 0
|
27天前
|
网络协议 算法 网络架构
OSPF 如何计算到目标网络的最佳路径
【8月更文挑战第24天】
28 0
|
1天前
|
SQL 安全 算法
网络安全与信息安全:保护你的数字世界
【9月更文挑战第18天】在这个数字信息时代,网络安全和信息安全的重要性不言而喻。从网络漏洞的发现到加密技术的应用,再到安全意识的提升,每一个环节都至关重要。本文将深入探讨这些主题,并提供实用的建议和代码示例,以帮助读者更好地保护自己的数字世界。
20 11