安装 torchsummary
pip install torchsummary
1
输出网络信息
summary函数介绍
model:网络模型
input_size:网络输入图片的shape,这里不用加batch_size进去
batch_size:batch_size参数,默认是-1
device:在GPU还是CPU上运行,默认是cuda在GPU上运行,如果想在CPU上执行将参数改为CPU即可
import torch
import torch.nn as nn
from torchsummary import summary
class Shallow_ConvNet(nn.Module):
def __init__(self, in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat,
pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size) :
super(Shallow_ConvNet, self).__init__()
self.temp_conv = nn.Conv2d(in_channels=in_channel,
out_channels=conv_channel_temp,
kernel_size=(1, kernel_size_temp),
stride=1,
bias=False)
self.spat_conv = nn.Conv2d(in_channels=conv_channel_temp,
out_channels=conv_channel_spat,
kernel_size=(kernel_size_spat, 1),
stride=1,
bias=False)
self.bn = nn.BatchNorm2d(num_features=conv_channel_spat)
# slef.act_conv = x*x
self.pooling = nn.AvgPool2d(kernel_size=(1, pooling_size),
stride=(1, pool_stride_size))
# slef.act_pool = log(max(x, eps))
self.dropout = nn.Dropout(p=dropoutRate)
self.class_conv = nn.Conv2d(in_channels=conv_channel_spat,
out_channels=n_classes,
kernel_size=(1, class_kernel_size),
bias=False)
self.softmax = nn.Softmax(dim=1)
def safe_log(self, x):
""" Prevents :math:`log(0)` by using :math:`log(max(x, eps))`."""
return torch.log(torch.clamp(x, min=1e-6))
def forward(self, x):
# input shape (batch_size, C, T)
if len(x.shape) is not 4:
x = torch.unsqueeze(x, 1)
# input shape (batch_size, 1, C, T)
x = self.temp_conv(x)
x = self.spat_conv(x)
x = self.bn(x)
x = x*x # conv_activate
x = self.pooling(x)
x = self.safe_log(x) # pool_activate
x = self.dropout(x)
x = self.class_conv(x)
x= self.softmax(x)
out = torch.squeeze(x)
return out
============================ Initialization parameters ============================
channels = 44
samples = 534
in_channel = 1
conv_channel_temp = 40
kernel_size_temp = 25
conv_channel_spat = 40
kernel_size_spat = channels
pooling_size = 75
pool_stride_size = 15
dropoutRate = 0.3
n_classes = 4
class_kernel_size = 30
def main():
input = torch.randn(32, 1, channels, samples)
model = Shallow_ConvNet(in_channel, conv_channel_temp, kernel_size_temp, conv_channel_spat, kernel_size_spat,
pooling_size, pool_stride_size, dropoutRate, n_classes, class_kernel_size)
out = model(input)
print('===============================================================')
print('out', out.shape)
print('model', model)
summary(model=model, input_size=(1,channels,samples), batch_size=32, device="cpu")
if name == "__main__":
main()