Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
fca7ec85
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fca7ec85
编写于
9月 22, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update ocrnet
上级
3b3a69b7
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
132 addition
and
110 deletion
+132
-110
dygraph/paddleseg/models/ocrnet.py
dygraph/paddleseg/models/ocrnet.py
+132
-110
未找到文件。
dygraph/paddleseg/models/ocrnet.py
浏览文件 @
fca7ec85
...
@@ -14,36 +14,41 @@
...
@@ -14,36 +14,41 @@
import
os
import
os
import
paddle.fluid
as
fluid
import
paddle
from
paddle.fluid.dygraph
import
Sequential
,
Conv2D
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddleseg.cvlibs
import
manager
from
paddleseg.models.common.layer_libs
import
ConvBnRelu
from
paddleseg
import
utils
from
paddleseg
import
utils
from
paddleseg.cvlibs
import
manager
,
param_init
from
paddleseg.models.common.layer_libs
import
ConvBNReLU
,
AuxLayer
class
SpatialGatherBlock
(
fluid
.
dygraph
.
Layer
):
class
SpatialGatherBlock
(
nn
.
Layer
):
"""Aggregation layer to compute the pixel-region representation"""
def
forward
(
self
,
pixels
,
regions
):
def
forward
(
self
,
pixels
,
regions
):
n
,
c
,
h
,
w
=
pixels
.
shape
n
,
c
,
h
,
w
=
pixels
.
shape
_
,
k
,
_
,
_
=
regions
.
shape
_
,
k
,
_
,
_
=
regions
.
shape
# pixels: from (n, c, h, w) to (n, h*w, c)
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels
=
fluid
.
layers
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
paddle
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
fluid
.
layers
.
transpose
(
pixels
,
(
0
,
2
,
1
))
pixels
=
paddle
.
transpose
(
pixels
,
(
0
,
2
,
1
))
# regions: from (n, k, h, w) to (n, k, h*w)
# regions: from (n, k, h, w) to (n, k, h*w)
regions
=
fluid
.
layers
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
paddle
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
fluid
.
layers
.
softmax
(
regions
,
axis
=
2
)
regions
=
F
.
softmax
(
regions
,
axis
=
2
)
# feats: from (n, k, c) to (n, c, k, 1)
# feats: from (n, k, c) to (n, c, k, 1)
feats
=
fluid
.
layers
.
matmul
(
regions
,
pixels
)
feats
=
paddle
.
bmm
(
regions
,
pixels
)
feats
=
fluid
.
layers
.
transpose
(
feats
,
(
0
,
2
,
1
))
feats
=
paddle
.
transpose
(
feats
,
(
0
,
2
,
1
))
feats
=
fluid
.
layers
.
unsqueeze
(
feats
,
axes
=
[
-
1
]
)
feats
=
paddle
.
unsqueeze
(
feats
,
axis
=-
1
)
return
feats
return
feats
class
SpatialOCRModule
(
fluid
.
dygraph
.
Layer
):
class
SpatialOCRModule
(
nn
.
Layer
):
"""Aggregate the global object representation to update the representation for each pixel"""
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
key_channels
,
key_channels
,
...
@@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer):
...
@@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer):
self
.
attention_block
=
ObjectAttentionBlock
(
in_channels
,
key_channels
)
self
.
attention_block
=
ObjectAttentionBlock
(
in_channels
,
key_channels
)
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
conv1x1
=
Conv2D
(
2
*
in_channels
,
out_channels
,
1
)
self
.
conv1x1
=
nn
.
Sequential
(
nn
.
Conv2d
(
2
*
in_channels
,
out_channels
,
1
),
nn
.
Dropout2d
(
0.1
))
def
forward
(
self
,
pixels
,
regions
):
def
forward
(
self
,
pixels
,
regions
):
context
=
self
.
attention_block
(
pixels
,
regions
)
context
=
self
.
attention_block
(
pixels
,
regions
)
feats
=
fluid
.
layers
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
paddle
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
self
.
conv1x1
(
feats
)
feats
=
self
.
conv1x1
(
feats
)
feats
=
fluid
.
layers
.
dropout
(
feats
,
self
.
dropout_rate
)
return
feats
return
feats
class
ObjectAttentionBlock
(
fluid
.
dygraph
.
Layer
):
class
ObjectAttentionBlock
(
nn
.
Layer
):
"""A self-attention module."""
def
__init__
(
self
,
in_channels
,
key_channels
):
def
__init__
(
self
,
in_channels
,
key_channels
):
super
(
ObjectAttentionBlock
,
self
).
__init__
()
super
(
ObjectAttentionBlock
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
key_channels
=
key_channels
self
.
key_channels
=
key_channels
self
.
f_pixel
=
Sequential
(
self
.
f_pixel
=
nn
.
Sequential
(
ConvB
nRelu
(
in_channels
,
key_channels
,
1
),
ConvB
NReLU
(
in_channels
,
key_channels
,
1
),
ConvB
nRelu
(
key_channels
,
key_channels
,
1
))
ConvB
NReLU
(
key_channels
,
key_channels
,
1
))
self
.
f_object
=
Sequential
(
self
.
f_object
=
nn
.
Sequential
(
ConvB
nRelu
(
in_channels
,
key_channels
,
1
),
ConvB
NReLU
(
in_channels
,
key_channels
,
1
),
ConvB
nRelu
(
key_channels
,
key_channels
,
1
))
ConvB
NReLU
(
key_channels
,
key_channels
,
1
))
self
.
f_down
=
ConvB
nRelu
(
in_channels
,
key_channels
,
1
)
self
.
f_down
=
ConvB
NReLU
(
in_channels
,
key_channels
,
1
)
self
.
f_up
=
ConvB
nRelu
(
key_channels
,
in_channels
,
1
)
self
.
f_up
=
ConvB
NReLU
(
key_channels
,
in_channels
,
1
)
def
forward
(
self
,
x
,
proxy
):
def
forward
(
self
,
x
,
proxy
):
n
,
_
,
h
,
w
=
x
.
shape
n
,
_
,
h
,
w
=
x
.
shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query
=
self
.
f_pixel
(
x
)
query
=
self
.
f_pixel
(
x
)
query
=
fluid
.
layers
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
paddle
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
fluid
.
layers
.
transpose
(
query
,
(
0
,
2
,
1
))
query
=
paddle
.
transpose
(
query
,
(
0
,
2
,
1
))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key
=
self
.
f_object
(
proxy
)
key
=
self
.
f_object
(
proxy
)
key
=
fluid
.
layers
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
key
=
paddle
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value
=
self
.
f_down
(
proxy
)
value
=
self
.
f_down
(
proxy
)
value
=
fluid
.
layers
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
paddle
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
fluid
.
layers
.
transpose
(
value
,
(
0
,
2
,
1
))
value
=
paddle
.
transpose
(
value
,
(
0
,
2
,
1
))
# sim_map (n, h1*w1, h2*w2)
# sim_map (n, h1*w1, h2*w2)
sim_map
=
fluid
.
layers
.
matmul
(
query
,
key
)
sim_map
=
paddle
.
bmm
(
query
,
key
)
sim_map
=
(
self
.
key_channels
**-
.
5
)
*
sim_map
sim_map
=
(
self
.
key_channels
**-
.
5
)
*
sim_map
sim_map
=
fluid
.
layers
.
softmax
(
sim_map
,
axis
=-
1
)
sim_map
=
F
.
softmax
(
sim_map
,
axis
=-
1
)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context
=
fluid
.
layers
.
matmul
(
sim_map
,
value
)
context
=
paddle
.
bmm
(
sim_map
,
value
)
context
=
fluid
.
layers
.
transpose
(
context
,
(
0
,
2
,
1
))
context
=
paddle
.
transpose
(
context
,
(
0
,
2
,
1
))
context
=
fluid
.
layers
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
paddle
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
self
.
f_up
(
context
)
context
=
self
.
f_up
(
context
)
return
context
return
context
@
manager
.
MODELS
.
add_component
class
OCRHead
(
nn
.
Layer
):
class
OCRNet
(
fluid
.
dygraph
.
Layer
):
"""
The OCR Head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_classes
,
num_classes
,
backbone
,
model_pretrained
=
None
,
in_channels
=
None
,
in_channels
=
None
,
ocr_mid_channels
=
512
,
ocr_mid_channels
=
512
,
ocr_key_channels
=
256
,
ocr_key_channels
=
256
):
ignore_index
=
255
):
super
(
OCRHead
,
self
).
__init__
()
super
(
OCRNet
,
self
).
__init__
()
self
.
ignore_index
=
ignore_index
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
EPS
=
1e-5
self
.
backbone
=
backbone
self
.
spatial_gather
=
SpatialGatherBlock
()
self
.
spatial_gather
=
SpatialGatherBlock
()
self
.
spatial_ocr
=
SpatialOCRModule
(
ocr_mid_channels
,
ocr_key_channels
,
self
.
spatial_ocr
=
SpatialOCRModule
(
ocr_mid_channels
,
ocr_key_channels
,
ocr_mid_channels
)
ocr_mid_channels
)
self
.
conv3x3_ocr
=
ConvBnRelu
(
in_channels
,
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
Conv2D
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
Sequential
(
self
.
indices
=
[
-
2
,
-
1
]
if
len
(
in_channels
)
>
1
else
[
-
1
,
-
1
]
ConvBnRelu
(
in_channels
,
in_channels
,
3
,
padding
=
1
),
Conv2D
(
in_channels
,
self
.
num_classes
,
1
))
self
.
init_weight
(
model_pretrained
)
self
.
conv3x3_ocr
=
ConvBNReLU
(
in_channels
[
self
.
indices
[
1
]],
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
nn
.
Conv2d
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
AuxLayer
(
in_channels
[
self
.
indices
[
0
]],
in_channels
[
self
.
indices
[
0
]],
self
.
num_classes
)
self
.
init_weight
()
def
forward
(
self
,
x
,
label
=
None
):
def
forward
(
self
,
x
,
label
=
None
):
feat
s
=
self
.
backbone
(
x
)
feat
_shallow
,
feat_deep
=
x
[
self
.
indices
[
0
]],
x
[
self
.
indices
[
1
]]
soft_regions
=
self
.
aux_head
(
feat
s
)
soft_regions
=
self
.
aux_head
(
feat
_shallow
)
pixels
=
self
.
conv3x3_ocr
(
feat
s
)
pixels
=
self
.
conv3x3_ocr
(
feat
_deep
)
object_regions
=
self
.
spatial_gather
(
pixels
,
soft_regions
)
object_regions
=
self
.
spatial_gather
(
pixels
,
soft_regions
)
ocr
=
self
.
spatial_ocr
(
pixels
,
object_regions
)
ocr
=
self
.
spatial_ocr
(
pixels
,
object_regions
)
logit
=
self
.
cls_head
(
ocr
)
logit
=
self
.
cls_head
(
ocr
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
x
.
shape
[
2
:])
return
[
logit
,
soft_regions
]
if
self
.
training
:
def
init_weight
(
self
):
soft_regions
=
fluid
.
layers
.
resize_bilinear
(
soft_regions
,
"""Initialize the parameters of model parts."""
x
.
shape
[
2
:])
for
sublayer
in
self
.
sublayers
():
cls_loss
=
self
.
_get_loss
(
logit
,
label
)
if
isinstance
(
sublayer
,
nn
.
Conv2d
):
aux_loss
=
self
.
_get_loss
(
soft_regions
,
label
)
param_init
.
normal_init
(
sublayer
.
weight
,
scale
=
0.001
)
return
cls_loss
+
0.4
*
aux_loss
elif
isinstance
(
sublayer
,
nn
.
SyncBatchNorm
):
param_init
.
constant_init
(
sublayer
.
weight
,
value
=
1
)
score_map
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
param_init
.
constant_init
(
sublayer
.
bias
,
value
=
0
)
score_map
=
fluid
.
layers
.
transpose
(
score_map
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
score_map
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
@
manager
.
MODELS
.
add_component
return
pred
,
score_map
class
OCRNet
(
nn
.
Layer
):
"""
def
init_weight
(
self
,
pretrained_model
=
None
):
The OCRNet implementation based on PaddlePaddle.
The orginal artile refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Defaullt to None.
backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of pixel representation.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def
__init__
(
self
,
num_classes
,
backbone
,
pretrained
=
None
,
backbone_indices
=
None
,
ocr_mid_channels
=
512
,
ocr_key_channels
=
256
):
super
(
OCRNet
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
backbone_indices
=
backbone_indices
in_channels
=
[
self
.
backbone
.
channels
[
i
]
for
i
in
backbone_indices
]
self
.
head
=
OCRHead
(
num_classes
=
num_classes
,
in_channels
=
in_channels
,
ocr_mid_channels
=
ocr_mid_channels
,
ocr_key_channels
=
ocr_key_channels
)
self
.
init_weight
(
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
feats
=
self
.
backbone
(
x
)
feats
=
[
feats
[
i
]
for
i
in
self
.
backbone_indices
]
preds
=
self
.
head
(
feats
,
label
)
preds
=
[
F
.
resize_bilinear
(
pred
,
x
.
shape
[
2
:])
for
pred
in
preds
]
return
preds
def
init_weight
(
self
,
pretrained
=
None
):
"""
"""
Initialize the parameters of model parts.
Initialize the parameters of model parts.
Args:
Args:
pretrained
_model
([str], optional): the path of pretrained model.. Defaults to None.
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
"""
if
pretrained
_model
is
not
None
:
if
pretrained
is
not
None
:
if
os
.
path
.
exists
(
pretrained
_model
):
if
os
.
path
.
exists
(
pretrained
):
utils
.
load_pretrained_model
(
self
,
pretrained
_model
)
utils
.
load_pretrained_model
(
self
,
pretrained
)
else
:
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
raise
Exception
(
pretrained_model
))
'Pretrained model is not found: {}'
.
format
(
pretrained
))
def
_get_loss
(
self
,
logit
,
label
):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
mask
=
label
!=
self
.
ignore_index
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
self
.
ignore_index
,
return_softmax
=
True
,
axis
=-
1
)
loss
=
loss
*
mask
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
mask
)
+
self
.
EPS
)
label
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录