Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
d6115158
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d6115158
编写于
9月 07, 2021
作者:
A
andyjpaddle
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix code style
上级
ae09ef60
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
92 addition
and
87 deletion
+92
-87
ppocr/modeling/heads/rec_sar_head.py
ppocr/modeling/heads/rec_sar_head.py
+92
-87
未找到文件。
ppocr/modeling/heads/rec_sar_head.py
浏览文件 @
d6115158
...
...
@@ -19,6 +19,7 @@ class SAREncoder(nn.Layer):
d_enc (int): Dim of encoder RNN layer.
mask (bool): If True, mask padding in RNN sequence.
"""
def
__init__
(
self
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
...
...
@@ -51,33 +52,31 @@ class SAREncoder(nn.Layer):
num_layers
=
2
,
time_major
=
False
,
dropout
=
enc_drop_rnn
,
direction
=
direction
)
direction
=
direction
)
if
enc_gru
:
self
.
rnn_encoder
=
nn
.
GRU
(
**
kwargs
)
else
:
self
.
rnn_encoder
=
nn
.
LSTM
(
**
kwargs
)
# global feature transformation
encoder_rnn_out_size
=
d_enc
*
(
int
(
enc_bi_rnn
)
+
1
)
self
.
linear
=
nn
.
Linear
(
encoder_rnn_out_size
,
encoder_rnn_out_size
)
def
forward
(
self
,
feat
,
img_metas
=
None
):
if
img_metas
is
not
None
:
assert
len
(
img_metas
[
0
])
==
feat
.
shape
[
0
]
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
valid_ratios
=
img_metas
[
-
1
]
h_feat
=
feat
.
shape
[
2
]
# bsz c h w
h_feat
=
feat
.
shape
[
2
]
# bsz c h w
feat_v
=
F
.
max_pool2d
(
feat
,
kernel_size
=
(
h_feat
,
1
),
stride
=
1
,
padding
=
0
)
feat_v
=
feat_v
.
squeeze
(
2
)
# bsz * C * W
feat_v
=
paddle
.
transpose
(
feat_v
,
perm
=
[
0
,
2
,
1
])
# bsz * W * C
holistic_feat
=
self
.
rnn_encoder
(
feat_v
)[
0
]
# bsz * T * C
feat
,
kernel_size
=
(
h_feat
,
1
),
stride
=
1
,
padding
=
0
)
feat_v
=
feat_v
.
squeeze
(
2
)
# bsz * C * W
feat_v
=
paddle
.
transpose
(
feat_v
,
perm
=
[
0
,
2
,
1
])
# bsz * W * C
holistic_feat
=
self
.
rnn_encoder
(
feat_v
)[
0
]
# bsz * T * C
if
valid_ratios
is
not
None
:
valid_hf
=
[]
T
=
holistic_feat
.
shape
[
1
]
...
...
@@ -86,11 +85,11 @@ class SAREncoder(nn.Layer):
valid_hf
.
append
(
holistic_feat
[
i
,
valid_step
,
:])
valid_hf
=
paddle
.
stack
(
valid_hf
,
axis
=
0
)
else
:
valid_hf
=
holistic_feat
[:,
-
1
,
:]
# bsz * C
holistic_feat
=
self
.
linear
(
valid_hf
)
# bsz * C
valid_hf
=
holistic_feat
[:,
-
1
,
:]
# bsz * C
holistic_feat
=
self
.
linear
(
valid_hf
)
# bsz * C
return
holistic_feat
class
BaseDecoder
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
...
...
@@ -102,7 +101,7 @@ class BaseDecoder(nn.Layer):
def
forward_test
(
self
,
feat
,
out_enc
,
img_metas
):
raise
NotImplementedError
def
forward
(
self
,
def
forward
(
self
,
feat
,
out_enc
,
label
=
None
,
...
...
@@ -135,20 +134,21 @@ class ParallelSARDecoder(BaseDecoder):
attention with holistic feature and hidden state.
"""
def
__init__
(
self
,
out_channels
,
# 90 + unknown + start + padding
enc_bi_rnn
=
False
,
dec_bi_rnn
=
False
,
dec_drop_rnn
=
0.0
,
dec_gru
=
False
,
d_model
=
512
,
d_enc
=
512
,
d_k
=
64
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
mask
=
True
,
pred_concat
=
True
,
**
kwargs
):
def
__init__
(
self
,
out_channels
,
# 90 + unknown + start + padding
enc_bi_rnn
=
False
,
dec_bi_rnn
=
False
,
dec_drop_rnn
=
0.0
,
dec_gru
=
False
,
d_model
=
512
,
d_enc
=
512
,
d_k
=
64
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
mask
=
True
,
pred_concat
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
num_classes
=
out_channels
...
...
@@ -165,7 +165,8 @@ class ParallelSARDecoder(BaseDecoder):
# 2D attention layer
self
.
conv1x1_1
=
nn
.
Linear
(
decoder_rnn_out_size
,
d_k
)
self
.
conv3x3_1
=
nn
.
Conv2D
(
d_model
,
d_k
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv3x3_1
=
nn
.
Conv2D
(
d_model
,
d_k
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1x1_2
=
nn
.
Linear
(
d_k
,
1
)
# Decoder RNN layer
...
...
@@ -180,8 +181,7 @@ class ParallelSARDecoder(BaseDecoder):
num_layers
=
2
,
time_major
=
False
,
dropout
=
dec_drop_rnn
,
direction
=
direction
)
direction
=
direction
)
if
dec_gru
:
self
.
rnn_decoder
=
nn
.
GRU
(
**
kwargs
)
else
:
...
...
@@ -189,8 +189,10 @@ class ParallelSARDecoder(BaseDecoder):
# Decoder input embedding
self
.
embedding
=
nn
.
Embedding
(
self
.
num_classes
,
encoder_rnn_out_size
,
padding_idx
=
self
.
padding_idx
)
self
.
num_classes
,
encoder_rnn_out_size
,
padding_idx
=
self
.
padding_idx
)
# Prediction layer
self
.
pred_dropout
=
nn
.
Dropout
(
pred_dropout
)
pred_num_classes
=
num_classes
-
1
...
...
@@ -205,11 +207,11 @@ class ParallelSARDecoder(BaseDecoder):
feat
,
holistic_feat
,
valid_ratios
=
None
):
y
=
self
.
rnn_decoder
(
decoder_input
)[
0
]
# y: bsz * (seq_len + 1) * hidden_size
attn_query
=
self
.
conv1x1_1
(
y
)
# bsz * (seq_len + 1) * attn_size
attn_query
=
self
.
conv1x1_1
(
y
)
# bsz * (seq_len + 1) * attn_size
bsz
,
seq_len
,
attn_size
=
attn_query
.
shape
attn_query
=
paddle
.
unsqueeze
(
attn_query
,
axis
=
[
3
,
4
])
# (bsz, seq_len + 1, attn_size, 1, 1)
...
...
@@ -220,7 +222,7 @@ class ParallelSARDecoder(BaseDecoder):
# bsz * 1 * attn_size * h * w
attn_weight
=
paddle
.
tanh
(
paddle
.
add
(
attn_key
,
attn_query
))
# bsz * (seq_len + 1) * attn_size * h * w
attn_weight
=
paddle
.
transpose
(
attn_weight
,
perm
=
[
0
,
1
,
3
,
4
,
2
])
# bsz * (seq_len + 1) * h * w * attn_size
...
...
@@ -237,25 +239,28 @@ class ParallelSARDecoder(BaseDecoder):
attn_weight
=
paddle
.
reshape
(
attn_weight
,
[
bsz
,
T
,
-
1
])
attn_weight
=
F
.
softmax
(
attn_weight
,
axis
=-
1
)
attn_weight
=
paddle
.
reshape
(
attn_weight
,
[
bsz
,
T
,
h
,
w
,
c
])
attn_weight
=
paddle
.
transpose
(
attn_weight
,
perm
=
[
0
,
1
,
4
,
2
,
3
])
# attn_weight: bsz * T * c * h * w
# feat: bsz * c * h * w
attn_feat
=
paddle
.
sum
(
paddle
.
multiply
(
feat
.
unsqueeze
(
1
),
attn_weight
),
(
3
,
4
),
keepdim
=
False
)
attn_feat
=
paddle
.
sum
(
paddle
.
multiply
(
feat
.
unsqueeze
(
1
),
attn_weight
),
(
3
,
4
),
keepdim
=
False
)
# bsz * (seq_len + 1) * C
# Linear transformation
if
self
.
pred_concat
:
hf_c
=
holistic_feat
.
shape
[
-
1
]
holistic_feat
=
paddle
.
expand
(
holistic_feat
,
shape
=
[
bsz
,
seq_len
,
hf_c
])
holistic_feat
=
paddle
.
expand
(
holistic_feat
,
shape
=
[
bsz
,
seq_len
,
hf_c
])
y
=
self
.
prediction
(
paddle
.
concat
((
y
,
attn_feat
,
holistic_feat
),
2
))
else
:
y
=
self
.
prediction
(
attn_feat
)
# bsz * (seq_len + 1) * num_classes
if
self
.
train_mode
:
y
=
self
.
pred_dropout
(
y
)
return
y
def
forward_train
(
self
,
feat
,
out_enc
,
label
,
img_metas
):
...
...
@@ -268,7 +273,7 @@ class ParallelSARDecoder(BaseDecoder):
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
valid_ratios
=
img_metas
[
-
1
]
label
=
label
.
cuda
()
lab_embedding
=
self
.
embedding
(
label
)
# bsz * seq_len * emb_dim
...
...
@@ -277,11 +282,10 @@ class ParallelSARDecoder(BaseDecoder):
in_dec
=
paddle
.
concat
((
out_enc
,
lab_embedding
),
axis
=
1
)
# bsz * (seq_len + 1) * C
out_dec
=
self
.
_2d_attention
(
in_dec
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
in_dec
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
# bsz * (seq_len + 1) * num_classes
return
out_dec
[:,
1
:,
:]
# bsz * seq_len * num_classes
return
out_dec
[:,
1
:,
:]
# bsz * seq_len * num_classes
def
forward_test
(
self
,
feat
,
out_enc
,
img_metas
):
if
img_metas
is
not
None
:
...
...
@@ -289,13 +293,12 @@ class ParallelSARDecoder(BaseDecoder):
valid_ratios
=
None
if
img_metas
is
not
None
and
self
.
mask
:
valid_ratios
=
img_metas
[
-
1
]
valid_ratios
=
img_metas
[
-
1
]
seq_len
=
self
.
max_seq_len
bsz
=
feat
.
shape
[
0
]
start_token
=
paddle
.
full
((
bsz
,
),
fill_value
=
self
.
start_idx
,
dtype
=
'int64'
)
start_token
=
paddle
.
full
(
(
bsz
,
),
fill_value
=
self
.
start_idx
,
dtype
=
'int64'
)
# bsz
start_token
=
self
.
embedding
(
start_token
)
# bsz * emb_dim
...
...
@@ -311,68 +314,70 @@ class ParallelSARDecoder(BaseDecoder):
outputs
=
[]
for
i
in
range
(
1
,
seq_len
+
1
):
decoder_output
=
self
.
_2d_attention
(
decoder_input
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
char_output
=
decoder_output
[:,
i
,
:]
# bsz * num_classes
decoder_input
,
feat
,
out_enc
,
valid_ratios
=
valid_ratios
)
char_output
=
decoder_output
[:,
i
,
:]
# bsz * num_classes
char_output
=
F
.
softmax
(
char_output
,
-
1
)
outputs
.
append
(
char_output
)
max_idx
=
paddle
.
argmax
(
char_output
,
axis
=
1
,
keepdim
=
False
)
char_embedding
=
self
.
embedding
(
max_idx
)
# bsz * emb_dim
char_embedding
=
self
.
embedding
(
max_idx
)
# bsz * emb_dim
if
i
<
seq_len
:
decoder_input
[:,
i
+
1
,
:]
=
char_embedding
outputs
=
paddle
.
stack
(
outputs
,
1
)
# bsz * seq_len * num_classes
outputs
=
paddle
.
stack
(
outputs
,
1
)
# bsz * seq_len * num_classes
return
outputs
class
SARHead
(
nn
.
Layer
):
def
__init__
(
self
,
out_channels
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
enc_gru
=
False
,
dec_bi_rnn
=
False
,
dec_drop_rnn
=
0.0
,
dec_gru
=
False
,
d_k
=
512
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
pred_concat
=
True
,
**
kwargs
):
def
__init__
(
self
,
out_channels
,
enc_bi_rnn
=
False
,
enc_drop_rnn
=
0.1
,
enc_gru
=
False
,
dec_bi_rnn
=
False
,
dec_drop_rnn
=
0.0
,
dec_gru
=
False
,
d_k
=
512
,
pred_dropout
=
0.1
,
max_text_length
=
30
,
pred_concat
=
True
,
**
kwargs
):
super
(
SARHead
,
self
).
__init__
()
# encoder module
self
.
encoder
=
SAREncoder
(
enc_bi_rnn
=
enc_bi_rnn
,
enc_drop_rnn
=
enc_drop_rnn
,
enc_gru
=
enc_gru
)
enc_bi_rnn
=
enc_bi_rnn
,
enc_drop_rnn
=
enc_drop_rnn
,
enc_gru
=
enc_gru
)
# decoder module
self
.
decoder
=
ParallelSARDecoder
(
out_channels
=
out_channels
,
enc_bi_rnn
=
enc_bi_rnn
,
enc_bi_rnn
=
enc_bi_rnn
,
dec_bi_rnn
=
dec_bi_rnn
,
dec_drop_rnn
=
dec_drop_rnn
,
dec_gru
=
dec_gru
,
d_k
=
d_k
,
pred_dropout
=
pred_dropout
,
max_text_length
=
max_text_length
,
pred_concat
=
pred_concat
)
pred_concat
=
pred_concat
)
def
forward
(
self
,
feat
,
targets
=
None
):
'''
img_metas: [label, valid_ratio]
'''
holistic_feat
=
self
.
encoder
(
feat
,
targets
)
# bsz c
holistic_feat
=
self
.
encoder
(
feat
,
targets
)
# bsz c
if
self
.
training
:
label
=
targets
[
0
]
# label
label
=
targets
[
0
]
# label
label
=
paddle
.
to_tensor
(
label
,
dtype
=
'int64'
)
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
,
img_metas
=
targets
)
if
not
self
.
training
:
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
=
None
,
img_metas
=
targets
,
train_mode
=
False
)
final_out
=
self
.
decoder
(
feat
,
holistic_feat
,
label
=
None
,
img_metas
=
targets
,
train_mode
=
False
)
# (bsz, seq_len, num_classes)
return
final_out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录