Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
2ea481f5
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2ea481f5
编写于
8月 28, 2020
作者:
W
wuzewu
提交者:
GitHub
8月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #361 from michaelowenliu/develop
上级
44b420ae
f85f1b0f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
427 addition
and
71 deletion
+427
-71
dygraph/cvlibs/manager.py
dygraph/cvlibs/manager.py
+3
-2
dygraph/models/__init__.py
dygraph/models/__init__.py
+1
-0
dygraph/models/architectures/layer_utils.py
dygraph/models/architectures/layer_utils.py
+55
-34
dygraph/models/architectures/mobilenetv3.py
dygraph/models/architectures/mobilenetv3.py
+11
-10
dygraph/models/architectures/resnet_vd.py
dygraph/models/architectures/resnet_vd.py
+8
-7
dygraph/models/architectures/xception_deeplab.py
dygraph/models/architectures/xception_deeplab.py
+22
-18
dygraph/models/model_utils.py
dygraph/models/model_utils.py
+102
-0
dygraph/models/pspnet.py
dygraph/models/pspnet.py
+225
-0
未找到文件。
dygraph/cvlibs/manager.py
浏览文件 @
2ea481f5
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
collections
from
collections.abc
import
Sequence
import
inspect
import
inspect
...
@@ -98,13 +98,14 @@ class ComponentManager:
...
@@ -98,13 +98,14 @@ class ComponentManager:
"""
"""
# Check whether the type is a sequence
# Check whether the type is a sequence
if
isinstance
(
components
,
collections
.
Sequence
):
if
isinstance
(
components
,
Sequence
):
for
component
in
components
:
for
component
in
components
:
self
.
_add_single_component
(
component
)
self
.
_add_single_component
(
component
)
else
:
else
:
component
=
components
component
=
components
self
.
_add_single_component
(
component
)
self
.
_add_single_component
(
component
)
return
components
MODELS
=
ComponentManager
()
MODELS
=
ComponentManager
()
BACKBONES
=
ComponentManager
()
BACKBONES
=
ComponentManager
()
\ No newline at end of file
dygraph/models/__init__.py
浏览文件 @
2ea481f5
...
@@ -16,3 +16,4 @@ from .architectures import *
...
@@ -16,3 +16,4 @@ from .architectures import *
from
.unet
import
UNet
from
.unet
import
UNet
from
.deeplab
import
*
from
.deeplab
import
*
from
.fcn
import
*
from
.fcn
import
*
from
.pspnet
import
*
dygraph/models/architectures/layer_utils.py
浏览文件 @
2ea481f5
...
@@ -13,24 +13,22 @@
...
@@ -13,24 +13,22 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
paddle.nn.functional
as
F
from
paddle
import
fluid
from
paddle
import
fluid
from
paddle.fluid
import
dygraph
from
paddle.fluid
import
dygraph
from
paddle.fluid.dygraph
import
Conv2D
from
paddle.fluid.dygraph
import
Conv2D
from
paddle.fluid.dygraph
import
SyncBatchNorm
as
BatchNorm
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
import
cv2
from
paddle.nn.layer
import
activation
import
os
import
sys
class
ConvBnRelu
(
dygraph
.
Layer
):
class
ConvBnRelu
(
dygraph
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
num_channels
,
num_channels
,
num_filters
,
num_filters
,
filter_size
,
filter_size
,
using_sep_conv
=
False
,
using_sep_conv
=
False
,
**
kwargs
):
**
kwargs
):
super
(
ConvBnRelu
,
self
).
__init__
()
super
(
ConvBnRelu
,
self
).
__init__
()
if
using_sep_conv
:
if
using_sep_conv
:
...
@@ -41,16 +39,16 @@ class ConvBnRelu(dygraph.Layer):
...
@@ -41,16 +39,16 @@ class ConvBnRelu(dygraph.Layer):
else
:
else
:
self
.
conv
=
Conv2D
(
num_channels
,
self
.
conv
=
Conv2D
(
num_channels
,
num_filters
,
num_filters
,
filter_size
,
filter_size
,
**
kwargs
)
**
kwargs
)
self
.
batch_norm
=
BatchNorm
(
num_filters
)
self
.
batch_norm
=
BatchNorm
(
num_filters
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
x
=
F
.
relu
(
x
)
return
x
return
x
...
@@ -81,7 +79,7 @@ class ConvReluPool(dygraph.Layer):
...
@@ -81,7 +79,7 @@ class ConvReluPool(dygraph.Layer):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
x
=
fluid
.
layers
.
relu
(
x
)
x
=
F
.
relu
(
x
)
x
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
2
,
pool_type
=
"max"
,
pool_stride
=
2
)
x
=
fluid
.
layers
.
pool2d
(
x
,
pool_size
=
2
,
pool_type
=
"max"
,
pool_stride
=
2
)
return
x
return
x
...
@@ -106,15 +104,15 @@ class DepthwiseConvBnRelu(dygraph.Layer):
...
@@ -106,15 +104,15 @@ class DepthwiseConvBnRelu(dygraph.Layer):
**
kwargs
):
**
kwargs
):
super
(
DepthwiseConvBnRelu
,
self
).
__init__
()
super
(
DepthwiseConvBnRelu
,
self
).
__init__
()
self
.
depthwise_conv
=
ConvBn
(
num_channels
,
self
.
depthwise_conv
=
ConvBn
(
num_channels
,
num_filters
=
num_channels
,
num_filters
=
num_channels
,
filter_size
=
filter_size
,
filter_size
=
filter_size
,
groups
=
num_channels
,
groups
=
num_channels
,
use_cudnn
=
False
,
use_cudnn
=
False
,
**
kwargs
)
**
kwargs
)
self
.
piontwise_conv
=
ConvBnRelu
(
num_channels
,
self
.
piontwise_conv
=
ConvBnRelu
(
num_channels
,
num_filters
,
num_filters
,
filter_size
=
1
,
filter_size
=
1
,
groups
=
1
)
groups
=
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
depthwise_conv
(
x
)
x
=
self
.
depthwise_conv
(
x
)
...
@@ -122,20 +120,43 @@ class DepthwiseConvBnRelu(dygraph.Layer):
...
@@ -122,20 +120,43 @@ class DepthwiseConvBnRelu(dygraph.Layer):
return
x
return
x
def
compute_loss
(
logits
,
label
,
ignore_index
=
255
):
class
Activation
(
fluid
.
dygraph
.
Layer
):
mask
=
label
!=
ignore_index
"""
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
The wrapper of activations
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
For example:
logits
,
>>> relu = Activation("relu")
label
,
>>> print(relu)
ignore_index
=
ignore_index
,
<class 'paddle.nn.layer.activation.ReLU'>
return_softmax
=
True
,
>>> sigmoid = Activation("sigmoid")
axis
=
1
)
>>> print(sigmoid)
<class 'paddle.nn.layer.activation.Sigmoid'>
>>> not_exit_one = Activation("not_exit_one")
KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
Args:
act (str): the activation name in lowercase
"""
def
__init__
(
self
,
act
=
None
):
super
(
Activation
,
self
).
__init__
()
self
.
_act
=
act
upper_act_names
=
activation
.
__all__
lower_act_names
=
[
act
.
lower
()
for
act
in
upper_act_names
]
act_dict
=
dict
(
zip
(
lower_act_names
,
upper_act_names
))
if
act
is
not
None
:
if
act
in
act_dict
.
keys
():
act_name
=
act_dict
[
act
]
self
.
act_func
=
eval
(
"activation.{}()"
.
format
(
act_name
))
else
:
raise
KeyError
(
"{} does not exist in the current {}"
.
format
(
act
,
act_dict
.
keys
()))
loss
=
loss
*
mask
def
forward
(
self
,
x
):
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
mask
)
+
1e-5
)
label
.
stop_gradient
=
True
if
self
.
_act
is
not
None
:
mask
.
stop_gradient
=
True
return
self
.
act_func
(
x
)
return
avg_loss
else
:
\ No newline at end of file
return
x
\ No newline at end of file
dygraph/models/architectures/mobilenetv3.py
浏览文件 @
2ea481f5
...
@@ -16,15 +16,17 @@ from __future__ import absolute_import
...
@@ -16,15 +16,17 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
math
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
,
Dropout
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
import
math
from
dygraph.models.architectures
import
layer_utils
from
dygraph.cvlibs
import
manager
from
dygraph.cvlibs
import
manager
__all__
=
[
__all__
=
[
...
@@ -251,19 +253,18 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -251,19 +253,18 @@ class ConvBNLayer(fluid.dygraph.Layer):
bias_attr
=
False
,
bias_attr
=
False
,
use_cudnn
=
use_cudnn
,
use_cudnn
=
use_cudnn
,
act
=
None
)
act
=
None
)
self
.
bn
=
fluid
.
dygraph
.
BatchNorm
(
self
.
bn
=
BatchNorm
(
num_channels
=
out_c
,
num_features
=
out_c
,
act
=
None
,
weight_attr
=
ParamAttr
(
param_attr
=
ParamAttr
(
name
=
name
+
"_bn_scale"
,
name
=
name
+
"_bn_scale"
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
)),
regularization_coeff
=
0.0
)),
bias_attr
=
ParamAttr
(
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
,
name
=
name
+
"_bn_offset"
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
))
,
regularization_coeff
=
0.0
))
)
moving_mean_name
=
name
+
"_bn_mean"
,
moving_variance_name
=
name
+
"_bn_variance"
)
self
.
_act_op
=
layer_utils
.
Activation
(
act
=
None
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
...
...
dygraph/models/architectures/resnet_vd.py
浏览文件 @
2ea481f5
...
@@ -24,10 +24,11 @@ import paddle
...
@@ -24,10 +24,11 @@ import paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
,
Dropout
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
dygraph.utils
import
utils
from
dygraph.utils
import
utils
from
dygraph.models.architectures
import
layer_utils
from
dygraph.cvlibs
import
manager
from
dygraph.cvlibs
import
manager
__all__
=
[
__all__
=
[
...
@@ -69,17 +70,17 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -69,17 +70,17 @@ class ConvBNLayer(fluid.dygraph.Layer):
bn_name
=
"bn"
+
name
[
3
:]
bn_name
=
"bn"
+
name
[
3
:]
self
.
_batch_norm
=
BatchNorm
(
self
.
_batch_norm
=
BatchNorm
(
num_filters
,
num_filters
,
act
=
act
,
weight_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
))
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
self
.
_act_op
=
layer_utils
.
Activation
(
act
=
act
)
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
if
self
.
is_vd_mode
:
if
self
.
is_vd_mode
:
inputs
=
self
.
_pool2d_avg
(
inputs
)
inputs
=
self
.
_pool2d_avg
(
inputs
)
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
y
=
self
.
_batch_norm
(
y
)
y
=
self
.
_act_op
(
y
)
return
y
return
y
...
...
dygraph/models/architectures/xception_deeplab.py
浏览文件 @
2ea481f5
...
@@ -2,8 +2,10 @@ import paddle
...
@@ -2,8 +2,10 @@ import paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
,
Dropout
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
dygraph.models.architectures
import
layer_utils
from
dygraph.cvlibs
import
manager
from
dygraph.cvlibs
import
manager
__all__
=
[
"Xception41_deeplab"
,
"Xception65_deeplab"
,
"Xception71_deeplab"
]
__all__
=
[
"Xception41_deeplab"
,
"Xception65_deeplab"
,
"Xception71_deeplab"
]
...
@@ -79,17 +81,17 @@ class ConvBNLayer(fluid.dygraph.Layer):
...
@@ -79,17 +81,17 @@ class ConvBNLayer(fluid.dygraph.Layer):
param_attr
=
ParamAttr
(
name
=
name
+
"/weights"
),
param_attr
=
ParamAttr
(
name
=
name
+
"/weights"
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_bn
=
BatchNorm
(
self
.
_bn
=
BatchNorm
(
num_channels
=
output_channels
,
num_features
=
output_channels
,
act
=
act
,
epsilon
=
1e-3
,
epsilon
=
1e-3
,
momentum
=
0.99
,
momentum
=
0.99
,
param
_attr
=
ParamAttr
(
name
=
name
+
"/BatchNorm/gamma"
),
weight
_attr
=
ParamAttr
(
name
=
name
+
"/BatchNorm/gamma"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/BatchNorm/beta"
)
,
bias_attr
=
ParamAttr
(
name
=
name
+
"/BatchNorm/beta"
)
)
moving_mean_name
=
name
+
"/BatchNorm/moving_mean"
,
moving_variance_name
=
name
+
"/BatchNorm/moving_variance"
)
self
.
_act_op
=
layer_utils
.
Activation
(
act
=
act
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
return
self
.
_bn
(
self
.
_conv
(
inputs
))
return
self
.
_act_op
(
self
.
_bn
(
self
.
_conv
(
inputs
)))
class
Seperate_Conv
(
fluid
.
dygraph
.
Layer
):
class
Seperate_Conv
(
fluid
.
dygraph
.
Layer
):
...
@@ -115,13 +117,13 @@ class Seperate_Conv(fluid.dygraph.Layer):
...
@@ -115,13 +117,13 @@ class Seperate_Conv(fluid.dygraph.Layer):
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_bn1
=
BatchNorm
(
self
.
_bn1
=
BatchNorm
(
input_channels
,
input_channels
,
act
=
act
,
epsilon
=
1e-3
,
epsilon
=
1e-3
,
momentum
=
0.99
,
momentum
=
0.99
,
param_attr
=
ParamAttr
(
name
=
name
+
"/depthwise/BatchNorm/gamma"
),
weight_attr
=
ParamAttr
(
name
=
name
+
"/depthwise/BatchNorm/gamma"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/depthwise/BatchNorm/beta"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/depthwise/BatchNorm/beta"
))
moving_mean_name
=
name
+
"/depthwise/BatchNorm/moving_mean"
,
moving_variance_name
=
name
+
"/depthwise/BatchNorm/moving_variance"
)
self
.
_act_op1
=
layer_utils
.
Activation
(
act
=
act
)
self
.
_conv2
=
Conv2D
(
self
.
_conv2
=
Conv2D
(
input_channels
,
input_channels
,
output_channels
,
output_channels
,
...
@@ -133,19 +135,21 @@ class Seperate_Conv(fluid.dygraph.Layer):
...
@@ -133,19 +135,21 @@ class Seperate_Conv(fluid.dygraph.Layer):
bias_attr
=
False
)
bias_attr
=
False
)
self
.
_bn2
=
BatchNorm
(
self
.
_bn2
=
BatchNorm
(
output_channels
,
output_channels
,
act
=
act
,
epsilon
=
1e-3
,
epsilon
=
1e-3
,
momentum
=
0.99
,
momentum
=
0.99
,
param_attr
=
ParamAttr
(
name
=
name
+
"/pointwise/BatchNorm/gamma"
),
weight_attr
=
ParamAttr
(
name
=
name
+
"/pointwise/BatchNorm/gamma"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/pointwise/BatchNorm/beta"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"/pointwise/BatchNorm/beta"
))
moving_mean_name
=
name
+
"/pointwise/BatchNorm/moving_mean"
,
moving_variance_name
=
name
+
"/pointwise/BatchNorm/moving_variance"
)
self
.
_act_op2
=
layer_utils
.
Activation
(
act
=
act
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_conv1
(
inputs
)
x
=
self
.
_bn1
(
x
)
x
=
self
.
_bn1
(
x
)
x
=
self
.
_act_op1
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_conv2
(
x
)
x
=
self
.
_bn2
(
x
)
x
=
self
.
_bn2
(
x
)
x
=
self
.
_act_op2
(
x
)
return
x
return
x
...
...
dygraph/models/model_utils.py
0 → 100644
浏览文件 @
2ea481f5
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn.functional
as
F
from
paddle
import
fluid
from
paddle.fluid
import
dygraph
from
paddle.fluid.dygraph
import
Conv2D
from
paddle.nn
import
SyncBatchNorm
as
BatchNorm
from
dygraph.models.architectures
import
layer_utils
class
FCNHead
(
fluid
.
dygraph
.
Layer
):
"""
The FCNHead implementation used in auxilary layer
Args:
in_channels (int): the number of input channels
out_channels (int): the number of output channels
"""
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
FCNHead
,
self
).
__init__
()
inter_channels
=
in_channels
//
4
self
.
conv_bn_relu
=
layer_utils
.
ConvBnRelu
(
num_channels
=
in_channels
,
num_filters
=
inter_channels
,
filter_size
=
3
,
padding
=
1
)
self
.
conv
=
Conv2D
(
num_channels
=
inter_channels
,
num_filters
=
out_channels
,
filter_size
=
1
)
def
forward
(
self
,
x
):
x
=
self
.
conv_bn_relu
(
x
)
x
=
F
.
dropout
(
x
,
p
=
0.1
)
x
=
self
.
conv
(
x
)
return
x
def
get_loss
(
logit
,
label
,
ignore_index
=
255
,
EPS
=
1e-5
):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
mask
=
label
!=
ignore_index
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
ignore_index
,
return_softmax
=
True
,
axis
=-
1
)
loss
=
loss
*
mask
avg_loss
=
paddle
.
mean
(
loss
)
/
(
paddle
.
mean
(
mask
)
+
EPS
)
label
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
def
get_pred_score_map
(
logit
):
"""
Get prediction and score map output in inference phase.
Args:
logit (tensor): output logit of network
Returns:
pred (tensor): predition map
score_map (tensor): score map
"""
score_map
=
F
.
softmax
(
logit
,
axis
=
1
)
score_map
=
fluid
.
layers
.
transpose
(
score_map
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
score_map
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
pred
,
score_map
\ No newline at end of file
dygraph/models/pspnet.py
0 → 100644
浏览文件 @
2ea481f5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle.nn.functional
as
F
from
paddle
import
fluid
from
paddle.fluid.dygraph
import
Conv2D
from
dygraph.cvlibs
import
manager
from
dygraph.models
import
model_utils
from
dygraph.models.architectures
import
layer_utils
from
dygraph.utils
import
utils
class
PSPNet
(
fluid
.
dygraph
.
Layer
):
"""
The PSPNet implementation
The orginal artile refers to
Zhao, Hengshuang, et al. "Pyramid scene parsing network."
Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
(https://openaccess.thecvf.com/content_cvpr_2017/papers/Zhao_Pyramid_Scene_Parsing_CVPR_2017_paper.pdf)
Args:
backbone (str): backbone name, currently support Resnet50/101.
num_classes (int): the unique number of target classes. Default 2.
output_stride (int): the ratio of input size and final feature size. Default 16.
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of Pyramid Pooling Module (PPModule).
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (2, 3), which means taking feature map of the third
stage (res4b22) in backbone, and feature map of the fourth stage (res5c) as input of PPModule.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
pp_out_channels (int): output channels after Pyramid Pooling Module. Default to 1024.
bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6).
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default to 255.
pretrained_model (str): the pretrained_model path of backbone.
"""
def
__init__
(
self
,
backbone
,
num_classes
=
2
,
output_stride
=
16
,
backbone_indices
=
(
2
,
3
),
backbone_channels
=
(
1024
,
2048
),
pp_out_channels
=
1024
,
bin_sizes
=
(
1
,
2
,
3
,
6
),
enable_auxiliary_loss
=
True
,
ignore_index
=
255
,
pretrained_model
=
None
):
super
(
PSPNet
,
self
).
__init__
()
self
.
backbone
=
manager
.
BACKBONES
[
backbone
](
output_stride
=
output_stride
,
multi_grid
=
(
1
,
1
,
1
))
self
.
backbone_indices
=
backbone_indices
self
.
psp_module
=
PPModule
(
in_channels
=
backbone_channels
[
1
],
out_channels
=
pp_out_channels
,
bin_sizes
=
bin_sizes
)
self
.
conv
=
Conv2D
(
num_channels
=
pp_out_channels
,
num_filters
=
num_classes
,
filter_size
=
1
)
if
enable_auxiliary_loss
:
self
.
fcn_head
=
model_utils
.
FCNHead
(
in_channels
=
backbone_channels
[
0
],
out_channels
=
num_classes
)
self
.
enable_auxiliary_loss
=
enable_auxiliary_loss
self
.
ignore_index
=
ignore_index
self
.
init_weight
(
pretrained_model
)
def
forward
(
self
,
input
,
label
=
None
):
_
,
feat_list
=
self
.
backbone
(
input
)
x
=
feat_list
[
self
.
backbone_indices
[
1
]]
x
=
self
.
psp_module
(
x
)
x
=
F
.
dropout
(
x
,
dropout_prob
=
0.1
)
logit
=
self
.
conv
(
x
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
input
.
shape
[
2
:])
if
self
.
enable_auxiliary_loss
:
auxiliary_feat
=
feat_list
[
self
.
backbone_indices
[
0
]]
auxiliary_logit
=
self
.
fcn_head
(
auxiliary_feat
)
auxiliary_logit
=
fluid
.
layers
.
resize_bilinear
(
auxiliary_logit
,
input
.
shape
[
2
:])
if
self
.
training
:
loss
=
model_utils
.
get_loss
(
logit
,
label
)
if
self
.
enable_auxiliary_loss
:
auxiliary_loss
=
model_utils
.
get_loss
(
auxiliary_logit
,
label
)
loss
+=
(
0.4
*
auxiliary_loss
)
return
loss
else
:
pred
,
score_map
=
model_utils
.
get_pred_score_map
(
logit
)
return
pred
,
score_map
def
init_weight
(
self
,
pretrained_model
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
"""
if
pretrained_model
is
not
None
:
if
os
.
path
.
exists
(
pretrained_model
):
utils
.
load_pretrained_model
(
self
.
backbone
,
pretrained_model
)
class
PPModule
(
fluid
.
dygraph
.
Layer
):
"""
Pyramid pooling module
Args:
in_channels (int): the number of intput channels to pyramid pooling module.
out_channels (int): the number of output channels after pyramid pooling module.
bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6).
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
bin_sizes
=
(
1
,
2
,
3
,
6
)):
super
(
PPModule
,
self
).
__init__
()
self
.
bin_sizes
=
bin_sizes
# we use dimension reduction after pooling mentioned in original implementation.
self
.
stages
=
fluid
.
dygraph
.
LayerList
([
self
.
_make_stage
(
in_channels
,
size
)
for
size
in
bin_sizes
])
self
.
conv_bn_relu2
=
layer_utils
.
ConvBnRelu
(
num_channels
=
in_channels
*
2
,
num_filters
=
out_channels
,
filter_size
=
3
,
padding
=
1
)
def
_make_stage
(
self
,
in_channels
,
size
):
"""
Create one pooling layer.
In our implementation, we adopt the same dimention reduction as the original paper that might be
slightly different with other implementations.
After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations
keep the channels to be same.
Args:
in_channels (int): the number of intput channels to pyramid pooling module.
size (int): the out size of the pooled layer.
Returns:
conv (tensor): a tensor after Pyramid Pooling Module
"""
# this paddle version does not support AdaptiveAvgPool2d, so skip it here.
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv
=
layer_utils
.
ConvBnRelu
(
num_channels
=
in_channels
,
num_filters
=
in_channels
//
len
(
self
.
bin_sizes
),
filter_size
=
1
)
return
conv
def
forward
(
self
,
input
):
cat_layers
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
size
=
self
.
bin_sizes
[
i
]
x
=
fluid
.
layers
.
adaptive_pool2d
(
input
,
pool_size
=
(
size
,
size
),
pool_type
=
"max"
)
x
=
stage
(
x
)
x
=
fluid
.
layers
.
resize_bilinear
(
x
,
out_shape
=
input
.
shape
[
2
:])
cat_layers
.
append
(
x
)
cat_layers
=
[
input
]
+
cat_layers
[::
-
1
]
cat
=
fluid
.
layers
.
concat
(
cat_layers
,
axis
=
1
)
out
=
self
.
conv_bn_relu2
(
cat
)
return
out
@
manager
.
MODELS
.
add_component
def
pspnet_resnet101_vd
(
*
args
,
**
kwargs
):
pretrained_model
=
None
return
PSPNet
(
backbone
=
'ResNet101_vd'
,
pretrained_model
=
pretrained_model
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
pspnet_resnet101_vd_os8
(
*
args
,
**
kwargs
):
pretrained_model
=
None
return
PSPNet
(
backbone
=
'ResNet101_vd'
,
output_stride
=
8
,
pretrained_model
=
pretrained_model
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
pspnet_resnet50_vd
(
*
args
,
**
kwargs
):
pretrained_model
=
None
return
PSPNet
(
backbone
=
'ResNet50_vd'
,
pretrained_model
=
pretrained_model
,
**
kwargs
)
@
manager
.
MODELS
.
add_component
def
pspnet_resnet50_vd_os8
(
*
args
,
**
kwargs
):
pretrained_model
=
None
return
PSPNet
(
backbone
=
'ResNet50_vd'
,
output_stride
=
8
,
pretrained_model
=
pretrained_model
,
**
kwargs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录