另类注意力机制之深度残差收缩网络(附代码)

简介: 深度残差收缩网络Deep Residual Shrinkage Network是一种较为新颖的深度神经网络,本质上是深度残差网络ResNet的一种改进版本,其目的是提高深度神经网络在强噪数据上的特征学习效果,其核心思想在于:在特征学习的过程中,剔除冗余信息也是很重要的。

深度残差收缩网络Deep Residual Shrinkage Network是一种较为新颖的深度神经网络,本质上是深度残差网络ResNet的一种改进版本,其目的是提高深度神经网络在强噪数据上的特征学习效果,其核心思想在于:在特征学习的过程中,剔除冗余信息也是很重要的

首先,我们来回顾一下深度残差网络。深度残差网络的基本模块如下图所示。相较于普通的卷积神经网络,深度残差网络引入了跨层的恒等连接,以降低模型训练的难度,提高准确率。
1

然后,相较于普通的深度残差网络,深度残差收缩网络引入了一个小型的子网络,用这个子网络学习得到一组阈值,继而对特征图的各个通道进行软阈值化。这个过程其实是一个可训练的特征选择的过程。具体而言,就是通过前面的两个卷积层Conv将重要的特征变换成绝对值较大的值,将冗余信息所对应的特征变换成绝对值较小的值;通过子网络学习得到二者之间的界限,并且通过软阈值化将冗余特征置为零,同时使重要的特征有着非零的输出;这样就实现了一个特征筛选的过程。
1

深度残差收缩网络其实是一种通用的深层特征学习方法,不仅可以用于含噪数据的特征学习,也可以用于不含噪声数据的特征学习。这是因为,深度残差收缩网络中的阈值是根据样本情况自适应确定的。换言之,如果样本中不含冗余信息、不需要软阈值化,那么阈值可以被训练得非常接近于零,从而软阈值化就相当于不存在了。

最后,堆叠许多基本模块,就可以得到完整的网络结构。
1

利用深度残差收缩网络进行MNIST手写数字的分类,可以看到,虽然没有添加噪声,效果还是挺好的。深度残差收缩网络的代码:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 26 07:46:00 2019

Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

@author: me
"""

import tflearn
import tensorflow as tf
from tflearn.layers.conv import conv_2d

# Data loading
from tflearn.datasets import mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1,28,28,1])
testX = testX.reshape([-1,28,28,1])

def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
    
    # residual shrinkage blocks with channel-wise thresholds

    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]

    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)

    with vscope as scope:
        name = scope.name #TODO

        for i in range(nb_blocks):

            identity = residual

            if not downsample:
                downsample_strides = 1

            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)

            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
            
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
            

            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)

            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels

            residual = residual + identity

    return residual


# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Construct a deep residual shrinkage network
net = tflearn.input_data(shape=[None, 28, 28, 1])
net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1,  8, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_mnist',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)

model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')
# Validation
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

接下来是深度残差网络ResNet的程序:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 26 07:46:00 2019

Implemented using TensorFlow 1.0 and TFLearn 0.3.2
K. He, X. Zhang, S. Ren, J. Sun, Deep Residual Learning for Image Recognition, CVPR, 2016.

@author: me
"""

import tflearn

# Data loading
from tflearn.datasets import mnist
X, Y, testX, testY = mnist.load_data(one_hot=True)
X = X.reshape([-1,28,28,1])
testX = testX.reshape([-1,28,28,1])

# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)

# Construct a deep residual network
net = tflearn.input_data(shape=[None, 28, 28, 1])
net = tflearn.conv_2d(net, 8, 3, regularizer='L2', weight_decay=0.0001)
net = tflearn.residual_block(net, 1,  8, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=40000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_mnist',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)

model.fit(X, Y, n_epoch=200, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_mnist')
# Validation
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

以上两个程序建立了只有1个基本模块的小型深度神经网络,MNIST图像数据中也没有添加任何噪声。训练和测试准确率如下表所示,可以看到,即使是对于不含噪声的数据,深度残差收缩网络的效果也是挺不错的:
1

参考文献:

M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898

https://ieeexplore.ieee.org/document/8850096

相关文章
|
2月前
|
安全 网络安全 数据安全/隐私保护
访问控制列表(ACL)是网络安全中的一种重要机制,用于定义和管理对网络资源的访问权限
访问控制列表(ACL)是网络安全中的一种重要机制,用于定义和管理对网络资源的访问权限。它通过设置一系列规则,控制谁可以访问特定资源、在什么条件下访问以及可以执行哪些操作。ACL 可以应用于路由器、防火墙等设备,分为标准、扩展、基于时间和基于用户等多种类型,广泛用于企业网络和互联网中,以增强安全性和精细管理。
312 7
|
3月前
|
机器学习/深度学习 数据可视化 测试技术
YOLO11实战:新颖的多尺度卷积注意力(MSCA)加在网络不同位置的涨点情况 | 创新点如何在自己数据集上高效涨点,解决不涨点掉点等问题
本文探讨了创新点在自定义数据集上表现不稳定的问题,分析了不同数据集和网络位置对创新效果的影响。通过在YOLO11的不同位置引入MSCAAttention模块,展示了三种不同的改进方案及其效果。实验结果显示,改进方案在mAP50指标上分别提升了至0.788、0.792和0.775。建议多尝试不同配置,找到最适合特定数据集的解决方案。
944 0
|
5月前
|
缓存 应用服务中间件 nginx
Web服务器的缓存机制与内容分发网络(CDN)
【8月更文第28天】随着互联网应用的发展,用户对网站响应速度的要求越来越高。为了提升用户体验,Web服务器通常会采用多种技术手段来优化页面加载速度,其中最重要的两种技术就是缓存机制和内容分发网络(CDN)。本文将深入探讨这两种技术的工作原理及其实现方法,并通过具体的代码示例加以说明。
526 1
|
2月前
|
机器学习/深度学习 计算机视觉 Python
【YOLOv11改进 - 注意力机制】SimAM:轻量级注意力机制,解锁卷积神经网络新潜力
【YOLOv11改进 - 注意力机制】SimAM:轻量级注意力机制,解锁卷积神经网络新潜力本文提出了一种简单且高效的卷积神经网络(ConvNets)注意力模块——SimAM。与现有模块不同,SimAM通过优化能量函数推断特征图的3D注意力权重,无需添加额外参数。SimAM基于空间抑制理论设计,通过简单的解决方案实现高效计算,提升卷积神经网络的表征能力。代码已在Pytorch-SimAM开源。
【YOLOv11改进 - 注意力机制】SimAM:轻量级注意力机制,解锁卷积神经网络新潜力
|
3月前
|
网络协议 Java 应用服务中间件
深入浅出Tomcat网络通信的高并发处理机制
【10月更文挑战第3天】本文详细解析了Tomcat在处理高并发网络请求时的机制,重点关注了其三种不同的IO模型:NioEndPoint、Nio2EndPoint 和 AprEndPoint。NioEndPoint 采用多路复用模型,通过 Acceptor 接收连接、Poller 监听事件及 Executor 处理请求;Nio2EndPoint 则使用 AIO 异步模型,通过回调函数处理连接和数据就绪事件;AprEndPoint 通过 JNI 调用本地库实现高性能,但已在 Tomcat 10 中弃用
深入浅出Tomcat网络通信的高并发处理机制
|
3月前
|
机器学习/深度学习 API 算法框架/工具
残差网络(ResNet) -深度学习(Residual Networks (ResNet) – Deep Learning)
残差网络(ResNet) -深度学习(Residual Networks (ResNet) – Deep Learning)
94 0
|
5月前
|
Java 网络安全 云计算
深入理解Java异常处理机制云计算与网络安全:技术挑战与应对策略
【8月更文挑战第27天】在Java编程的世界里,异常处理是维护程序健壮性的重要一环。本文将带你深入了解Java的异常处理机制,从基本的try-catch-finally结构到自定义异常类的设计,再到高级特性如try-with-resources和异常链的应用。通过具体代码示例,我们将探索如何优雅地管理错误和异常,确保你的程序即使在面对不可预见的情况时也能保持运行的稳定性。
|
6月前
|
机器学习/深度学习 计算机视觉
【YOLOv8改进 - 注意力机制】c2f结合CBAM:针对卷积神经网络(CNN)设计的新型注意力机制
【YOLOv8改进 - 注意力机制】c2f结合CBAM:针对卷积神经网络(CNN)设计的新型注意力机制
|
5月前
|
安全 网络安全 数据安全/隐私保护
|
6月前
|
机器学习/深度学习 计算机视觉
【YOLOv8改进 - 注意力机制】Gather-Excite : 提高网络捕获长距离特征交互的能力
【YOLOv8改进 - 注意力机制】Gather-Excite : 提高网络捕获长距离特征交互的能力