Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
f6e03a51
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f6e03a51
编写于
1月 29, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
upload rare code
上级
95a60fa4
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
525 addition
and
13 deletion
+525
-13
configs/rec/rec_mv3_tps_bilstm_att.yml
configs/rec/rec_mv3_tps_bilstm_att.yml
+102
-0
configs/rec/rec_r34_vd_tps_bilstm_att.yml
configs/rec/rec_r34_vd_tps_bilstm_att.yml
+103
-0
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+19
-5
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+4
-1
ppocr/losses/rec_att_loss.py
ppocr/losses/rec_att_loss.py
+39
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+4
-1
ppocr/modeling/heads/rec_att_head.py
ppocr/modeling/heads/rec_att_head.py
+211
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-1
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+41
-5
未找到文件。
configs/rec/rec_mv3_tps_bilstm_att.yml
0 → 100644
浏览文件 @
f6e03a51
Global
:
use_gpu
:
true
epoch_num
:
72
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/rec_mv3_tps_bilstm_att/
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
2000
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path
:
character_type
:
en
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
learning_rate
:
0.0005
regularizer
:
name
:
'
L2'
factor
:
0.00001
Architecture
:
model_type
:
rec
algorithm
:
RARE
Transform
:
name
:
TPS
num_fiducial
:
20
loc_lr
:
0.1
model_name
:
small
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
96
Head
:
name
:
AttentionHead
hidden_size
:
96
Loss
:
name
:
AttentionLoss
PostProcess
:
name
:
AttnLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDateSet
data_dir
:
../training/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
AttnLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
256
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
LMDBDateSet
data_dir
:
../validation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
AttnLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
1
configs/rec/rec_r34_vd_tps_bilstm_att.yml
0 → 100644
浏览文件 @
f6e03a51
Global
:
use_gpu
:
true
epoch_num
:
400
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec/b3_rare_r34_none_gru/
save_epoch_step
:
3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
2000
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path
:
character_type
:
en
max_text_length
:
25
infer_mode
:
False
use_space_char
:
False
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
learning_rate
:
0.0005
regularizer
:
name
:
'
L2'
factor
:
0.00000
Architecture
:
model_type
:
rec
algorithm
:
RARE
Transform
:
name
:
TPS
num_fiducial
:
20
loc_lr
:
0.1
model_name
:
large
Backbone
:
name
:
ResNet
layers
:
34
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
256
#96
Head
:
name
:
AttentionHead
# AttentionHead
hidden_size
:
256
#
l2_decay
:
0.00001
Loss
:
name
:
AttentionLoss
PostProcess
:
name
:
AttnLabelDecode
Metric
:
name
:
RecMetric
main_indicator
:
acc
Train
:
dataset
:
name
:
LMDBDateSet
data_dir
:
../training/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
AttnLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
True
batch_size_per_card
:
256
drop_last
:
True
num_workers
:
8
Eval
:
dataset
:
name
:
LMDBDateSet
data_dir
:
../validation/
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
AttnLabelEncode
:
# Class handling label
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
100
]
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
label'
,
'
length'
]
# dataloader will return list in this order
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
256
num_workers
:
8
ppocr/data/imaug/label_ops.py
浏览文件 @
f6e03a51
...
@@ -197,16 +197,30 @@ class AttnLabelEncode(BaseRecLabelEncode):
...
@@ -197,16 +197,30 @@ class AttnLabelEncode(BaseRecLabelEncode):
super
(
AttnLabelEncode
,
super
(
AttnLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
character_type
,
use_space_char
)
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
return
dict_character
def
__call__
(
self
,
text
):
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
text
=
self
.
encode
(
text
)
return
text
if
text
is
None
:
return
None
if
len
(
text
)
>
self
.
max_text_len
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
[
0
]
+
text
+
[
len
(
self
.
character
)
-
1
]
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
)
-
1
)
data
[
'label'
]
=
np
.
array
(
text
)
return
data
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
if
beg_or_end
==
"beg"
:
...
...
ppocr/losses/__init__.py
浏览文件 @
f6e03a51
...
@@ -23,11 +23,14 @@ def build_loss(config):
...
@@ -23,11 +23,14 @@ def build_loss(config):
# rec loss
# rec loss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_att_loss
import
AttentionLoss
# cls loss
# cls loss
from
.cls_loss
import
ClsLoss
from
.cls_loss
import
ClsLoss
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
]
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/rec_att_loss.py
0 → 100644
浏览文件 @
f6e03a51
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
from
paddle
import
nn
class
AttentionLoss
(
nn
.
Layer
):
def
__init__
(
self
,
**
kwargs
):
super
(
AttentionLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'none'
)
def
forward
(
self
,
predicts
,
batch
):
targets
=
batch
[
1
].
astype
(
"int64"
)
label_lengths
=
batch
[
2
].
astype
(
'int64'
)
batch_size
,
num_steps
,
num_classes
=
predicts
.
shape
[
0
],
predicts
.
shape
[
1
],
predicts
.
shape
[
2
]
assert
len
(
targets
.
shape
)
==
len
(
list
(
predicts
.
shape
))
-
1
,
\
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
inputs
=
paddle
.
reshape
(
predicts
,
[
-
1
,
predicts
.
shape
[
-
1
]])
targets
=
paddle
.
reshape
(
targets
,
[
-
1
])
return
{
'loss'
:
paddle
.
sum
(
self
.
loss_func
(
inputs
,
targets
))}
ppocr/modeling/heads/__init__.py
浏览文件 @
f6e03a51
...
@@ -23,10 +23,13 @@ def build_head(config):
...
@@ -23,10 +23,13 @@ def build_head(config):
# rec head
# rec head
from
.rec_ctc_head
import
CTCHead
from
.rec_ctc_head
import
CTCHead
from
.rec_att_head
import
AttentionHead
# cls head
# cls head
from
.cls_head
import
ClsHead
from
.cls_head
import
ClsHead
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
]
support_dict
=
[
'DBHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
]
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
assert
module_name
in
support_dict
,
Exception
(
'head only support {}'
.
format
(
...
...
ppocr/modeling/heads/rec_att_head.py
0 → 100644
浏览文件 @
f6e03a51
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
import
numpy
as
np
from
paddle.jit
import
to_static
class
AttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_size
,
**
kwargs
):
super
(
AttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
self
.
hidden_size
=
hidden_size
self
.
num_classes
=
out_channels
self
.
attention_cell
=
AttentionGRUCell
(
in_channels
,
hidden_size
,
out_channels
,
use_gru
=
False
)
self
.
generator
=
nn
.
Linear
(
hidden_size
,
out_channels
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
return
input_ont_hot
def
forward
(
self
,
inputs
,
targets
=
None
,
batch_max_length
=
25
):
batch_size
=
inputs
.
shape
[
0
]
num_steps
=
batch_max_length
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
if
targets
is
not
None
:
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
targets
[:,
i
],
onehot_dim
=
self
.
num_classes
)
(
outputs
,
hidden
),
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
output_hiddens
.
append
(
paddle
.
unsqueeze
(
outputs
,
axis
=
1
))
output
=
paddle
.
concat
(
output_hiddens
,
axis
=
1
)
probs
=
self
.
generator
(
output
)
else
:
targets
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
probs
=
None
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
targets
,
onehot_dim
=
self
.
num_classes
)
(
outputs
,
hidden
),
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
probs_step
=
self
.
generator
(
outputs
)
probs
=
paddle
.
unsqueeze
(
probs_step
,
axis
=
1
)
if
probs
is
None
else
paddle
.
concat
(
[
probs
,
paddle
.
unsqueeze
(
probs_step
,
axis
=
1
)],
axis
=
1
)
next_input
=
probs_step
.
argmax
(
axis
=
1
)
targets
=
next_input
return
probs
class
AttentionGRUCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionGRUCell
,
self
).
__init__
()
self
.
i2h
=
nn
.
Linear
(
input_size
,
hidden_size
,
bias_attr
=
False
)
self
.
h2h
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
score
=
nn
.
Linear
(
hidden_size
,
1
,
bias_attr
=
False
)
self
.
rnn
=
nn
.
GRUCell
(
input_size
=
input_size
+
num_embeddings
,
hidden_size
=
hidden_size
)
self
.
hidden_size
=
hidden_size
def
forward
(
self
,
prev_hidden
,
batch_H
,
char_onehots
):
batch_H_proj
=
self
.
i2h
(
batch_H
)
prev_hidden_proj
=
paddle
.
unsqueeze
(
self
.
h2h
(
prev_hidden
),
axis
=
1
)
res
=
paddle
.
add
(
batch_H_proj
,
prev_hidden_proj
)
res
=
paddle
.
tanh
(
res
)
e
=
self
.
score
(
res
)
alpha
=
F
.
softmax
(
e
,
axis
=
1
)
alpha
=
paddle
.
transpose
(
alpha
,
[
0
,
2
,
1
])
context
=
paddle
.
squeeze
(
paddle
.
mm
(
alpha
,
batch_H
),
axis
=
1
)
concat_context
=
paddle
.
concat
([
context
,
char_onehots
],
1
)
cur_hidden
=
self
.
rnn
(
concat_context
,
prev_hidden
)
return
cur_hidden
,
alpha
class
AttentionLSTM
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_size
,
**
kwargs
):
super
(
AttentionLSTM
,
self
).
__init__
()
self
.
input_size
=
in_channels
self
.
hidden_size
=
hidden_size
self
.
num_classes
=
out_channels
self
.
attention_cell
=
AttentionLSTMCell
(
in_channels
,
hidden_size
,
out_channels
,
use_gru
=
False
)
self
.
generator
=
nn
.
Linear
(
hidden_size
,
out_channels
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
return
input_ont_hot
def
forward
(
self
,
inputs
,
targets
=
None
,
batch_max_length
=
25
):
batch_size
=
inputs
.
shape
[
0
]
num_steps
=
batch_max_length
hidden
=
(
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
)),
paddle
.
zeros
(
(
batch_size
,
self
.
hidden_size
)))
output_hiddens
=
[]
if
targets
is
not
None
:
for
i
in
range
(
num_steps
):
# one-hot vectors for a i-th char
char_onehots
=
self
.
_char_to_onehot
(
targets
[:,
i
],
onehot_dim
=
self
.
num_classes
)
hidden
,
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
hidden
=
(
hidden
[
1
][
0
],
hidden
[
1
][
1
])
output_hiddens
.
append
(
paddle
.
unsqueeze
(
hidden
[
0
],
axis
=
1
))
output
=
paddle
.
concat
(
output_hiddens
,
axis
=
1
)
probs
=
self
.
generator
(
output
)
else
:
targets
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
probs
=
None
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
targets
,
onehot_dim
=
self
.
num_classes
)
hidden
,
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
probs_step
=
self
.
generator
(
hidden
[
0
])
hidden
=
(
hidden
[
1
][
0
],
hidden
[
1
][
1
])
probs
=
paddle
.
unsqueeze
(
probs_step
,
axis
=
1
)
if
probs
is
None
else
paddle
.
concat
(
[
probs
,
paddle
.
unsqueeze
(
probs_step
,
axis
=
1
)],
axis
=
1
)
next_input
=
probs_step
.
argmax
(
axis
=
1
)
targets
=
next_input
return
probs
class
AttentionLSTMCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionLSTMCell
,
self
).
__init__
()
self
.
i2h
=
nn
.
Linear
(
input_size
,
hidden_size
,
bias_attr
=
False
)
self
.
h2h
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
score
=
nn
.
Linear
(
hidden_size
,
1
,
bias_attr
=
False
)
if
not
use_gru
:
self
.
rnn
=
nn
.
LSTMCell
(
input_size
=
input_size
+
num_embeddings
,
hidden_size
=
hidden_size
)
else
:
self
.
rnn
=
nn
.
GRUCell
(
input_size
=
input_size
+
num_embeddings
,
hidden_size
=
hidden_size
)
self
.
hidden_size
=
hidden_size
def
forward
(
self
,
prev_hidden
,
batch_H
,
char_onehots
):
batch_H_proj
=
self
.
i2h
(
batch_H
)
prev_hidden_proj
=
paddle
.
unsqueeze
(
self
.
h2h
(
prev_hidden
[
0
]),
axis
=
1
)
res
=
paddle
.
add
(
batch_H_proj
,
prev_hidden_proj
)
res
=
paddle
.
tanh
(
res
)
e
=
self
.
score
(
res
)
alpha
=
F
.
softmax
(
e
,
axis
=
1
)
alpha
=
paddle
.
transpose
(
alpha
,
[
0
,
2
,
1
])
context
=
paddle
.
squeeze
(
paddle
.
mm
(
alpha
,
batch_H
),
axis
=
1
)
concat_context
=
paddle
.
concat
([
context
,
char_onehots
],
1
)
cur_hidden
=
self
.
rnn
(
concat_context
,
prev_hidden
)
return
cur_hidden
,
alpha
if
__name__
==
'__main__'
:
paddle
.
disable_static
()
model
=
Attention
(
100
,
200
,
10
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
2
,
10
,
100
]).
astype
(
np
.
float32
)
y
=
np
.
random
.
randint
(
0
,
10
,
[
2
,
21
]).
astype
(
np
.
int32
)
xp
=
paddle
.
to_tensor
(
x
)
yp
=
paddle
.
to_tensor
(
y
)
res
=
model
(
inputs
=
xp
,
targets
=
yp
,
is_train
=
True
,
batch_max_length
=
20
)
print
(
"res: "
,
res
.
shape
)
ppocr/postprocess/__init__.py
浏览文件 @
f6e03a51
...
@@ -30,7 +30,8 @@ def build_post_process(config, global_config=None):
...
@@ -30,7 +30,8 @@ def build_post_process(config, global_config=None):
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
support_dict
=
[
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'AttnLabelDecode'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
f6e03a51
...
@@ -133,16 +133,52 @@ class AttnLabelDecode(BaseRecLabelDecode):
...
@@ -133,16 +133,52 @@ class AttnLabelDecode(BaseRecLabelDecode):
**
kwargs
):
**
kwargs
):
super
(
AttnLabelDecode
,
self
).
__init__
(
character_dict_path
,
super
(
AttnLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
character_type
,
use_space_char
)
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
dict_character
=
dict_character
dict_character
=
[
self
.
beg_str
]
+
dict_character
+
[
self
.
end_str
]
return
dict_character
return
dict_character
def
__call__
(
self
,
text
):
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
"""
text = self.decode(text)
text = self.decode(text)
return
text
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if
isinstance
(
preds
,
paddle
.
Tensor
):
preds
=
preds
.
numpy
()
preds_idx
=
preds
.
argmax
(
axis
=
2
)
preds_prob
=
preds
.
max
(
axis
=
2
)
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
True
)
if
label
is
None
:
return
text
label
=
self
.
decode
(
label
,
is_remove_duplicate
=
True
)
return
text
,
label
def
encoder
(
self
,
labels
,
labels_length
):
"""
used to encoder labels readed from LMDB dataset, forexample:
[35, 25, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] encode to
'you': [0, 35,25,31, 37, 0, ...] 'sos'you'eos'
"""
if
isinstance
(
labels
,
paddle
.
Tensor
):
labels
=
labels
.
numpy
()
batch_max_length
=
labels
.
shape
[
1
]
+
2
# add start token 'sos' and end token 'eos'
new_labels
=
np
.
zeros
(
[
labels
.
shape
[
0
],
batch_max_length
]).
astype
(
np
.
int64
)
for
i
in
range
(
labels
.
shape
[
0
]):
new_labels
[
i
,
1
:
1
+
labels_length
[
i
]]
=
labels
[
i
,
:
labels_length
[
i
]]
# new_labels[i, 0] = 'sos' token
new_labels
[
i
,
labels_length
[
i
]
+
1
]
=
len
(
self
.
character
)
-
1
# add end charactor 'eos' token
return
new_labels
def
get_ignored_tokens
(
self
):
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录