Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
b6f0a903
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b6f0a903
编写于
8月 16, 2021
作者:
T
Topdu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add rec_nrtr
上级
6127aad9
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
1338 addition
and
4 deletion
+1338
-4
configs/rec/rec_mtb_nrtr.yml
configs/rec/rec_mtb_nrtr.yml
+100
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+28
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+3
-2
ppocr/modeling/heads/multiheadAttention.py
ppocr/modeling/heads/multiheadAttention.py
+365
-0
ppocr/modeling/heads/rec_nrtr_optim_head.py
ppocr/modeling/heads/rec_nrtr_optim_head.py
+779
-0
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+63
-0
tools/eval.py
tools/eval.py
+0
-2
未找到文件。
configs/rec/rec_mtb_nrtr.yml
0 → 100644
浏览文件 @
b6f0a903
Global
:
use_gpu
:
True
epoch_num
:
21
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/nrtr_final/
save_epoch_step
:
1
# evaluation is run every 2000 iterations
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path
:
character_type
:
EN_symbol
max_text_length
:
25
infer_mode
:
False
use_space_char
:
True
save_res_path
:
./output/rec/predicts_nrtr.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.99
clip_norm
:
5.0
lr
:
name
:
Cosine
learning_rate
:
0.0005
warmup_epoch
:
2
regularizer
:
name
:
'
L2'
factor
:
0.
Architecture
:
model_type
:
rec
algorithm
:
NRTR
in_channels
:
1
Transform
:
Backbone
:
name
:
MTB
cnn_num
:
2
Head
:
name
:
TransformerOptim
d_model
:
512
num_encoder_layers
:
6
beam_size
:
-1
# When Beam size is greater than 0, it means to use beam search when evaluation.
Loss
:
name
:
NRTRLoss
smoothing
:
True
PostProcess
:
name
:
NRTRLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDataSet
data_dir
:
/paddle/data/ocr_data/training/
transforms
:
-
NRTRDecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
NRTRLabelEncode
:
# Class handling label
-
PILResize
:
image_shape
:
[
100
,
32
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
512
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
LMDBDataSet
data_dir
:
/paddle/data/ocr_data/evaluation/
transforms
:
-
NRTRDecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
NRTRLabelEncode
:
# Class handling label
-
PILResize
:
image_shape
:
[
100
,
32
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
1
use_shared_memory
:
False
ppocr/data/imaug/label_ops.py
浏览文件 @
b6f0a903
...
...
@@ -159,6 +159,34 @@ class BaseRecLabelEncode(object):
return
text_list
class
NRTRLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
character_dict_path
=
None
,
character_type
=
'EN_symbol'
,
use_space_char
=
False
,
**
kwargs
):
super
(
NRTRLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
.
insert
(
0
,
2
)
text
.
append
(
3
)
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
return
data
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
return
dict_character
class
CTCLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
...
...
ppocr/modeling/heads/__init__.py
浏览文件 @
b6f0a903
...
...
@@ -26,12 +26,13 @@ def build_head(config):
from
.rec_ctc_head
import
CTCHead
from
.rec_att_head
import
AttentionHead
from
.rec_srn_head
import
SRNHead
from
.rec_nrtr_optim_head
import
TransformerOptim
# cls head
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
]
'SRNHead'
,
'PGHead'
,
'TransformerOptim'
]
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/modeling/heads/multiheadAttention.py
0 → 100755
浏览文件 @
b6f0a903
import
paddle
from
paddle
import
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Linear
from
paddle.nn.initializer
import
XavierUniform
as
xavier_uniform_
from
paddle.nn.initializer
import
Constant
as
constant_
from
paddle.nn.initializer
import
XavierNormal
as
xavier_normal_
zeros_
=
constant_
(
value
=
0.
)
ones_
=
constant_
(
value
=
1.
)
class
MultiheadAttention
(
nn
.
Layer
):
r
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout
=
0.
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
super
(
MultiheadAttention
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
out_proj
=
Linear
(
embed_dim
,
embed_dim
,
bias_attr
=
bias
)
if
add_bias_kv
:
self
.
bias_k
=
self
.
create_parameter
(
shape
=
(
1
,
1
,
embed_dim
),
default_initializer
=
zeros_
)
self
.
add_parameter
(
"bias_k"
,
self
.
bias_k
)
self
.
bias_v
=
self
.
create_parameter
(
shape
=
(
1
,
1
,
embed_dim
),
default_initializer
=
zeros_
)
self
.
add_parameter
(
"bias_v"
,
self
.
bias_v
)
else
:
self
.
bias_k
=
self
.
bias_v
=
None
self
.
add_zero_attn
=
add_zero_attn
self
.
_reset_parameters
()
self
.
conv1
=
paddle
.
nn
.
Conv2D
(
in_channels
=
embed_dim
,
out_channels
=
embed_dim
,
kernel_size
=
(
1
,
1
))
self
.
conv2
=
paddle
.
nn
.
Conv2D
(
in_channels
=
embed_dim
,
out_channels
=
embed_dim
*
2
,
kernel_size
=
(
1
,
1
))
self
.
conv3
=
paddle
.
nn
.
Conv2D
(
in_channels
=
embed_dim
,
out_channels
=
embed_dim
*
3
,
kernel_size
=
(
1
,
1
))
def
_reset_parameters
(
self
):
xavier_uniform_
(
self
.
out_proj
.
weight
)
if
self
.
bias_k
is
not
None
:
xavier_normal_
(
self
.
bias_k
)
if
self
.
bias_v
is
not
None
:
xavier_normal_
(
self
.
bias_v
)
def
forward
(
self
,
query
,
key
,
value
,
key_padding_mask
=
None
,
incremental_state
=
None
,
need_weights
=
True
,
static_kv
=
False
,
attn_mask
=
None
,
qkv_
=
[
False
,
False
,
False
]):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
qkv_same
=
qkv_
[
0
]
kv_same
=
qkv_
[
1
]
tgt_len
,
bsz
,
embed_dim
=
query
.
shape
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
shape
)
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
key
.
shape
==
value
.
shape
if
qkv_same
:
# self-attention
q
,
k
,
v
=
self
.
_in_proj_qkv
(
query
)
elif
kv_same
:
# encoder-decoder attention
q
=
self
.
_in_proj_q
(
query
)
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
,
v
=
self
.
_in_proj_kv
(
key
)
else
:
q
=
self
.
_in_proj_q
(
query
)
k
=
self
.
_in_proj_k
(
key
)
v
=
self
.
_in_proj_v
(
value
)
q
*=
self
.
scaling
if
self
.
bias_k
is
not
None
:
assert
self
.
bias_v
is
not
None
self
.
bias_k
=
paddle
.
concat
([
self
.
bias_k
for
i
in
range
(
bsz
)],
axis
=
1
)
self
.
bias_v
=
paddle
.
concat
([
self
.
bias_v
for
i
in
range
(
bsz
)],
axis
=
1
)
k
=
paddle
.
concat
([
k
,
self
.
bias_k
])
v
=
paddle
.
concat
([
v
,
self
.
bias_v
])
if
attn_mask
is
not
None
:
attn_mask
=
paddle
.
concat
([
attn_mask
,
paddle
.
zeros
([
attn_mask
.
shape
[
0
],
1
],
dtype
=
attn_mask
.
dtype
)],
axis
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
paddle
.
concat
(
[
key_padding_mask
,
paddle
.
zeros
([
key_padding_mask
.
shape
[
0
],
1
],
dtype
=
key_padding_mask
.
dtype
)],
axis
=
1
)
q
=
q
.
reshape
([
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
([
1
,
0
,
2
])
if
k
is
not
None
:
k
=
k
.
reshape
([
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
([
1
,
0
,
2
])
if
v
is
not
None
:
v
=
v
.
reshape
([
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
([
1
,
0
,
2
])
src_len
=
k
.
shape
[
1
]
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
shape
[
0
]
==
bsz
assert
key_padding_mask
.
shape
[
1
]
==
src_len
if
self
.
add_zero_attn
:
src_len
+=
1
k
=
paddle
.
concat
([
k
,
paddle
.
zeros
((
k
.
shape
[
0
],
1
)
+
k
.
shape
[
2
:],
dtype
=
k
.
dtype
)],
axis
=
1
)
v
=
paddle
.
concat
([
v
,
paddle
.
zeros
((
v
.
shape
[
0
],
1
)
+
v
.
shape
[
2
:],
dtype
=
v
.
dtype
)],
axis
=
1
)
if
attn_mask
is
not
None
:
attn_mask
=
paddle
.
concat
([
attn_mask
,
paddle
.
zeros
([
attn_mask
.
shape
[
0
],
1
],
dtype
=
attn_mask
.
dtype
)],
axis
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
paddle
.
concat
(
[
key_padding_mask
,
paddle
.
zeros
([
key_padding_mask
.
shape
[
0
],
1
],
dtype
=
key_padding_mask
.
dtype
)],
axis
=
1
)
attn_output_weights
=
paddle
.
bmm
(
q
,
k
.
transpose
([
0
,
2
,
1
]))
assert
list
(
attn_output_weights
.
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
attn_output_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
attn_output_weights
=
attn_output_weights
.
reshape
([
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
])
key
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
astype
(
'float32'
)
y
=
paddle
.
full
(
shape
=
key
.
shape
,
dtype
=
'float32'
,
fill_value
=
'-inf'
)
y
=
paddle
.
where
(
key
==
0.
,
key
,
y
)
attn_output_weights
+=
y
attn_output_weights
=
attn_output_weights
.
reshape
([
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
])
attn_output_weights
=
F
.
softmax
(
attn_output_weights
.
astype
(
'float32'
),
axis
=-
1
,
dtype
=
paddle
.
float32
if
attn_output_weights
.
dtype
==
paddle
.
float16
else
attn_output_weights
.
dtype
)
attn_output_weights
=
F
.
dropout
(
attn_output_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
paddle
.
bmm
(
attn_output_weights
,
v
)
assert
list
(
attn_output
.
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn_output
=
attn_output
.
transpose
([
1
,
0
,
2
]).
reshape
([
tgt_len
,
bsz
,
embed_dim
])
attn_output
=
self
.
out_proj
(
attn_output
)
if
need_weights
:
# average attention weights over heads
attn_output_weights
=
attn_output_weights
.
reshape
([
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
])
attn_output_weights
=
attn_output_weights
.
sum
(
axis
=
1
)
/
self
.
num_heads
else
:
attn_output_weights
=
None
return
attn_output
,
attn_output_weights
def
_in_proj_qkv
(
self
,
query
):
query
=
query
.
transpose
([
1
,
2
,
0
])
query
=
paddle
.
unsqueeze
(
query
,
axis
=
2
)
res
=
self
.
conv3
(
query
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
.
chunk
(
3
,
axis
=-
1
)
def
_in_proj_kv
(
self
,
key
):
key
=
key
.
transpose
([
1
,
2
,
0
])
key
=
paddle
.
unsqueeze
(
key
,
axis
=
2
)
res
=
self
.
conv2
(
key
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
.
chunk
(
2
,
axis
=-
1
)
def
_in_proj_q
(
self
,
query
):
query
=
query
.
transpose
([
1
,
2
,
0
])
query
=
paddle
.
unsqueeze
(
query
,
axis
=
2
)
res
=
self
.
conv1
(
query
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
def
_in_proj_k
(
self
,
key
):
key
=
key
.
transpose
([
1
,
2
,
0
])
key
=
paddle
.
unsqueeze
(
key
,
axis
=
2
)
res
=
self
.
conv1
(
key
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
def
_in_proj_v
(
self
,
value
):
value
=
value
.
transpose
([
1
,
2
,
0
])
#(1, 2, 0)
value
=
paddle
.
unsqueeze
(
value
,
axis
=
2
)
res
=
self
.
conv1
(
value
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
class
MultiheadAttentionOptim
(
nn
.
Layer
):
r
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout
=
0.
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
super
(
MultiheadAttentionOptim
,
self
).
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
out_proj
=
Linear
(
embed_dim
,
embed_dim
,
bias_attr
=
bias
)
self
.
_reset_parameters
()
self
.
conv1
=
paddle
.
nn
.
Conv2D
(
in_channels
=
embed_dim
,
out_channels
=
embed_dim
,
kernel_size
=
(
1
,
1
))
self
.
conv2
=
paddle
.
nn
.
Conv2D
(
in_channels
=
embed_dim
,
out_channels
=
embed_dim
,
kernel_size
=
(
1
,
1
))
self
.
conv3
=
paddle
.
nn
.
Conv2D
(
in_channels
=
embed_dim
,
out_channels
=
embed_dim
,
kernel_size
=
(
1
,
1
))
def
_reset_parameters
(
self
):
xavier_uniform_
(
self
.
out_proj
.
weight
)
def
forward
(
self
,
query
,
key
,
value
,
key_padding_mask
=
None
,
incremental_state
=
None
,
need_weights
=
True
,
static_kv
=
False
,
attn_mask
=
None
):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
key: [sequence length, batch size, embed dim]
value: [sequence length, batch size, embed dim]
key_padding_mask: if True, mask padding based on batch size
incremental_state: if provided, previous time steps are cashed
need_weights: output attn_output_weights
static_kv: key and value are static
Outputs of forward function
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
tgt_len
,
bsz
,
embed_dim
=
query
.
shape
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
shape
)
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
key
.
shape
==
value
.
shape
q
=
self
.
_in_proj_q
(
query
)
k
=
self
.
_in_proj_k
(
key
)
v
=
self
.
_in_proj_v
(
value
)
q
*=
self
.
scaling
q
=
q
.
reshape
([
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
([
1
,
0
,
2
])
k
=
k
.
reshape
([
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
([
1
,
0
,
2
])
v
=
v
.
reshape
([
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
([
1
,
0
,
2
])
src_len
=
k
.
shape
[
1
]
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
shape
[
0
]
==
bsz
assert
key_padding_mask
.
shape
[
1
]
==
src_len
attn_output_weights
=
paddle
.
bmm
(
q
,
k
.
transpose
([
0
,
2
,
1
]))
assert
list
(
attn_output_weights
.
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
attn_output_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
attn_output_weights
=
attn_output_weights
.
reshape
([
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
])
key
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
astype
(
'float32'
)
y
=
paddle
.
full
(
shape
=
key
.
shape
,
dtype
=
'float32'
,
fill_value
=
'-inf'
)
y
=
paddle
.
where
(
key
==
0.
,
key
,
y
)
attn_output_weights
+=
y
attn_output_weights
=
attn_output_weights
.
reshape
([
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
])
attn_output_weights
=
F
.
softmax
(
attn_output_weights
.
astype
(
'float32'
),
axis
=-
1
,
dtype
=
paddle
.
float32
if
attn_output_weights
.
dtype
==
paddle
.
float16
else
attn_output_weights
.
dtype
)
attn_output_weights
=
F
.
dropout
(
attn_output_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
paddle
.
bmm
(
attn_output_weights
,
v
)
assert
list
(
attn_output
.
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn_output
=
attn_output
.
transpose
([
1
,
0
,
2
]).
reshape
([
tgt_len
,
bsz
,
embed_dim
])
attn_output
=
self
.
out_proj
(
attn_output
)
if
need_weights
:
# average attention weights over heads
attn_output_weights
=
attn_output_weights
.
reshape
([
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
])
attn_output_weights
=
attn_output_weights
.
sum
(
axis
=
1
)
/
self
.
num_heads
else
:
attn_output_weights
=
None
return
attn_output
,
attn_output_weights
def
_in_proj_q
(
self
,
query
):
query
=
query
.
transpose
([
1
,
2
,
0
])
query
=
paddle
.
unsqueeze
(
query
,
axis
=
2
)
res
=
self
.
conv1
(
query
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
def
_in_proj_k
(
self
,
key
):
key
=
key
.
transpose
([
1
,
2
,
0
])
key
=
paddle
.
unsqueeze
(
key
,
axis
=
2
)
res
=
self
.
conv2
(
key
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
def
_in_proj_v
(
self
,
value
):
value
=
value
.
transpose
([
1
,
2
,
0
])
#(1, 2, 0)
value
=
paddle
.
unsqueeze
(
value
,
axis
=
2
)
res
=
self
.
conv3
(
value
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
([
2
,
0
,
1
])
return
res
\ No newline at end of file
ppocr/modeling/heads/rec_nrtr_optim_head.py
0 → 100644
浏览文件 @
b6f0a903
此差异已折叠。
点击以展开。
ppocr/postprocess/rec_postprocess.py
浏览文件 @
b6f0a903
...
...
@@ -156,6 +156,69 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
return
output
class
NRTRLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'EN_symbol'
,
use_space_char
=
True
,
**
kwargs
):
super
(
NRTRLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
if
preds
.
dtype
==
paddle
.
int64
:
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
if
preds
[
0
][
0
]
==
2
:
preds_idx
=
preds
[:,
1
:]
else
:
preds_idx
=
preds
text
=
self
.
decode
(
preds_idx
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
else
:
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
[:,
1
:])
return
text
,
label
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
return
dict_character
def
decode
(
self
,
text_index
,
text_prob
=
None
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
result_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
conf_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
if
text_index
[
batch_idx
][
idx
]
==
3
:
# end
break
try
:
char_list
.
append
(
self
.
character
[
int
(
text_index
[
batch_idx
][
idx
])])
except
:
continue
if
text_prob
is
not
None
:
conf_list
.
append
(
text_prob
[
batch_idx
][
idx
])
else
:
conf_list
.
append
(
1
)
text
=
''
.
join
(
char_list
)
result_list
.
append
((
text
.
lower
(),
np
.
mean
(
conf_list
)))
return
result_list
class
AttnLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
tools/eval.py
浏览文件 @
b6f0a903
...
...
@@ -22,7 +22,6 @@ import sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
from
ppocr.data
import
build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
...
...
@@ -31,7 +30,6 @@ from ppocr.utils.save_load import init_model
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
def
main
():
global_config
=
config
[
'Global'
]
# build dataloader
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录