提交 f002083a 编写于 作者: 别团等shy哥发育's avatar 别团等shy哥发育

CBAM:通道注意力+空间注意力机制复现

上级 a3154681
......@@ -37,3 +37,4 @@
/经典网络/ShuffleNet/checkpoint_v2/
/经典网络/ResNext/checkpoint/
/经典网络/ECANet(通道注意力机制)/checkpoint/
/经典网络/CBAM(通道+空间注意力机制)/checkpoint/
import tensorflow as tf
from tensorflow.keras.layers import Activation, Add, Concatenate, Conv1D, Conv2D, Dense,multiply,Input
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Lambda, BatchNormalization,Reshape
from tensorflow.keras.layers import Multiply
from tensorflow.keras.models import Model
from plot_model import plot_model
# Channel Attention Module
def channelAttention(input_feature, ratio=16, name=""):
# 获取特征层的通道数
channel = input_feature.shape[-1]
shared_layer_one = Dense(channel // ratio,
activation='relu',
use_bias=False,
name="channel_attention_shared_one_" + str(name))
shared_layer_two = Dense(channel,
use_bias=False,
name="channel_attention_shared_two_" + str(name))
# 全局平均池化
avg_pool = GlobalAveragePooling2D()(input_feature)
# 全局最大池化
max_pool = GlobalMaxPooling2D()(input_feature)
avg_pool = Reshape((1, 1, channel))(avg_pool)
max_pool = Reshape((1, 1, channel))(max_pool)
avg_pool = shared_layer_one(avg_pool)
max_pool = shared_layer_one(max_pool)
avg_pool = shared_layer_two(avg_pool)
max_pool = shared_layer_two(max_pool)
# 相加
cbam_feature = Add()([avg_pool, max_pool])
# 获得输入特征层每一个通道的权值
cbam_feature = Activation('sigmoid')(cbam_feature)
# 将这个权值与原输入特征层相乘
out = Multiply()([input_feature, cbam_feature])
return out
# Spatial Attention Module
def spatialAttention(input_feature, kernel_size=7, name=""):
cbam_feature = input_feature
# 在通道维度上分别做最大池化和平均池化
avg_pool = tf.reduce_mean(input_feature, axis=3, keepdims=True)
max_pool = tf.reduce_max(input_feature, axis=3, keepdims=True)
concat = Concatenate(axis=3)([avg_pool, max_pool])
cbam_feature = Conv2D(filters=1,
kernel_size=kernel_size,
strides=1,
padding='same',
use_bias=False,
name="spatial_attention_" + str(name))(concat)
cbam_feature = Activation('sigmoid')(cbam_feature)
out = Multiply()([input_feature, cbam_feature])
return out
# CBAM Block
def cbamBlock(cbam_feature,ratio=16,name=""):
# 先通道注意力,再空间注意力,原论文中真名这种排列效果更好。
cbam_feature=channelAttention(cbam_feature,ratio,name)
cbam_feature=spatialAttention(cbam_feature,name)
return cbam_feature
if __name__ == '__main__':
inputs = Input([26, 26, 512])
x = channelAttention(inputs)
x = spatialAttention(x)
model = Model(inputs, x)
model.summary()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册