未验证 提交 a40d348e 编写于 作者: L littletomatodonkey 提交者: GitHub

Update mobilenet_v3.py

上级 9ebbd78b
...@@ -21,7 +21,7 @@ import paddle ...@@ -21,7 +21,7 @@ import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.functional.activation import hard_sigmoid, hard_swish from paddle.nn.functional import hardswish, hardsigmoid
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
...@@ -64,15 +64,15 @@ class MobileNetV3(nn.Layer): ...@@ -64,15 +64,15 @@ class MobileNetV3(nn.Layer):
[5, 72, 40, True, "relu", 2], [5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1], [5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1], [5, 120, 40, True, "relu", 1],
[3, 240, 80, False, "hard_swish", 2], [3, 240, 80, False, "hardswish", 2],
[3, 200, 80, False, "hard_swish", 1], [3, 200, 80, False, "hardswish", 1],
[3, 184, 80, False, "hard_swish", 1], [3, 184, 80, False, "hardswish", 1],
[3, 184, 80, False, "hard_swish", 1], [3, 184, 80, False, "hardswish", 1],
[3, 480, 112, True, "hard_swish", 1], [3, 480, 112, True, "hardswish", 1],
[3, 672, 112, True, "hard_swish", 1], [3, 672, 112, True, "hardswish", 1],
[5, 672, 160, True, "hard_swish", 2], [5, 672, 160, True, "hardswish", 2],
[5, 960, 160, True, "hard_swish", 1], [5, 960, 160, True, "hardswish", 1],
[5, 960, 160, True, "hard_swish", 1], [5, 960, 160, True, "hardswish", 1],
] ]
self.cls_ch_squeeze = 960 self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280 self.cls_ch_expand = 1280
...@@ -82,14 +82,14 @@ class MobileNetV3(nn.Layer): ...@@ -82,14 +82,14 @@ class MobileNetV3(nn.Layer):
[3, 16, 16, True, "relu", 2], [3, 16, 16, True, "relu", 2],
[3, 72, 24, False, "relu", 2], [3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1], [3, 88, 24, False, "relu", 1],
[5, 96, 40, True, "hard_swish", 2], [5, 96, 40, True, "hardswish", 2],
[5, 240, 40, True, "hard_swish", 1], [5, 240, 40, True, "hardswish", 1],
[5, 240, 40, True, "hard_swish", 1], [5, 240, 40, True, "hardswish", 1],
[5, 120, 48, True, "hard_swish", 1], [5, 120, 48, True, "hardswish", 1],
[5, 144, 48, True, "hard_swish", 1], [5, 144, 48, True, "hardswish", 1],
[5, 288, 96, True, "hard_swish", 2], [5, 288, 96, True, "hardswish", 2],
[5, 576, 96, True, "hard_swish", 1], [5, 576, 96, True, "hardswish", 1],
[5, 576, 96, True, "hard_swish", 1], [5, 576, 96, True, "hardswish", 1],
] ]
self.cls_ch_squeeze = 576 self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280 self.cls_ch_expand = 1280
...@@ -105,7 +105,7 @@ class MobileNetV3(nn.Layer): ...@@ -105,7 +105,7 @@ class MobileNetV3(nn.Layer):
padding=1, padding=1,
num_groups=1, num_groups=1,
if_act=True, if_act=True,
act="hard_swish", act="hardswish",
name="conv1") name="conv1")
self.block_list = [] self.block_list = []
...@@ -135,7 +135,7 @@ class MobileNetV3(nn.Layer): ...@@ -135,7 +135,7 @@ class MobileNetV3(nn.Layer):
padding=0, padding=0,
num_groups=1, num_groups=1,
if_act=True, if_act=True,
act="hard_swish", act="hardswish",
name="conv_last") name="conv_last")
self.pool = AdaptiveAvgPool2D(1) self.pool = AdaptiveAvgPool2D(1)
...@@ -167,9 +167,9 @@ class MobileNetV3(nn.Layer): ...@@ -167,9 +167,9 @@ class MobileNetV3(nn.Layer):
x = self.pool(x) x = self.pool(x)
x = self.last_conv(x) x = self.last_conv(x)
x = hard_swish(x) x = hardswish(x)
x = self.dropout(x) x = self.dropout(x)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]]) x = paddle.flatten(x, start_axis=1, stop_axis=-1)
x = self.out(x) x = self.out(x)
return x return x
...@@ -215,8 +215,8 @@ class ConvBNLayer(nn.Layer): ...@@ -215,8 +215,8 @@ class ConvBNLayer(nn.Layer):
if self.if_act: if self.if_act:
if self.act == "relu": if self.act == "relu":
x = F.relu(x) x = F.relu(x)
elif self.act == "hard_swish": elif self.act == "hardswish":
x = hard_swish(x) x = hardswish(x)
else: else:
print("The activation function is selected incorrectly.") print("The activation function is selected incorrectly.")
exit() exit()
...@@ -305,7 +305,7 @@ class SEModule(nn.Layer): ...@@ -305,7 +305,7 @@ class SEModule(nn.Layer):
outputs = self.conv1(outputs) outputs = self.conv1(outputs)
outputs = F.relu(outputs) outputs = F.relu(outputs)
outputs = self.conv2(outputs) outputs = self.conv2(outputs)
outputs = hard_sigmoid(outputs) outputs = hardsigmoid(outputs, slope=0.2, offset=0.5)
return paddle.multiply(x=inputs, y=outputs) return paddle.multiply(x=inputs, y=outputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册