提交 fb631b38 编写于 作者: C cuicheng01

update esnet.py

上级 343c7d8f
......@@ -18,14 +18,13 @@ import paddle
from paddle import ParamAttr, reshape, transpose, concat, split
import paddle.nn as nn
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D
from paddle.nn.initializer import KaimingNormal
from paddle.regularizer import L2Decay
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"ESNet_x0_25":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ESNet_x0_25_pretrained.pdparams",
......@@ -60,14 +59,13 @@ def make_divisible(v, divisor=8, min_value=None):
class ConvBNLayer(TheseusLayer):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
if_act=True):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
if_act=True):
super().__init__()
self.conv = Conv2D(
in_channels=in_channels,
......@@ -93,7 +91,7 @@ class ConvBNLayer(TheseusLayer):
x = self.hardswish(x)
return x
class SEModule(TheseusLayer):
def __init__(self, channel, reduction=4):
super().__init__()
......@@ -121,13 +119,11 @@ class SEModule(TheseusLayer):
x = self.conv2(x)
x = self.hardsigmoid(x)
x = paddle.multiply(x=identity, y=x)
return x
return x
class ESBlock1(TheseusLayer):
def __init__(self,
in_channels,
out_channels):
def __init__(self, in_channels, out_channels):
super().__init__()
self.pw_1_1 = ConvBNLayer(
in_channels=in_channels // 2,
......@@ -151,9 +147,7 @@ class ESBlock1(TheseusLayer):
def forward(self, x):
x1, x2 = split(
x,
num_or_sections=[x.shape[1] // 2, x.shape[1] // 2],
axis=1)
x, num_or_sections=[x.shape[1] // 2, x.shape[1] // 2], axis=1)
x2 = self.pw_1_1(x2)
x3 = self.dw_1(x2)
x3 = concat([x2, x3], axis=1)
......@@ -164,9 +158,7 @@ class ESBlock1(TheseusLayer):
class ESBlock2(TheseusLayer):
def __init__(self,
in_channels,
out_channels):
def __init__(self, in_channels, out_channels):
super().__init__()
# branch1
......@@ -205,9 +197,7 @@ class ESBlock2(TheseusLayer):
kernel_size=3,
groups=out_channels)
self.concat_pw = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1)
in_channels=out_channels, out_channels=out_channels, kernel_size=1)
def forward(self, x):
x1 = self.dw_1(x)
......@@ -223,13 +213,20 @@ class ESBlock2(TheseusLayer):
class ESNet(TheseusLayer):
def __init__(self, class_num=1000, scale=1.0, dropout_prob=0.2, class_expand=1280):
def __init__(self,
class_num=1000,
scale=1.0,
dropout_prob=0.2,
class_expand=1280):
super().__init__()
self.scale = scale
self.class_num = class_num
self.class_expand = class_expand
stage_repeats = [3, 7, 3]
stage_out_channels = [-1, 24, make_divisible(116*scale), make_divisible(232*scale), make_divisible(464*scale), 1024]
stage_out_channels = [
-1, 24, make_divisible(116 * scale), make_divisible(232 * scale),
make_divisible(464 * scale), 1024
]
self.conv1 = ConvBNLayer(
in_channels=3,
......@@ -240,18 +237,18 @@ class ESNet(TheseusLayer):
block_list = []
for stage_id, num_repeat in enumerate(stage_repeats):
for i in range(num_repeat):
for i in range(num_repeat):
if i == 0:
block = ESBlock2(
in_channels=stage_out_channels[stage_id + 1],
out_channels=stage_out_channels[stage_id + 2])
in_channels=stage_out_channels[stage_id + 1],
out_channels=stage_out_channels[stage_id + 2])
else:
block = ESBlock1(
in_channels=stage_out_channels[stage_id + 2],
out_channels=stage_out_channels[stage_id + 2])
in_channels=stage_out_channels[stage_id + 2],
out_channels=stage_out_channels[stage_id + 2])
block_list.append(block)
self.blocks = nn.Sequential(*block_list)
self.conv2 = ConvBNLayer(
in_channels=stage_out_channels[-2],
out_channels=stage_out_channels[-1],
......@@ -282,9 +279,9 @@ class ESNet(TheseusLayer):
x = self.dropout(x)
x = self.flatten(x)
x = self.fc(x)
return x
return x
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
......@@ -295,8 +292,8 @@ def _load_pretrained(pretrained, model, model_url, use_ssld):
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
)
def ESNet_x0_25(pretrained=False, use_ssld=False, **kwargs):
"""
......@@ -310,9 +307,9 @@ def ESNet_x0_25(pretrained=False, use_ssld=False, **kwargs):
"""
model = ESNet(scale=0.25, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x0_25"], use_ssld)
return model
return model
def ESNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
"""
ESNet_x0_5
......@@ -356,4 +353,3 @@ def ESNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
model = ESNet(scale=1.0, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ESNet_x1_0"], use_ssld)
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册