提交 69ff4ed9 编写于 作者: 别团等shy哥发育's avatar 别团等shy哥发育

RegNet中的ZBlock实现。

上级 9f1687df
...@@ -198,10 +198,10 @@ def Stem(x, name=None): ...@@ -198,10 +198,10 @@ def Stem(x, name=None):
# SE注意力机制模块 # SE注意力机制模块
def SqueezeAndExciteBlock(inputs, filters_in, se_filters, name=None): def SqueezeAndExciteBlock(inputs, filters_in, se_filters, name=None):
# 获得通道数 # 获得通道数
channel=inputs.shape[-1] channel = inputs.shape[-1]
x = layers.GlobalAveragePooling2D(name=name + '_squeeze_and_excite_gap')(inputs) x = layers.GlobalAveragePooling2D(name=name + '_squeeze_and_excite_gap')(inputs)
x=layers.Reshape((1,1,channel))(x) x = layers.Reshape((1, 1, channel))(x)
# 两个全连接层(目前看到的所有源码都是使用两个1x1卷积层代替) # 两个全连接层(目前看到的所有源码都是使用两个1x1卷积层代替)
x = layers.Conv2D(filters=se_filters, x = layers.Conv2D(filters=se_filters,
...@@ -330,6 +330,63 @@ def YBlock(inputs, ...@@ -330,6 +330,63 @@ def YBlock(inputs,
x = layers.ReLU(name=name + "_exit_relu")(x + skip) x = layers.ReLU(name=name + "_exit_relu")(x + skip)
return x return x
# ZBlock实现:
def ZBlock(inputs,
filters_in,
filters_out,
group_width,
stride=1,
squeeze_excite_ratio=0.25, # SE中Squeeze 和 Excite 块的膨胀率
bottleneck_ratio=0.25, # bottlenect比率,用来改变网络宽度(通道数)
name=None):
groups = filters_out // group_width
se_filters = int(filters_in * squeeze_excite_ratio)
inv_btlneck_filters = int(filters_out / bottleneck_ratio)
# Build block
# conv_1x1_1
x = layers.Conv2D(inv_btlneck_filters,
(1, 1),
use_bias=False,
kernel_initializer='he_normal',
name=name + '_conv_1x1_1')(inputs)
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=name + '_conv_1x1_1_bn')(x)
# 使用SiLU代替ReLU
x = tf.nn.silu(x)
# conv_3x3
x = layers.Conv2D(inv_btlneck_filters,
(3, 3),
use_bias=False,
groups=groups,
padding='same',
kernel_initializer='he_normal',
name=name + '_conv_3x3')(x)
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=name + '_conv_3x3_bn')(x)
x = tf.nn.silu(x)
# squeeze-Excitation block
x = SqueezeAndExciteBlock(x, inv_btlneck_filters, se_filters, name=name)
# conv_1x1_2
x = layers.Conv2D(filters_out,
(1, 1),
use_bias=False,
kernel_initializer='he_normal',
name=name + '_conv_1x1_2')(x)
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=name + '_conv_1x1_2_bn')(x)
# 最后的1x1卷积后面没有非线性
# stride=2的块没有残差边
if stride != 1:
return x
else:
x = layers.Add([x, inputs])
return x
def Stage(inputs, def Stage(inputs,
block_type, # 必须是X、Y、Z之一 block_type, # 必须是X、Y、Z之一
depth, # stage深度,要使用的块数 depth, # stage深度,要使用的块数
...@@ -339,7 +396,7 @@ def Stage(inputs, ...@@ -339,7 +396,7 @@ def Stage(inputs,
name=None): # 名称前缀 name=None): # 名称前缀
x = inputs x = inputs
if block_type == "X": if block_type == "X":
# 论文原话:Stage的第一个block的步长为2 # 论文原话:Stage的第一个block的步长为2,其他默认为1
x = XBlock( x = XBlock(
x, x,
filters_in, filters_in,
...@@ -349,11 +406,14 @@ def Stage(inputs, ...@@ -349,11 +406,14 @@ def Stage(inputs,
name=f"{name}_XBlock_0") name=f"{name}_XBlock_0")
for i in range(1, depth): for i in range(1, depth):
x = XBlock(x, filters_out, filters_out, group_width, name=f"{name}_XBlock_{i}") x = XBlock(x, filters_out, filters_out, group_width, name=f"{name}_XBlock_{i}")
elif block_type== "Y": elif block_type == "Y":
x=YBlock(x,filters_in,filters_out,group_width,stride=2,name=name+'_YBlock_0') x = YBlock(x, filters_in, filters_out, group_width, stride=2, name=name + '_YBlock_0')
for i in range(1, depth):
x = YBlock(x, filters_out, filters_out, group_width, name=f"{name}_YBlock_{i}")
elif block_type == "Z":
x=ZBlock(x,filters_in,filters_out,group_width,stride=2,name=f"{name}_ZBlock_{i}")
for i in range(1,depth): for i in range(1,depth):
x=YBlock(x,filters_out,filters_out,group_width,name=f"{name}_YBlock_{i}") x=ZBlock(x,filters_out,filters_out,group_width,name=f"{name}_ZBlock_{i}")
# TODO ZBlock
return x return x
...@@ -477,6 +537,7 @@ def RegNetX004(model_name="regnetx004", ...@@ -477,6 +537,7 @@ def RegNetX004(model_name="regnetx004",
classes=classes, classes=classes,
classifier_activation=classifier_activation) classifier_activation=classifier_activation)
def RegNetY002(model_name="regnety002", def RegNetY002(model_name="regnety002",
include_top=True, include_top=True,
include_preprocessing=True, include_preprocessing=True,
...@@ -486,21 +547,23 @@ def RegNetY002(model_name="regnety002", ...@@ -486,21 +547,23 @@ def RegNetY002(model_name="regnety002",
pooling=None, pooling=None,
classes=1000, classes=1000,
classifier_activation="softmax"): classifier_activation="softmax"):
return RegNet( return RegNet(
MODEL_CONFIGS["y002"]["depths"], MODEL_CONFIGS["y002"]["depths"],
MODEL_CONFIGS["y002"]["widths"], MODEL_CONFIGS["y002"]["widths"],
MODEL_CONFIGS["y002"]["group_width"], MODEL_CONFIGS["y002"]["group_width"],
MODEL_CONFIGS["y002"]["block_type"], MODEL_CONFIGS["y002"]["block_type"],
MODEL_CONFIGS["y002"]["default_size"], MODEL_CONFIGS["y002"]["default_size"],
model_name=model_name, model_name=model_name,
include_top=include_top, include_top=include_top,
include_preprocessing=include_preprocessing, include_preprocessing=include_preprocessing,
weights=weights, weights=weights,
input_tensor=input_tensor, input_tensor=input_tensor,
input_shape=input_shape, input_shape=input_shape,
pooling=pooling, pooling=pooling,
classes=classes, classes=classes,
classifier_activation=classifier_activation) classifier_activation=classifier_activation)
if __name__ == '__main__': if __name__ == '__main__':
# model = RegNetX002(input_shape=(224, 224, 3)) # model = RegNetX002(input_shape=(224, 224, 3))
# model.summary() # model.summary()
...@@ -508,5 +571,5 @@ if __name__ == '__main__': ...@@ -508,5 +571,5 @@ if __name__ == '__main__':
# model1=RegNetX004(input_shape=(224,224,3)) # model1=RegNetX004(input_shape=(224,224,3))
# model1.summary() # model1.summary()
model2=RegNetY002(input_shape=(224,224,3)) model2 = RegNetY002(input_shape=(224, 224, 3))
model2.summary() model2.summary()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册