AI-DSW 上编辑嵌套式模型实现Resnet手势识别
AI-DSW(Data science workshop)是专门为算法开发者准备的云端深度学习开发环境,
进入DSW,目前只有KerasCode和KerasGraph两个Kernel实现了FastNeuralNetwork功能。
- KerasCode:先写深度学习网络代码,然后将代码转成图
- KerasGraph:直接通过画布构建深度学习网络,并且将图转成代码
接下来我们通过实现Resnet18实现手势识别为例,体验AI-DSW的使用
我们的任务为,手语英文字母数据集中包含用手语表示的26个英文字母的信息,我们通过建立ResNet18模型进行手语英文字母识别
在AI-DSW 的官方文档中推荐我们采用序贯式(sequential)的方式构建模型,但是嵌套式封装来构建模型可以使结构更清晰,一些内容可以复用,我们来具体看下代码:
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same'):
x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides)(x)
x = BatchNormalization(axis=3)(x)
x = Activation('relu')(x)
return x
首先我们将最常见的CNN模块封装,包括卷积,BN,激活函数;用于Resnet模型的复用;
def identity_Block(inpt, nb_filter, kernel_size, strides=(1, 1), with_conv_shortcut=False):
x = Conv2d_BN(inpt, nb_filter=nb_filter, kernel_size=kernel_size, strides=strides, padding='same')
x = Conv2d_BN(x, nb_filter=nb_filter, kernel_size=kernel_size, padding='same')
if with_conv_shortcut:#shortcut的含义是:将输入层x与最后的输出层y进行连接,如上图所示
shortcut = Conv2d_BN(inpt, nb_filter=nb_filter, strides=strides, kernel_size=kernel_size)
x = add([x, shortcut])
return x
else:
x = add([x, inpt])
return x
接下来我们实现Resnet用于Residual Block的模块,即残差块,基于残差块可以有效提升网络性能,提升模型泛化能力,如图所示:
有了核心模块后,我们可着手搭建模型的核心结构,包括输入,卷积,残差,池化,全连接,输出等一系列步骤
def resnet_18(width,height,channel,classes):
inpt = Input(shape=(width, height, channel))
# x = ZeroPadding2D((3, 3))(inpt)
#conv1
x = Conv2d_BN(inpt, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
#conv2_x
x = identity_Block(x, nb_filter=64, kernel_size=(3, 3))
x = identity_Block(x, nb_filter=64, kernel_size=(3, 3))
#conv3_x
x = identity_Block(x, nb_filter=128, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
x = identity_Block(x, nb_filter=128, kernel_size=(3, 3))
#conv4_x
x = identity_Block(x, nb_filter=256, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
x = identity_Block(x, nb_filter=256, kernel_size=(3, 3))
#conv5_x
x = identity_Block(x, nb_filter=512, kernel_size=(3, 3), strides=(2, 2), with_conv_shortcut=True)
x = identity_Block(x, nb_filter=512, kernel_size=(3, 3))
x = GlobalAvgPool2D()(x)
x = Dense(classes, activation='softmax')(x)
model = Model(inputs=inpt, outputs=x)
return model
基于嵌套式策略同样可以做生成模型结构,如图所示:
同样的,我们按照官方文档介绍的,也可做模型可视化编辑,调整参数等
有了模型后,我们定义损失函数,加入训练集验证集来训练优化模型,最终得到结果。
综上,体验了KerasGraph后,个人感觉它代表了最新的ai开发环境演进方向——类似轻代码(low code)编辑器,可以快速构建模型结构并验证模型效果,提升了我们对模型结构的实现效率,避免纠结与在TF过于繁琐的源码,而是Focus在模型结构优化本身,总体来说还是不错的。
当然KerasGraph当前使用也存在一些问题:
- 暂不支持各类预训练模型,比如keras_bert,resnet这些,不过在支持了预训练模型,甚至支持对预训练模型最后几层做编辑,将大大提升实用性
- KerasGraph图形化界面前端占用过多内存,有的时候会导致页面卡塞
- KerasGraph对于各层参数编辑和定义易用性还需要提升,目前并不比查阅文档方便多少
当然这不妨碍KerasGraph已经是个较为出色的模型展示工具,我也相信假以时日KerasGraph在模型编辑上取得突破