Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
a3bfb074
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
286
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a3bfb074
编写于
9月 22, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/PaddleSeg
into dygraph
上级
fad18563
23d69271
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
444 addition
and
126 deletion
+444
-126
dygraph/paddleseg/datasets/ade.py
dygraph/paddleseg/datasets/ade.py
+3
-3
dygraph/paddleseg/datasets/optic_disc_seg.py
dygraph/paddleseg/datasets/optic_disc_seg.py
+2
-2
dygraph/paddleseg/datasets/voc.py
dygraph/paddleseg/datasets/voc.py
+3
-3
dygraph/paddleseg/env.py
dygraph/paddleseg/env.py
+50
-0
dygraph/paddleseg/models/danet.py
dygraph/paddleseg/models/danet.py
+217
-0
dygraph/paddleseg/models/ocrnet.py
dygraph/paddleseg/models/ocrnet.py
+132
-110
dygraph/paddleseg/utils/utils.py
dygraph/paddleseg/utils/utils.py
+28
-1
dygraph/train.py
dygraph/train.py
+4
-3
dygraph/val.py
dygraph/val.py
+5
-4
未找到文件。
dygraph/paddleseg/datasets/ade.py
浏览文件 @
a3bfb074
...
@@ -17,12 +17,12 @@ import os
...
@@ -17,12 +17,12 @@ import os
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
import
paddleseg.env
as
segenv
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.transforms
import
Compose
from
paddleseg.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
URL
=
"http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
...
@@ -61,8 +61,8 @@ class ADE20K(Dataset):
...
@@ -61,8 +61,8 @@ class ADE20K(Dataset):
"`dataset_root` not set and auto download disabled."
)
"`dataset_root` not set and auto download disabled."
)
self
.
dataset_root
=
download_file_and_uncompress
(
self
.
dataset_root
=
download_file_and_uncompress
(
url
=
URL
,
url
=
URL
,
savepath
=
DATA_HOME
,
savepath
=
segenv
.
DATA_HOME
,
extrapath
=
DATA_HOME
,
extrapath
=
segenv
.
DATA_HOME
,
extraname
=
'ADEChallengeData2016'
)
extraname
=
'ADEChallengeData2016'
)
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
...
...
dygraph/paddleseg/datasets/optic_disc_seg.py
浏览文件 @
a3bfb074
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
import
os
import
os
import
paddleseg.env
as
segenv
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.transforms
import
Compose
from
paddleseg.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
URL
=
"https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
...
@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
...
@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
raise
Exception
(
raise
Exception
(
"`data_root` not set and auto download disabled."
)
"`data_root` not set and auto download disabled."
)
self
.
dataset_root
=
download_file_and_uncompress
(
self
.
dataset_root
=
download_file_and_uncompress
(
url
=
URL
,
savepath
=
DATA_HOME
,
extrapath
=
DATA_HOME
)
url
=
URL
,
savepath
=
segenv
.
DATA_HOME
,
extrapath
=
segenv
.
DATA_HOME
)
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
self
.
dataset_root
))
self
.
dataset_root
))
...
...
dygraph/paddleseg/datasets/voc.py
浏览文件 @
a3bfb074
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
import
os
import
os
import
paddleseg.env
as
segenv
from
.dataset
import
Dataset
from
.dataset
import
Dataset
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.utils.download
import
download_file_and_uncompress
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.transforms
import
Compose
from
paddleseg.transforms
import
Compose
DATA_HOME
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset'
)
URL
=
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
URL
=
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
...
@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
...
@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
"`dataset_root` not set and auto download disabled."
)
"`dataset_root` not set and auto download disabled."
)
self
.
dataset_root
=
download_file_and_uncompress
(
self
.
dataset_root
=
download_file_and_uncompress
(
url
=
URL
,
url
=
URL
,
savepath
=
DATA_HOME
,
savepath
=
segenv
.
DATA_HOME
,
extrapath
=
DATA_HOME
,
extrapath
=
segenv
.
DATA_HOME
,
extraname
=
'VOCdevkit'
)
extraname
=
'VOCdevkit'
)
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
elif
not
os
.
path
.
exists
(
self
.
dataset_root
):
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
raise
Exception
(
'there is not `dataset_root`: {}.'
.
format
(
...
...
dygraph/paddleseg/env.py
0 → 100644
浏览文件 @
a3bfb074
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
from
paddleseg.utils
import
logger
def
_get_user_home
():
return
os
.
path
.
expanduser
(
'~'
)
def
_get_seg_home
():
if
'SEG_HOME'
in
os
.
environ
:
home_path
=
os
.
environ
[
'SEG_HOME'
]
if
os
.
path
.
exists
(
home_path
):
if
os
.
path
.
isdir
(
home_path
):
return
home_path
else
:
logger
.
warning
(
'SEG_HOME {} is a file!'
.
format
(
home_path
))
else
:
return
home_path
return
os
.
path
.
join
(
_get_user_home
(),
'.paddleseg'
)
def
_get_sub_home
(
directory
):
home
=
os
.
path
.
join
(
_get_seg_home
(),
directory
)
if
not
os
.
path
.
exists
(
home
):
os
.
makedirs
(
home
)
return
home
USER_HOME
=
_get_user_home
()
SEG_HOME
=
_get_seg_home
()
DATA_HOME
=
_get_sub_home
(
'dataset'
)
TMP_HOME
=
_get_sub_home
(
'tmp'
)
PRETRAINED_MODEL_HOME
=
_get_sub_home
(
'pretrained_model'
)
dygraph/paddleseg/models/danet.py
0 → 100644
浏览文件 @
a3bfb074
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddleseg.utils
import
utils
from
paddleseg.cvlibs
import
manager
,
param_init
from
paddleseg.models.common.layer_libs
import
ConvBNReLU
class
PAM
(
nn
.
Layer
):
"""Position attention module"""
def
__init__
(
self
,
in_channels
):
super
(
PAM
,
self
).
__init__
()
mid_channels
=
in_channels
//
8
self
.
query_conv
=
nn
.
Conv2d
(
in_channels
,
mid_channels
,
1
,
1
)
self
.
key_conv
=
nn
.
Conv2d
(
in_channels
,
mid_channels
,
1
,
1
)
self
.
value_conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
1
,
1
)
self
.
gamma
=
self
.
create_parameter
(
shape
=
[
1
],
dtype
=
'float32'
,
default_initializer
=
nn
.
initializer
.
Constant
(
0
))
def
forward
(
self
,
x
):
n
,
_
,
h
,
w
=
x
.
shape
# query: n, h * w, c1
query
=
self
.
query_conv
(
x
)
query
=
paddle
.
reshape
(
query
,
(
n
,
-
1
,
h
*
w
))
query
=
paddle
.
transpose
(
query
,
(
0
,
2
,
1
))
# key: n, c1, h * w
key
=
self
.
key_conv
(
x
)
key
=
paddle
.
reshape
(
key
,
(
n
,
-
1
,
h
*
w
))
# sim: n, h * w, h * w
sim
=
paddle
.
bmm
(
query
,
key
)
sim
=
F
.
softmax
(
sim
,
axis
=-
1
)
value
=
self
.
value_conv
(
x
)
value
=
paddle
.
reshape
(
value
,
(
n
,
-
1
,
h
*
w
))
sim
=
paddle
.
transpose
(
sim
,
(
0
,
2
,
1
))
# feat: from (n, c2, h * w) -> (n, c2, h, w)
feat
=
paddle
.
bmm
(
value
,
sim
)
feat
=
paddle
.
reshape
(
feat
,
(
n
,
-
1
,
h
,
w
))
out
=
self
.
gamma
*
feat
+
x
return
out
class
CAM
(
nn
.
Layer
):
"""Channel attention module"""
def
__init__
(
self
):
super
(
CAM
,
self
).
__init__
()
self
.
gamma
=
self
.
create_parameter
(
shape
=
[
1
],
dtype
=
'float32'
,
default_initializer
=
nn
.
initializer
.
Constant
(
0
))
def
forward
(
self
,
x
):
n
,
c
,
h
,
w
=
x
.
shape
# query: n, c, h * w
query
=
paddle
.
reshape
(
x
,
(
n
,
c
,
h
*
w
))
# key: n, h * w, c
key
=
paddle
.
reshape
(
x
,
(
n
,
c
,
h
*
w
))
key
=
paddle
.
transpose
(
key
,
(
0
,
2
,
1
))
# sim: n, c, c
sim
=
paddle
.
bmm
(
query
,
key
)
# The danet author claims that this can avoid gradient divergence
sim
=
paddle
.
max
(
sim
,
axis
=-
1
,
keepdim
=
True
).
expand_as
(
sim
)
-
sim
sim
=
F
.
softmax
(
sim
,
axis
=-
1
)
# feat: from (n, c, h * w) to (n, c, h, w)
value
=
paddle
.
reshape
(
x
,
(
n
,
c
,
h
*
w
))
feat
=
paddle
.
bmm
(
sim
,
value
)
feat
=
paddle
.
reshape
(
feat
,
(
n
,
c
,
h
,
w
))
out
=
self
.
gamma
*
feat
+
x
return
out
class
DAHead
(
nn
.
Layer
):
"""
The Dual attention head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
"""
def
__init__
(
self
,
num_classes
,
in_channels
=
None
):
super
(
DAHead
,
self
).
__init__
()
in_channels
=
in_channels
[
-
1
]
inter_channels
=
in_channels
//
4
self
.
channel_conv
=
ConvBNReLU
(
in_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
position_conv
=
ConvBNReLU
(
in_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
pam
=
PAM
(
inter_channels
)
self
.
cam
=
CAM
()
self
.
conv1
=
ConvBNReLU
(
inter_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
conv2
=
ConvBNReLU
(
inter_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
aux_head_pam
=
nn
.
Sequential
(
nn
.
Dropout2d
(
0.1
),
nn
.
Conv2d
(
inter_channels
,
num_classes
,
1
))
self
.
aux_head_cam
=
nn
.
Sequential
(
nn
.
Dropout2d
(
0.1
),
nn
.
Conv2d
(
inter_channels
,
num_classes
,
1
))
self
.
cls_head
=
nn
.
Sequential
(
nn
.
Dropout2d
(
0.1
),
nn
.
Conv2d
(
inter_channels
,
num_classes
,
1
))
self
.
init_weight
()
def
forward
(
self
,
x
,
label
=
None
):
feats
=
x
[
-
1
]
channel_feats
=
self
.
channel_conv
(
feats
)
channel_feats
=
self
.
cam
(
channel_feats
)
channel_feats
=
self
.
conv1
(
channel_feats
)
cam_head
=
self
.
aux_head_cam
(
channel_feats
)
position_feats
=
self
.
position_conv
(
feats
)
position_feats
=
self
.
pam
(
position_feats
)
position_feats
=
self
.
conv2
(
position_feats
)
pam_head
=
self
.
aux_head_pam
(
position_feats
)
feats_sum
=
position_feats
+
channel_feats
cam_logit
=
self
.
aux_head_cam
(
channel_feats
)
pam_logit
=
self
.
aux_head_cam
(
position_feats
)
logit
=
self
.
cls_head
(
feats_sum
)
return
[
logit
,
cam_logit
,
pam_logit
]
def
init_weight
(
self
):
"""Initialize the parameters of model parts."""
for
sublayer
in
self
.
sublayers
():
if
isinstance
(
sublayer
,
nn
.
Conv2d
):
param_init
.
normal_init
(
sublayer
.
weight
,
scale
=
0.001
)
elif
isinstance
(
sublayer
,
nn
.
SyncBatchNorm
):
param_init
.
constant_init
(
sublayer
.
weight
,
value
=
1
)
param_init
.
constant_init
(
sublayer
.
bias
,
value
=
0
)
@
manager
.
MODELS
.
add_component
class
DANet
(
nn
.
Layer
):
"""
The DANet implementation based on PaddlePaddle.
The original article refers to
Fu, jun, et al. "Dual Attention Network for Scene Segmentation"
(https://arxiv.org/pdf/1809.02983.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): values in the tuple indicate the indices of output of backbone.
Only the last indice is used.
"""
def
__init__
(
self
,
num_classes
,
backbone
,
pretrained
=
None
,
backbone_indices
=
None
):
super
(
DANet
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
backbone_indices
=
backbone_indices
in_channels
=
[
self
.
backbone
.
channels
[
i
]
for
i
in
backbone_indices
]
self
.
head
=
DAHead
(
num_classes
=
num_classes
,
in_channels
=
in_channels
)
self
.
init_weight
(
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
feats
=
self
.
backbone
(
x
)
feats
=
[
feats
[
i
]
for
i
in
self
.
backbone_indices
]
preds
=
self
.
head
(
feats
,
label
)
preds
=
[
F
.
resize_bilinear
(
pred
,
x
.
shape
[
2
:])
for
pred
in
preds
]
return
preds
def
init_weight
(
self
,
pretrained
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
if
pretrained
is
not
None
:
if
os
.
path
.
exists
(
pretrained
):
utils
.
load_pretrained_model
(
self
,
pretrained
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained
))
dygraph/paddleseg/models/ocrnet.py
浏览文件 @
a3bfb074
...
@@ -14,36 +14,41 @@
...
@@ -14,36 +14,41 @@
import
os
import
os
import
paddle.fluid
as
fluid
import
paddle
from
paddle.fluid.dygraph
import
Sequential
,
Conv2D
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddleseg.cvlibs
import
manager
from
paddleseg.models.common.layer_libs
import
ConvBnRelu
from
paddleseg
import
utils
from
paddleseg
import
utils
from
paddleseg.cvlibs
import
manager
,
param_init
from
paddleseg.models.common.layer_libs
import
ConvBNReLU
,
AuxLayer
class
SpatialGatherBlock
(
fluid
.
dygraph
.
Layer
):
class
SpatialGatherBlock
(
nn
.
Layer
):
"""Aggregation layer to compute the pixel-region representation"""
def
forward
(
self
,
pixels
,
regions
):
def
forward
(
self
,
pixels
,
regions
):
n
,
c
,
h
,
w
=
pixels
.
shape
n
,
c
,
h
,
w
=
pixels
.
shape
_
,
k
,
_
,
_
=
regions
.
shape
_
,
k
,
_
,
_
=
regions
.
shape
# pixels: from (n, c, h, w) to (n, h*w, c)
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels
=
fluid
.
layers
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
paddle
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
fluid
.
layers
.
transpose
(
pixels
,
(
0
,
2
,
1
))
pixels
=
paddle
.
transpose
(
pixels
,
(
0
,
2
,
1
))
# regions: from (n, k, h, w) to (n, k, h*w)
# regions: from (n, k, h, w) to (n, k, h*w)
regions
=
fluid
.
layers
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
paddle
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
fluid
.
layers
.
softmax
(
regions
,
axis
=
2
)
regions
=
F
.
softmax
(
regions
,
axis
=
2
)
# feats: from (n, k, c) to (n, c, k, 1)
# feats: from (n, k, c) to (n, c, k, 1)
feats
=
fluid
.
layers
.
matmul
(
regions
,
pixels
)
feats
=
paddle
.
bmm
(
regions
,
pixels
)
feats
=
fluid
.
layers
.
transpose
(
feats
,
(
0
,
2
,
1
))
feats
=
paddle
.
transpose
(
feats
,
(
0
,
2
,
1
))
feats
=
fluid
.
layers
.
unsqueeze
(
feats
,
axes
=
[
-
1
]
)
feats
=
paddle
.
unsqueeze
(
feats
,
axis
=-
1
)
return
feats
return
feats
class
SpatialOCRModule
(
fluid
.
dygraph
.
Layer
):
class
SpatialOCRModule
(
nn
.
Layer
):
"""Aggregate the global object representation to update the representation for each pixel"""
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
key_channels
,
key_channels
,
...
@@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer):
...
@@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer):
self
.
attention_block
=
ObjectAttentionBlock
(
in_channels
,
key_channels
)
self
.
attention_block
=
ObjectAttentionBlock
(
in_channels
,
key_channels
)
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
conv1x1
=
Conv2D
(
2
*
in_channels
,
out_channels
,
1
)
self
.
conv1x1
=
nn
.
Sequential
(
nn
.
Conv2d
(
2
*
in_channels
,
out_channels
,
1
),
nn
.
Dropout2d
(
0.1
))
def
forward
(
self
,
pixels
,
regions
):
def
forward
(
self
,
pixels
,
regions
):
context
=
self
.
attention_block
(
pixels
,
regions
)
context
=
self
.
attention_block
(
pixels
,
regions
)
feats
=
fluid
.
layers
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
paddle
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
self
.
conv1x1
(
feats
)
feats
=
self
.
conv1x1
(
feats
)
feats
=
fluid
.
layers
.
dropout
(
feats
,
self
.
dropout_rate
)
return
feats
return
feats
class
ObjectAttentionBlock
(
fluid
.
dygraph
.
Layer
):
class
ObjectAttentionBlock
(
nn
.
Layer
):
"""A self-attention module."""
def
__init__
(
self
,
in_channels
,
key_channels
):
def
__init__
(
self
,
in_channels
,
key_channels
):
super
(
ObjectAttentionBlock
,
self
).
__init__
()
super
(
ObjectAttentionBlock
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
key_channels
=
key_channels
self
.
key_channels
=
key_channels
self
.
f_pixel
=
Sequential
(
self
.
f_pixel
=
nn
.
Sequential
(
ConvB
nRelu
(
in_channels
,
key_channels
,
1
),
ConvB
NReLU
(
in_channels
,
key_channels
,
1
),
ConvB
nRelu
(
key_channels
,
key_channels
,
1
))
ConvB
NReLU
(
key_channels
,
key_channels
,
1
))
self
.
f_object
=
Sequential
(
self
.
f_object
=
nn
.
Sequential
(
ConvB
nRelu
(
in_channels
,
key_channels
,
1
),
ConvB
NReLU
(
in_channels
,
key_channels
,
1
),
ConvB
nRelu
(
key_channels
,
key_channels
,
1
))
ConvB
NReLU
(
key_channels
,
key_channels
,
1
))
self
.
f_down
=
ConvB
nRelu
(
in_channels
,
key_channels
,
1
)
self
.
f_down
=
ConvB
NReLU
(
in_channels
,
key_channels
,
1
)
self
.
f_up
=
ConvB
nRelu
(
key_channels
,
in_channels
,
1
)
self
.
f_up
=
ConvB
NReLU
(
key_channels
,
in_channels
,
1
)
def
forward
(
self
,
x
,
proxy
):
def
forward
(
self
,
x
,
proxy
):
n
,
_
,
h
,
w
=
x
.
shape
n
,
_
,
h
,
w
=
x
.
shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query
=
self
.
f_pixel
(
x
)
query
=
self
.
f_pixel
(
x
)
query
=
fluid
.
layers
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
paddle
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
fluid
.
layers
.
transpose
(
query
,
(
0
,
2
,
1
))
query
=
paddle
.
transpose
(
query
,
(
0
,
2
,
1
))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key
=
self
.
f_object
(
proxy
)
key
=
self
.
f_object
(
proxy
)
key
=
fluid
.
layers
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
key
=
paddle
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value
=
self
.
f_down
(
proxy
)
value
=
self
.
f_down
(
proxy
)
value
=
fluid
.
layers
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
paddle
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
fluid
.
layers
.
transpose
(
value
,
(
0
,
2
,
1
))
value
=
paddle
.
transpose
(
value
,
(
0
,
2
,
1
))
# sim_map (n, h1*w1, h2*w2)
# sim_map (n, h1*w1, h2*w2)
sim_map
=
fluid
.
layers
.
matmul
(
query
,
key
)
sim_map
=
paddle
.
bmm
(
query
,
key
)
sim_map
=
(
self
.
key_channels
**-
.
5
)
*
sim_map
sim_map
=
(
self
.
key_channels
**-
.
5
)
*
sim_map
sim_map
=
fluid
.
layers
.
softmax
(
sim_map
,
axis
=-
1
)
sim_map
=
F
.
softmax
(
sim_map
,
axis
=-
1
)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context
=
fluid
.
layers
.
matmul
(
sim_map
,
value
)
context
=
paddle
.
bmm
(
sim_map
,
value
)
context
=
fluid
.
layers
.
transpose
(
context
,
(
0
,
2
,
1
))
context
=
paddle
.
transpose
(
context
,
(
0
,
2
,
1
))
context
=
fluid
.
layers
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
paddle
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
self
.
f_up
(
context
)
context
=
self
.
f_up
(
context
)
return
context
return
context
@
manager
.
MODELS
.
add_component
class
OCRHead
(
nn
.
Layer
):
class
OCRNet
(
fluid
.
dygraph
.
Layer
):
"""
The Object contextual representation head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
,
backbone
,
model_pretrained
=
None
,
in_channels
=
None
,
in_channels
=
None
,
ocr_mid_channels
=
512
,
ocr_mid_channels
=
512
,
ocr_key_channels
=
256
,
ocr_key_channels
=
256
):
ignore_index
=
255
):
super
(
OCRHead
,
self
).
__init__
()
super
(
OCRNet
,
self
).
__init__
()
self
.
ignore_index
=
ignore_index
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
EPS
=
1e-5
self
.
backbone
=
backbone
self
.
spatial_gather
=
SpatialGatherBlock
()
self
.
spatial_gather
=
SpatialGatherBlock
()
self
.
spatial_ocr
=
SpatialOCRModule
(
ocr_mid_channels
,
ocr_key_channels
,
self
.
spatial_ocr
=
SpatialOCRModule
(
ocr_mid_channels
,
ocr_key_channels
,
ocr_mid_channels
)
ocr_mid_channels
)
self
.
conv3x3_ocr
=
ConvBnRelu
(
in_channels
,
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
Conv2D
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
Sequential
(
self
.
indices
=
[
-
2
,
-
1
]
if
len
(
in_channels
)
>
1
else
[
-
1
,
-
1
]
ConvBnRelu
(
in_channels
,
in_channels
,
3
,
padding
=
1
),
Conv2D
(
in_channels
,
self
.
num_classes
,
1
))
self
.
init_weight
(
model_pretrained
)
self
.
conv3x3_ocr
=
ConvBNReLU
(
in_channels
[
self
.
indices
[
1
]],
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
nn
.
Conv2d
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
AuxLayer
(
in_channels
[
self
.
indices
[
0
]],
in_channels
[
self
.
indices
[
0
]],
self
.
num_classes
)
self
.
init_weight
()
def
forward
(
self
,
x
,
label
=
None
):
def
forward
(
self
,
x
,
label
=
None
):
feat
s
=
self
.
backbone
(
x
)
feat
_shallow
,
feat_deep
=
x
[
self
.
indices
[
0
]],
x
[
self
.
indices
[
1
]]
soft_regions
=
self
.
aux_head
(
feat
s
)
soft_regions
=
self
.
aux_head
(
feat
_shallow
)
pixels
=
self
.
conv3x3_ocr
(
feat
s
)
pixels
=
self
.
conv3x3_ocr
(
feat
_deep
)
object_regions
=
self
.
spatial_gather
(
pixels
,
soft_regions
)
object_regions
=
self
.
spatial_gather
(
pixels
,
soft_regions
)
ocr
=
self
.
spatial_ocr
(
pixels
,
object_regions
)
ocr
=
self
.
spatial_ocr
(
pixels
,
object_regions
)
logit
=
self
.
cls_head
(
ocr
)
logit
=
self
.
cls_head
(
ocr
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
x
.
shape
[
2
:])
return
[
logit
,
soft_regions
]
if
self
.
training
:
def
init_weight
(
self
):
soft_regions
=
fluid
.
layers
.
resize_bilinear
(
soft_regions
,
"""Initialize the parameters of model parts."""
x
.
shape
[
2
:])
for
sublayer
in
self
.
sublayers
():
cls_loss
=
self
.
_get_loss
(
logit
,
label
)
if
isinstance
(
sublayer
,
nn
.
Conv2d
):
aux_loss
=
self
.
_get_loss
(
soft_regions
,
label
)
param_init
.
normal_init
(
sublayer
.
weight
,
scale
=
0.001
)
return
cls_loss
+
0.4
*
aux_loss
elif
isinstance
(
sublayer
,
nn
.
SyncBatchNorm
):
param_init
.
constant_init
(
sublayer
.
weight
,
value
=
1
)
score_map
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
param_init
.
constant_init
(
sublayer
.
bias
,
value
=
0
)
score_map
=
fluid
.
layers
.
transpose
(
score_map
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
score_map
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
@
manager
.
MODELS
.
add_component
return
pred
,
score_map
class
OCRNet
(
nn
.
Layer
):
"""
def
init_weight
(
self
,
pretrained_model
=
None
):
The OCRNet implementation based on PaddlePaddle.
The original article refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of pixel representation.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def
__init__
(
self
,
num_classes
,
backbone
,
pretrained
=
None
,
backbone_indices
=
None
,
ocr_mid_channels
=
512
,
ocr_key_channels
=
256
):
super
(
OCRNet
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
backbone_indices
=
backbone_indices
in_channels
=
[
self
.
backbone
.
channels
[
i
]
for
i
in
backbone_indices
]
self
.
head
=
OCRHead
(
num_classes
=
num_classes
,
in_channels
=
in_channels
,
ocr_mid_channels
=
ocr_mid_channels
,
ocr_key_channels
=
ocr_key_channels
)
self
.
init_weight
(
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
feats
=
self
.
backbone
(
x
)
feats
=
[
feats
[
i
]
for
i
in
self
.
backbone_indices
]
preds
=
self
.
head
(
feats
,
label
)
preds
=
[
F
.
resize_bilinear
(
pred
,
x
.
shape
[
2
:])
for
pred
in
preds
]
return
preds
def
init_weight
(
self
,
pretrained
=
None
):
"""
"""
Initialize the parameters of model parts.
Initialize the parameters of model parts.
Args:
Args:
pretrained
_model
([str], optional): the path of pretrained model.. Defaults to None.
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
"""
if
pretrained
_model
is
not
None
:
if
pretrained
is
not
None
:
if
os
.
path
.
exists
(
pretrained
_model
):
if
os
.
path
.
exists
(
pretrained
):
utils
.
load_pretrained_model
(
self
,
pretrained
_model
)
utils
.
load_pretrained_model
(
self
,
pretrained
)
else
:
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
raise
Exception
(
pretrained_model
))
'Pretrained model is not found: {}'
.
format
(
pretrained
))
def
_get_loss
(
self
,
logit
,
label
):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
mask
=
label
!=
self
.
ignore_index
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
self
.
ignore_index
,
return_softmax
=
True
,
axis
=-
1
)
loss
=
loss
*
mask
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
mask
)
+
self
.
EPS
)
label
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
dygraph/paddleseg/utils/utils.py
浏览文件 @
a3bfb074
...
@@ -12,13 +12,28 @@
...
@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
contextlib
import
os
import
os
import
numpy
as
np
import
numpy
as
np
import
math
import
math
import
cv2
import
cv2
import
tempfile
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
urllib.parse
import
urlparse
,
unquote
from
.
import
logger
import
filelock
import
paddleseg.env
as
segenv
from
paddleseg.utils
import
logger
from
paddleseg.utils.download
import
download_file_and_uncompress
@
contextlib
.
contextmanager
def
generate_tempdir
(
directory
:
str
=
None
,
**
kwargs
):
'''Generate a temporary directory'''
directory
=
segenv
.
TMP_HOME
if
not
directory
else
directory
with
tempfile
.
TemporaryDirectory
(
dir
=
directory
,
**
kwargs
)
as
_dir
:
yield
_dir
def
seconds_to_hms
(
seconds
):
def
seconds_to_hms
(
seconds
):
...
@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
...
@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
def
load_pretrained_model
(
model
,
pretrained_model
):
def
load_pretrained_model
(
model
,
pretrained_model
):
if
pretrained_model
is
not
None
:
if
pretrained_model
is
not
None
:
logger
.
info
(
'Load pretrained model from {}'
.
format
(
pretrained_model
))
logger
.
info
(
'Load pretrained model from {}'
.
format
(
pretrained_model
))
# download pretrained model from url
if
urlparse
(
pretrained_model
).
netloc
:
pretrained_model
=
unquote
(
pretrained_model
)
savename
=
pretrained_model
.
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
with
generate_tempdir
()
as
_dir
:
with
filelock
.
FileLock
(
os
.
path
.
join
(
segenv
.
TMP_HOME
,
savename
)):
pretrained_model
=
download_file_and_uncompress
(
pretrained_model
,
savepath
=
_dir
,
extrapath
=
segenv
.
PRETRAINED_MODEL_HOME
,
extraname
=
savename
)
if
os
.
path
.
exists
(
pretrained_model
):
if
os
.
path
.
exists
(
pretrained_model
):
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
ckpt_path
=
os
.
path
.
join
(
pretrained_model
,
'model'
)
try
:
try
:
...
...
dygraph/train.py
浏览文件 @
a3bfb074
...
@@ -112,9 +112,10 @@ def main(args):
...
@@ -112,9 +112,10 @@ def main(args):
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
losses
=
cfg
.
loss
losses
=
cfg
.
loss
print
(
'---------------Config Information---------------'
)
msg
=
'
\n
---------------Config Information---------------
\n
'
print
(
cfg
)
msg
+=
str
(
cfg
)
print
(
'------------------------------------------------'
)
msg
+=
'------------------------------------------------'
logger
.
info
(
msg
)
train
(
train
(
cfg
.
model
,
cfg
.
model
,
...
...
dygraph/val.py
浏览文件 @
a3bfb074
...
@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
...
@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
import
paddleseg
import
paddleseg
from
paddleseg.cvlibs
import
manager
from
paddleseg.cvlibs
import
manager
from
paddleseg.utils
import
get_environ_info
,
Config
from
paddleseg.utils
import
get_environ_info
,
Config
,
logger
from
paddleseg.core
import
evaluate
from
paddleseg.core
import
evaluate
...
@@ -56,9 +56,10 @@ def main(args):
...
@@ -56,9 +56,10 @@ def main(args):
'The verification dataset is not specified in the configuration file.'
'The verification dataset is not specified in the configuration file.'
)
)
print
(
'---------------Config Information---------------'
)
msg
=
'
\n
---------------Config Information---------------
\n
'
print
(
cfg
)
msg
+=
str
(
cfg
)
print
(
'------------------------------------------------'
)
msg
+=
'------------------------------------------------'
logger
.
info
(
msg
)
evaluate
(
evaluate
(
cfg
.
model
,
cfg
.
model
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录