Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
8cf54a47
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8cf54a47
编写于
9月 23, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update xception_deeplab and mobilenetv3 to 2.0 beta
上级
f7e5320e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
139 addition
and
189 deletion
+139
-189
dygraph/configs/deeplabv3p/deeplabv3p_mobilenetv3_large_cityscapes_769x769_160k.yml
.../deeplabv3p_mobilenetv3_large_cityscapes_769x769_160k.yml
+13
-0
dygraph/configs/deeplabv3p/deeplabv3p_xception65_cityscapes_769x769_160k.yml
...plabv3p/deeplabv3p_xception65_cityscapes_769x769_160k.yml
+13
-0
dygraph/paddleseg/models/backbones/mobilenetv3.py
dygraph/paddleseg/models/backbones/mobilenetv3.py
+80
-111
dygraph/paddleseg/models/backbones/xception_deeplab.py
dygraph/paddleseg/models/backbones/xception_deeplab.py
+33
-78
未找到文件。
dygraph/configs/deeplabv3p/deeplabv3p_mobilenetv3_large_cityscapes_769x769_160k.yml
0 → 100644
浏览文件 @
8cf54a47
_base_
:
'
../_base_/cityscapes.yml'
model
:
type
:
DeepLabV3
backbone
:
type
:
MobileNetV3_small_x1_0
pretrained
:
Null
num_classes
:
19
pretrained
:
Null
backbone_indices
:
[
0
,
3
]
optimizer
:
weight_decay
:
0.00004
dygraph/configs/deeplabv3p/deeplabv3p_xception65_cityscapes_769x769_160k.yml
0 → 100644
浏览文件 @
8cf54a47
_base_
:
'
../_base_/cityscapes.yml'
model
:
type
:
DeepLabV3
backbone
:
type
:
Xception65_deeplab
pretrained
:
Null
num_classes
:
19
pretrained
:
Null
backbone_indices
:
[
0
,
1
]
optimizer
:
weight_decay
:
0.00004
dygraph/paddleseg/models/backbones/mobilenetv3.py
浏览文件 @
8cf54a47
...
@@ -21,13 +21,14 @@ import os
...
@@ -21,13 +21,14 @@ import os
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.nn
as
nn
from
paddle.fluid.param_attr
import
ParamAttr
import
paddle.nn.functional
as
F
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.nn
import
Conv2d
,
AdaptiveAvgPool2d
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
,
Dropout
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddle.regularizer
import
L2Decay
from
paddle
import
ParamAttr
from
paddleseg.models.common
import
layer_libs
from
paddleseg.models.common
import
layer_libs
,
activation
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.utils
import
utils
from
paddleseg.utils
import
utils
...
@@ -71,9 +72,9 @@ def get_padding_same(kernel_size, dilation_rate):
...
@@ -71,9 +72,9 @@ def get_padding_same(kernel_size, dilation_rate):
return
padding_same
return
padding_same
class
MobileNetV3
(
fluid
.
dygraph
.
Layer
):
class
MobileNetV3
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
backbone_
pretrained
=
None
,
pretrained
=
None
,
scale
=
1.0
,
scale
=
1.0
,
model_name
=
"small"
,
model_name
=
"small"
,
class_dim
=
1000
,
class_dim
=
1000
,
...
@@ -103,6 +104,9 @@ class MobileNetV3(fluid.dygraph.Layer):
...
@@ -103,6 +104,9 @@ class MobileNetV3(fluid.dygraph.Layer):
1
],
# output 3 -> out_index=14
1
],
# output 3 -> out_index=14
]
]
self
.
out_indices
=
[
2
,
5
,
11
,
14
]
self
.
out_indices
=
[
2
,
5
,
11
,
14
]
self
.
feat_channels
=
[
make_divisible
(
i
*
scale
)
for
i
in
[
24
,
40
,
112
,
160
]
]
self
.
cls_ch_squeeze
=
960
self
.
cls_ch_squeeze
=
960
self
.
cls_ch_expand
=
1280
self
.
cls_ch_expand
=
1280
...
@@ -122,6 +126,9 @@ class MobileNetV3(fluid.dygraph.Layer):
...
@@ -122,6 +126,9 @@ class MobileNetV3(fluid.dygraph.Layer):
[
5
,
576
,
96
,
True
,
"hard_swish"
,
1
],
# output 4 -> out_index=10
[
5
,
576
,
96
,
True
,
"hard_swish"
,
1
],
# output 4 -> out_index=10
]
]
self
.
out_indices
=
[
0
,
3
,
7
,
10
]
self
.
out_indices
=
[
0
,
3
,
7
,
10
]
self
.
feat_channels
=
[
make_divisible
(
i
*
scale
)
for
i
in
[
16
,
24
,
48
,
96
]
]
self
.
cls_ch_squeeze
=
576
self
.
cls_ch_squeeze
=
576
self
.
cls_ch_expand
=
1280
self
.
cls_ch_expand
=
1280
...
@@ -169,37 +176,33 @@ class MobileNetV3(fluid.dygraph.Layer):
...
@@ -169,37 +176,33 @@ class MobileNetV3(fluid.dygraph.Layer):
sublayer
=
self
.
block_list
[
-
1
],
name
=
"conv"
+
str
(
i
+
2
))
sublayer
=
self
.
block_list
[
-
1
],
name
=
"conv"
+
str
(
i
+
2
))
inplanes
=
make_divisible
(
scale
*
c
)
inplanes
=
make_divisible
(
scale
*
c
)
self
.
last_second_conv
=
ConvBNLayer
(
# self.last_second_conv = ConvBNLayer(
in_c
=
inplanes
,
# in_c=inplanes,
out_c
=
make_divisible
(
scale
*
self
.
cls_ch_squeeze
),
# out_c=make_divisible(scale * self.cls_ch_squeeze),
filter_size
=
1
,
# filter_size=1,
stride
=
1
,
# stride=1,
padding
=
0
,
# padding=0,
num_groups
=
1
,
# num_groups=1,
if_act
=
True
,
# if_act=True,
act
=
"hard_swish"
,
# act="hard_swish",
name
=
"conv_last"
)
# name="conv_last")
self
.
pool
=
Pool2D
(
# self.pool = Pool2D(
pool_type
=
"avg"
,
global_pooling
=
True
,
use_cudnn
=
False
)
# pool_type="avg", global_pooling=True, use_cudnn=False)
self
.
last_conv
=
Conv2D
(
# self.last_conv = Conv2d(
num_channels
=
make_divisible
(
scale
*
self
.
cls_ch_squeeze
),
# in_channels=make_divisible(scale * self.cls_ch_squeeze),
num_filters
=
self
.
cls_ch_expand
,
# out_channels=self.cls_ch_expand,
filter_size
=
1
,
# kernel_size=1,
stride
=
1
,
# stride=1,
padding
=
0
,
# padding=0,
act
=
None
,
# bias_attr=False)
param_attr
=
ParamAttr
(
name
=
"last_1x1_conv_weights"
),
bias_attr
=
False
)
# self.out = Linear(
# input_dim=self.cls_ch_expand,
self
.
out
=
Linear
(
# output_dim=class_dim)
input_dim
=
self
.
cls_ch_expand
,
output_dim
=
class_dim
,
utils
.
load_pretrained_model
(
self
,
pretrained
)
param_attr
=
ParamAttr
(
"fc_weights"
),
bias_attr
=
ParamAttr
(
name
=
"fc_offset"
))
self
.
init_weight
(
backbone_pretrained
)
def
modify_bottle_params
(
self
,
output_stride
=
None
):
def
modify_bottle_params
(
self
,
output_stride
=
None
):
...
@@ -216,7 +219,7 @@ class MobileNetV3(fluid.dygraph.Layer):
...
@@ -216,7 +219,7 @@ class MobileNetV3(fluid.dygraph.Layer):
self
.
dilation_cfg
[
i
]
=
rate
self
.
dilation_cfg
[
i
]
=
rate
def
forward
(
self
,
inputs
,
label
=
None
,
dropout_prob
=
0.2
):
def
forward
(
self
,
inputs
,
label
=
None
):
x
=
self
.
conv1
(
inputs
)
x
=
self
.
conv1
(
inputs
)
# A feature list saves each downsampling feature.
# A feature list saves each downsampling feature.
feat_list
=
[]
feat_list
=
[]
...
@@ -225,31 +228,18 @@ class MobileNetV3(fluid.dygraph.Layer):
...
@@ -225,31 +228,18 @@ class MobileNetV3(fluid.dygraph.Layer):
if
i
in
self
.
out_indices
:
if
i
in
self
.
out_indices
:
feat_list
.
append
(
x
)
feat_list
.
append
(
x
)
#print("block {}:".format(i),x.shape, self.dilation_cfg[i])
#print("block {}:".format(i),x.shape, self.dilation_cfg[i])
x
=
self
.
last_second_conv
(
x
)
# x = self.last_second_conv(x)
x
=
self
.
pool
(
x
)
# x = self.pool(x)
x
=
self
.
last_conv
(
x
)
# x = self.last_conv(x)
x
=
fluid
.
layers
.
hard_swish
(
x
)
# x = F.hard_swish(x)
x
=
fluid
.
layers
.
dropout
(
x
=
x
,
dropout_prob
=
dropout_prob
)
# x = F.dropout(x=x, dropout_prob=dropout_prob)
x
=
fluid
.
layers
.
reshape
(
x
,
shape
=
[
x
.
shape
[
0
],
x
.
shape
[
1
]])
# x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x
=
self
.
out
(
x
)
# x = self.out(x)
return
x
,
feat_list
return
feat_list
def
init_weight
(
self
,
pretrained_model
=
None
):
"""
class
ConvBNLayer
(
nn
.
Layer
):
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if
pretrained_model
is
not
None
:
if
os
.
path
.
exists
(
pretrained_model
):
utils
.
load_pretrained_model
(
self
,
pretrained_model
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained_model
))
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
in_c
,
in_c
,
out_c
,
out_c
,
...
@@ -266,46 +256,31 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -266,46 +256,31 @@ class ConvBNLayer(fluid.dygraph.Layer):
self
.
if_act
=
if_act
self
.
if_act
=
if_act
self
.
act
=
act
self
.
act
=
act
self
.
conv
=
fluid
.
dygraph
.
Conv2D
(
self
.
conv
=
Conv2d
(
num
_channels
=
in_c
,
in
_channels
=
in_c
,
num_filter
s
=
out_c
,
out_channel
s
=
out_c
,
filter
_size
=
filter_size
,
kernel
_size
=
filter_size
,
stride
=
stride
,
stride
=
stride
,
padding
=
padding
,
padding
=
padding
,
dilation
=
dilation
,
dilation
=
dilation
,
groups
=
num_groups
,
groups
=
num_groups
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
bias_attr
=
False
,
use_cudnn
=
use_cudnn
,
act
=
None
)
self
.
bn
=
BatchNorm
(
self
.
bn
=
BatchNorm
(
num_features
=
out_c
,
num_features
=
out_c
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
name
=
name
+
"_bn_scale"
,
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)),
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)))
self
.
_act_op
=
layer_utils
.
Activation
(
act
=
None
)
self
.
_act_op
=
activation
.
Activation
(
act
=
None
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
if_act
:
if
self
.
if_act
:
if
self
.
act
==
"relu"
:
x
=
self
.
_act_op
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
elif
self
.
act
==
"hard_swish"
:
x
=
fluid
.
layers
.
hard_swish
(
x
)
else
:
print
(
"The activation function is selected incorrectly."
)
exit
()
return
x
return
x
class
ResidualUnit
(
fluid
.
dygraph
.
Layer
):
class
ResidualUnit
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
in_c
,
in_c
,
mid_c
,
mid_c
,
...
@@ -363,40 +338,34 @@ class ResidualUnit(fluid.dygraph.Layer):
...
@@ -363,40 +338,34 @@ class ResidualUnit(fluid.dygraph.Layer):
x
=
self
.
mid_se
(
x
)
x
=
self
.
mid_se
(
x
)
x
=
self
.
linear_conv
(
x
)
x
=
self
.
linear_conv
(
x
)
if
self
.
if_shortcut
:
if
self
.
if_shortcut
:
x
=
fluid
.
layers
.
elementwise_add
(
inputs
,
x
)
x
=
inputs
+
x
return
x
return
x
class
SEModule
(
fluid
.
dygraph
.
Layer
):
class
SEModule
(
nn
.
Layer
):
def
__init__
(
self
,
channel
,
reduction
=
4
,
name
=
""
):
def
__init__
(
self
,
channel
,
reduction
=
4
,
name
=
""
):
super
(
SEModule
,
self
).
__init__
()
super
(
SEModule
,
self
).
__init__
()
self
.
avg_pool
=
fluid
.
dygraph
.
Pool2D
(
self
.
avg_pool
=
AdaptiveAvgPool2d
(
1
)
pool_type
=
"avg"
,
global_pooling
=
True
,
use_cudnn
=
False
)
self
.
conv1
=
Conv2d
(
self
.
conv1
=
fluid
.
dygraph
.
Conv2D
(
in_channels
=
channel
,
num_channels
=
channel
,
out_channels
=
channel
//
reduction
,
num_filters
=
channel
//
reduction
,
kernel_size
=
1
,
filter_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
)
act
=
"relu"
,
self
.
conv2
=
Conv2d
(
param_attr
=
ParamAttr
(
name
=
name
+
"_1_weights"
),
in_channels
=
channel
//
reduction
,
bias_attr
=
ParamAttr
(
name
=
name
+
"_1_offset"
))
out_channels
=
channel
,
self
.
conv2
=
fluid
.
dygraph
.
Conv2D
(
kernel_size
=
1
,
num_channels
=
channel
//
reduction
,
num_filters
=
channel
,
filter_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
)
act
=
None
,
param_attr
=
ParamAttr
(
name
+
"_2_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_2_offset"
))
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
outputs
=
self
.
avg_pool
(
inputs
)
outputs
=
self
.
avg_pool
(
inputs
)
outputs
=
self
.
conv1
(
outputs
)
outputs
=
self
.
conv1
(
outputs
)
outputs
=
F
.
relu
(
outputs
)
outputs
=
self
.
conv2
(
outputs
)
outputs
=
self
.
conv2
(
outputs
)
outputs
=
fluid
.
layers
.
hard_sigmoid
(
outputs
)
outputs
=
F
.
hard_sigmoid
(
outputs
)
return
fluid
.
layers
.
elementwise_mul
(
x
=
inputs
,
y
=
outputs
,
axis
=
0
)
return
paddle
.
multiply
(
x
=
inputs
,
y
=
outputs
,
axis
=
0
)
def
MobileNetV3_small_x0_35
(
**
kwargs
):
def
MobileNetV3_small_x0_35
(
**
kwargs
):
...
...
dygraph/paddleseg/models/backbones/xception_deeplab.py
浏览文件 @
8cf54a47
...
@@ -15,13 +15,12 @@
...
@@ -15,13 +15,12 @@
import
os
import
os
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.nn
as
nn
from
paddle.fluid.param_attr
import
ParamAttr
import
paddle.nn.functional
as
F
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.nn
import
Conv2d
,
Linear
,
Dropout
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
,
Dropout
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
paddleseg.models.common
import
layer_libs
from
paddleseg.models.common
import
layer_libs
,
activation
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.utils
import
utils
from
paddleseg.utils
import
utils
...
@@ -78,7 +77,7 @@ def gen_bottleneck_params(backbone='xception_65'):
...
@@ -78,7 +77,7 @@ def gen_bottleneck_params(backbone='xception_65'):
return
bottleneck_params
return
bottleneck_params
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
input_channels
,
input_channels
,
output_channels
,
output_channels
,
...
@@ -89,29 +88,24 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -89,29 +88,24 @@ class ConvBNLayer(fluid.dygraph.Layer):
name
=
None
):
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2
D
(
self
.
_conv
=
Conv2
d
(
num
_channels
=
input_channels
,
in
_channels
=
input_channels
,
num_filter
s
=
output_channels
,
out_channel
s
=
output_channels
,
filter
_size
=
filter_size
,
kernel
_size
=
filter_size
,
stride
=
stride
,
stride
=
stride
,
padding
=
padding
,
padding
=
padding
,
param_attr
=
ParamAttr
(
name
=
name
+
"/weights"
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_bn
=
BatchNorm
(
self
.
_bn
=
BatchNorm
(
num_features
=
output_channels
,
num_features
=
output_channels
,
epsilon
=
1e-3
,
momentum
=
0.99
)
epsilon
=
1e-3
,
momentum
=
0.99
,
weight_attr
=
ParamAttr
(
name
=
name
+
"/BatchNorm/gamma"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/BatchNorm/beta"
))
self
.
_act_op
=
layer_utils
.
Activation
(
act
=
act
)
self
.
_act_op
=
activation
.
Activation
(
act
=
act
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
return
self
.
_act_op
(
self
.
_bn
(
self
.
_conv
(
inputs
)))
return
self
.
_act_op
(
self
.
_bn
(
self
.
_conv
(
inputs
)))
class
Seperate_Conv
(
fluid
.
dygraph
.
Layer
):
class
Seperate_Conv
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
input_channels
,
input_channels
,
output_channels
,
output_channels
,
...
@@ -122,42 +116,30 @@ class Seperate_Conv(fluid.dygraph.Layer):
...
@@ -122,42 +116,30 @@ class Seperate_Conv(fluid.dygraph.Layer):
name
=
None
):
name
=
None
):
super
(
Seperate_Conv
,
self
).
__init__
()
super
(
Seperate_Conv
,
self
).
__init__
()
self
.
_conv1
=
Conv2
D
(
self
.
_conv1
=
Conv2
d
(
num
_channels
=
input_channels
,
in
_channels
=
input_channels
,
num_filter
s
=
input_channels
,
out_channel
s
=
input_channels
,
filter
_size
=
filter
,
kernel
_size
=
filter
,
stride
=
stride
,
stride
=
stride
,
groups
=
input_channels
,
groups
=
input_channels
,
padding
=
(
filter
)
//
2
*
dilation
,
padding
=
(
filter
)
//
2
*
dilation
,
dilation
=
dilation
,
dilation
=
dilation
,
param_attr
=
ParamAttr
(
name
=
name
+
"/depthwise/weights"
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_bn1
=
BatchNorm
(
self
.
_bn1
=
BatchNorm
(
input_channels
,
epsilon
=
1e-3
,
momentum
=
0.99
)
input_channels
,
epsilon
=
1e-3
,
momentum
=
0.99
,
weight_attr
=
ParamAttr
(
name
=
name
+
"/depthwise/BatchNorm/gamma"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/depthwise/BatchNorm/beta"
))
self
.
_act_op1
=
layer_utils
.
Activation
(
act
=
act
)
self
.
_act_op1
=
activation
.
Activation
(
act
=
act
)
self
.
_conv2
=
Conv2
D
(
self
.
_conv2
=
Conv2
d
(
input_channels
,
input_channels
,
output_channels
,
output_channels
,
1
,
1
,
stride
=
1
,
stride
=
1
,
groups
=
1
,
groups
=
1
,
padding
=
0
,
padding
=
0
,
param_attr
=
ParamAttr
(
name
=
name
+
"/pointwise/weights"
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_bn2
=
BatchNorm
(
self
.
_bn2
=
BatchNorm
(
output_channels
,
epsilon
=
1e-3
,
momentum
=
0.99
)
output_channels
,
epsilon
=
1e-3
,
momentum
=
0.99
,
weight_attr
=
ParamAttr
(
name
=
name
+
"/pointwise/BatchNorm/gamma"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/pointwise/BatchNorm/beta"
))
self
.
_act_op2
=
layer_utils
.
Activation
(
act
=
act
)
self
.
_act_op2
=
activation
.
Activation
(
act
=
act
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_conv1
(
inputs
)
...
@@ -169,7 +151,7 @@ class Seperate_Conv(fluid.dygraph.Layer):
...
@@ -169,7 +151,7 @@ class Seperate_Conv(fluid.dygraph.Layer):
return
x
return
x
class
Xception_Block
(
fluid
.
dygraph
.
Layer
):
class
Xception_Block
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
input_channels
,
input_channels
,
output_channels
,
output_channels
,
...
@@ -248,13 +230,12 @@ class Xception_Block(fluid.dygraph.Layer):
...
@@ -248,13 +230,12 @@ class Xception_Block(fluid.dygraph.Layer):
name
=
name
+
"/shortcut"
)
name
=
name
+
"/shortcut"
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
layer_helper
=
LayerHelper
(
self
.
full_name
(),
act
=
'relu'
)
if
not
self
.
activation_fn_in_separable_conv
:
if
not
self
.
activation_fn_in_separable_conv
:
x
=
layer_helper
.
append_activation
(
inputs
)
x
=
F
.
relu
(
inputs
)
x
=
self
.
_conv1
(
x
)
x
=
self
.
_conv1
(
x
)
x
=
layer_helper
.
append_activation
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
layer_helper
.
append_activation
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
_conv3
(
x
)
x
=
self
.
_conv3
(
x
)
else
:
else
:
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_conv1
(
inputs
)
...
@@ -266,16 +247,16 @@ class Xception_Block(fluid.dygraph.Layer):
...
@@ -266,16 +247,16 @@ class Xception_Block(fluid.dygraph.Layer):
skip
=
self
.
_short
(
inputs
)
skip
=
self
.
_short
(
inputs
)
else
:
else
:
skip
=
inputs
skip
=
inputs
return
fluid
.
layers
.
elementwise_add
(
x
,
skip
)
return
x
+
skip
class
XceptionDeeplab
(
fluid
.
dygraph
.
Layer
):
class
XceptionDeeplab
(
nn
.
Layer
):
#def __init__(self, backbone, class_dim=1000):
#def __init__(self, backbone, class_dim=1000):
# add output_stride
# add output_stride
def
__init__
(
self
,
def
__init__
(
self
,
backbone
,
backbone
,
backbone_
pretrained
=
None
,
pretrained
=
None
,
output_stride
=
16
,
output_stride
=
16
,
class_dim
=
1000
):
class_dim
=
1000
):
...
@@ -283,6 +264,7 @@ class XceptionDeeplab(fluid.dygraph.Layer):
...
@@ -283,6 +264,7 @@ class XceptionDeeplab(fluid.dygraph.Layer):
bottleneck_params
=
gen_bottleneck_params
(
backbone
)
bottleneck_params
=
gen_bottleneck_params
(
backbone
)
self
.
backbone
=
backbone
self
.
backbone
=
backbone
self
.
feat_channels
=
[
128
,
2048
]
self
.
_conv1
=
ConvBNLayer
(
self
.
_conv1
=
ConvBNLayer
(
3
,
3
,
...
@@ -388,19 +370,8 @@ class XceptionDeeplab(fluid.dygraph.Layer):
...
@@ -388,19 +370,8 @@ class XceptionDeeplab(fluid.dygraph.Layer):
has_skip
=
False
,
has_skip
=
False
,
activation_fn_in_separable_conv
=
True
,
activation_fn_in_separable_conv
=
True
,
name
=
self
.
backbone
+
"/exit_flow/block2"
)
name
=
self
.
backbone
+
"/exit_flow/block2"
)
s
=
s
*
stride
self
.
stride
=
s
self
.
_drop
=
Dropout
(
p
=
0.5
)
self
.
_pool
=
Pool2D
(
pool_type
=
"avg"
,
global_pooling
=
True
)
self
.
_fc
=
Linear
(
self
.
chns
[
1
][
-
1
],
class_dim
,
param_attr
=
ParamAttr
(
name
=
"fc_weights"
),
bias_attr
=
ParamAttr
(
name
=
"fc_bias"
))
self
.
init_weight
(
backbone_
pretrained
)
utils
.
load_pretrained_model
(
self
,
pretrained
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_conv1
(
inputs
)
...
@@ -415,27 +386,10 @@ class XceptionDeeplab(fluid.dygraph.Layer):
...
@@ -415,27 +386,10 @@ class XceptionDeeplab(fluid.dygraph.Layer):
x
=
self
.
_exit_flow_1
(
x
)
x
=
self
.
_exit_flow_1
(
x
)
x
=
self
.
_exit_flow_2
(
x
)
x
=
self
.
_exit_flow_2
(
x
)
feat_list
.
append
(
x
)
feat_list
.
append
(
x
)
return
feat_list
x
=
self
.
_drop
(
x
)
x
=
self
.
_pool
(
x
)
x
=
fluid
.
layers
.
squeeze
(
x
,
axes
=
[
2
,
3
])
x
=
self
.
_fc
(
x
)
return
x
,
feat_list
def
init_weight
(
self
,
pretrained_model
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if
pretrained_model
is
not
None
:
if
os
.
path
.
exists
(
pretrained_model
):
utils
.
load_pretrained_model
(
self
,
pretrained_model
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained_model
))
@
manager
.
BACKBONES
.
add_component
def
Xception41_deeplab
(
**
args
):
def
Xception41_deeplab
(
**
args
):
model
=
XceptionDeeplab
(
'xception_41'
,
**
args
)
model
=
XceptionDeeplab
(
'xception_41'
,
**
args
)
return
model
return
model
...
@@ -447,6 +401,7 @@ def Xception65_deeplab(**args):
...
@@ -447,6 +401,7 @@ def Xception65_deeplab(**args):
return
model
return
model
@
manager
.
BACKBONES
.
add_component
def
Xception71_deeplab
(
**
args
):
def
Xception71_deeplab
(
**
args
):
model
=
XceptionDeeplab
(
"xception_71"
,
**
args
)
model
=
XceptionDeeplab
(
"xception_71"
,
**
args
)
return
model
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录