提交 69ff4ed9 编写于 作者: 别团等shy哥发育's avatar 别团等shy哥发育

RegNet中的ZBlock实现。

上级 9f1687df
......@@ -198,10 +198,10 @@ def Stem(x, name=None):
# SE注意力机制模块
def SqueezeAndExciteBlock(inputs, filters_in, se_filters, name=None):
# 获得通道数
channel=inputs.shape[-1]
channel = inputs.shape[-1]
x = layers.GlobalAveragePooling2D(name=name + '_squeeze_and_excite_gap')(inputs)
x=layers.Reshape((1,1,channel))(x)
x = layers.Reshape((1, 1, channel))(x)
# 两个全连接层(目前看到的所有源码都是使用两个1x1卷积层代替)
x = layers.Conv2D(filters=se_filters,
......@@ -330,6 +330,63 @@ def YBlock(inputs,
x = layers.ReLU(name=name + "_exit_relu")(x + skip)
return x
# ZBlock实现:
def ZBlock(inputs,
filters_in,
filters_out,
group_width,
stride=1,
squeeze_excite_ratio=0.25, # SE中Squeeze 和 Excite 块的膨胀率
bottleneck_ratio=0.25, # bottlenect比率,用来改变网络宽度(通道数)
name=None):
groups = filters_out // group_width
se_filters = int(filters_in * squeeze_excite_ratio)
inv_btlneck_filters = int(filters_out / bottleneck_ratio)
# Build block
# conv_1x1_1
x = layers.Conv2D(inv_btlneck_filters,
(1, 1),
use_bias=False,
kernel_initializer='he_normal',
name=name + '_conv_1x1_1')(inputs)
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=name + '_conv_1x1_1_bn')(x)
# 使用SiLU代替ReLU
x = tf.nn.silu(x)
# conv_3x3
x = layers.Conv2D(inv_btlneck_filters,
(3, 3),
use_bias=False,
groups=groups,
padding='same',
kernel_initializer='he_normal',
name=name + '_conv_3x3')(x)
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=name + '_conv_3x3_bn')(x)
x = tf.nn.silu(x)
# squeeze-Excitation block
x = SqueezeAndExciteBlock(x, inv_btlneck_filters, se_filters, name=name)
# conv_1x1_2
x = layers.Conv2D(filters_out,
(1, 1),
use_bias=False,
kernel_initializer='he_normal',
name=name + '_conv_1x1_2')(x)
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=name + '_conv_1x1_2_bn')(x)
# 最后的1x1卷积后面没有非线性
# stride=2的块没有残差边
if stride != 1:
return x
else:
x = layers.Add([x, inputs])
return x
def Stage(inputs,
block_type, # 必须是X、Y、Z之一
depth, # stage深度,要使用的块数
......@@ -339,7 +396,7 @@ def Stage(inputs,
name=None): # 名称前缀
x = inputs
if block_type == "X":
# 论文原话:Stage的第一个block的步长为2
# 论文原话:Stage的第一个block的步长为2,其他默认为1
x = XBlock(
x,
filters_in,
......@@ -349,11 +406,14 @@ def Stage(inputs,
name=f"{name}_XBlock_0")
for i in range(1, depth):
x = XBlock(x, filters_out, filters_out, group_width, name=f"{name}_XBlock_{i}")
elif block_type== "Y":
x=YBlock(x,filters_in,filters_out,group_width,stride=2,name=name+'_YBlock_0')
elif block_type == "Y":
x = YBlock(x, filters_in, filters_out, group_width, stride=2, name=name + '_YBlock_0')
for i in range(1, depth):
x = YBlock(x, filters_out, filters_out, group_width, name=f"{name}_YBlock_{i}")
elif block_type == "Z":
x=ZBlock(x,filters_in,filters_out,group_width,stride=2,name=f"{name}_ZBlock_{i}")
for i in range(1,depth):
x=YBlock(x,filters_out,filters_out,group_width,name=f"{name}_YBlock_{i}")
# TODO ZBlock
x=ZBlock(x,filters_out,filters_out,group_width,name=f"{name}_ZBlock_{i}")
return x
......@@ -477,6 +537,7 @@ def RegNetX004(model_name="regnetx004",
classes=classes,
classifier_activation=classifier_activation)
def RegNetY002(model_name="regnety002",
include_top=True,
include_preprocessing=True,
......@@ -486,21 +547,23 @@ def RegNetY002(model_name="regnety002",
pooling=None,
classes=1000,
classifier_activation="softmax"):
return RegNet(
MODEL_CONFIGS["y002"]["depths"],
MODEL_CONFIGS["y002"]["widths"],
MODEL_CONFIGS["y002"]["group_width"],
MODEL_CONFIGS["y002"]["block_type"],
MODEL_CONFIGS["y002"]["default_size"],
model_name=model_name,
include_top=include_top,
include_preprocessing=include_preprocessing,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation)
return RegNet(
MODEL_CONFIGS["y002"]["depths"],
MODEL_CONFIGS["y002"]["widths"],
MODEL_CONFIGS["y002"]["group_width"],
MODEL_CONFIGS["y002"]["block_type"],
MODEL_CONFIGS["y002"]["default_size"],
model_name=model_name,
include_top=include_top,
include_preprocessing=include_preprocessing,
weights=weights,
input_tensor=input_tensor,
input_shape=input_shape,
pooling=pooling,
classes=classes,
classifier_activation=classifier_activation)
if __name__ == '__main__':
# model = RegNetX002(input_shape=(224, 224, 3))
# model.summary()
......@@ -508,5 +571,5 @@ if __name__ == '__main__':
# model1=RegNetX004(input_shape=(224,224,3))
# model1.summary()
model2=RegNetY002(input_shape=(224,224,3))
model2 = RegNetY002(input_shape=(224, 224, 3))
model2.summary()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册