mport torch.nn as nn
from torchsummary import summary
定义⽹络结构
net = nn.Sequential(
nn.Conv2d(1,8,kernel_size=7),
nn.MaxPool2d(2,stride=2),
nn.ReLU(True),
nn.Conv2d(8,10,kernel_size=5),
nn.MaxPool2d(2,stride=2),
nn.ReLU(True)
)
输出每层⽹络参数信息
summary(net,(1,28,28),batch_size=1,device="cpu")
Layer (type) Output Shape Param #
================================================================
Conv2d-1[1,8,22,22]400
MaxPool2d-2[1,8,11,11]0
ReLU-3[1,8,11,11]0
Conv2d-4[1,10,7,7]2,010
MaxPool2d-5[1,10,3,3]0
ReLU-6[1,10,3,3]0
================================================================
Total params:2,410
Trainable params:2,410
Non-trainable params:0
Input size (MB):0.00
Forward/backward pass size (MB):0.05
Params size (MB):0.01