Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
fc512a84
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fc512a84
编写于
9月 25, 2020
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add anno for rec
上级
8b64f4c2
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
101 addition
and
22 deletion
+101
-22
ppocr/modeling/architectures/rec_model.py
ppocr/modeling/architectures/rec_model.py
+17
-1
ppocr/modeling/heads/rec_ctc_head.py
ppocr/modeling/heads/rec_ctc_head.py
+6
-0
ppocr/modeling/losses/rec_attention_loss.py
ppocr/modeling/losses/rec_attention_loss.py
+1
-0
ppocr/modeling/losses/rec_ctc_loss.py
ppocr/modeling/losses/rec_ctc_loss.py
+1
-0
ppocr/utils/character.py
ppocr/utils/character.py
+49
-7
tools/eval_utils/eval_rec_utils.py
tools/eval_utils/eval_rec_utils.py
+3
-1
tools/program.py
tools/program.py
+24
-13
未找到文件。
ppocr/modeling/architectures/rec_model.py
浏览文件 @
fc512a84
...
...
@@ -25,6 +25,12 @@ from copy import deepcopy
class
RecModel
(
object
):
"""
Rec model architecture
Args:
params(object): Params from yaml file and settings from command line
"""
def
__init__
(
self
,
params
):
super
(
RecModel
,
self
).
__init__
()
global_params
=
params
[
'Global'
]
...
...
@@ -64,6 +70,12 @@ class RecModel(object):
self
.
num_heads
=
None
def
create_feed
(
self
,
mode
):
"""
Create feed dict and DataLoader object
Args:
mode(str): runtime mode, can be "train", "eval" or "test"
Return: image, labels, loader
"""
image_shape
=
deepcopy
(
self
.
image_shape
)
image_shape
.
insert
(
0
,
-
1
)
if
mode
==
"train"
:
...
...
@@ -189,9 +201,12 @@ class RecModel(object):
inputs
=
image
else
:
inputs
=
self
.
tps
(
image
)
# backbone
conv_feas
=
self
.
backbone
(
inputs
)
# predict
predicts
=
self
.
head
(
conv_feas
,
labels
,
mode
)
decoded_out
=
predicts
[
'decoded_out'
]
# loss
if
mode
==
"train"
:
loss
=
self
.
loss
(
predicts
,
labels
)
if
self
.
loss_type
==
"attention"
:
...
...
@@ -211,7 +226,7 @@ class RecModel(object):
outputs
=
{
'total_loss'
:
loss
,
'decoded_out'
:
\
decoded_out
,
'label'
:
label
}
return
loader
,
outputs
# export_model
elif
mode
==
"export"
:
predict
=
predicts
[
'predict'
]
if
self
.
loss_type
==
"ctc"
:
...
...
@@ -225,6 +240,7 @@ class RecModel(object):
]
return
[
image
,
{
'decoded_out'
:
decoded_out
,
'predicts'
:
predict
}]
# eval or test
else
:
predict
=
predicts
[
'predict'
]
if
self
.
loss_type
==
"ctc"
:
...
...
ppocr/modeling/heads/rec_ctc_head.py
浏览文件 @
fc512a84
...
...
@@ -27,6 +27,12 @@ import numpy as np
class
CTCPredict
(
object
):
"""
CTC predict
Args:
params(object): Params from yaml file and settings from command line
"""
def
__init__
(
self
,
params
):
super
(
CTCPredict
,
self
).
__init__
()
self
.
char_num
=
params
[
'char_num'
]
...
...
ppocr/modeling/losses/rec_attention_loss.py
浏览文件 @
fc512a84
...
...
@@ -33,6 +33,7 @@ class AttentionLoss(object):
predict
=
predicts
[
'predict'
]
label_out
=
labels
[
'label_out'
]
label_out
=
fluid
.
layers
.
cast
(
x
=
label_out
,
dtype
=
'int64'
)
# calculate attention loss
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
predict
,
label
=
label_out
)
sum_cost
=
fluid
.
layers
.
reduce_sum
(
cost
)
return
sum_cost
ppocr/modeling/losses/rec_ctc_loss.py
浏览文件 @
fc512a84
...
...
@@ -30,6 +30,7 @@ class CTCLoss(object):
def
__call__
(
self
,
predicts
,
labels
):
predict
=
predicts
[
'predict'
]
label
=
labels
[
'label'
]
# calculate ctc loss
cost
=
fluid
.
layers
.
warpctc
(
input
=
predict
,
label
=
label
,
blank
=
self
.
char_num
,
norm_by_times
=
True
)
sum_cost
=
fluid
.
layers
.
reduce_sum
(
cost
)
...
...
ppocr/utils/character.py
浏览文件 @
fc512a84
...
...
@@ -20,15 +20,21 @@ import sys
class
CharacterOps
(
object
):
""" Convert between text-label and text-index """
"""
Convert between text-label and text-index
Args:
config: config from yaml file
"""
def
__init__
(
self
,
config
):
self
.
character_type
=
config
[
'character_type'
]
self
.
loss_type
=
config
[
'loss_type'
]
self
.
max_text_len
=
config
[
'max_text_length'
]
# use the default dictionary(36 char)
if
self
.
character_type
==
"en"
:
self
.
character_str
=
"0123456789abcdefghijklmnopqrstuvwxyz"
dict_character
=
list
(
self
.
character_str
)
# use the custom dictionary
elif
self
.
character_type
in
[
"ch"
,
'japan'
,
'korean'
,
'french'
,
'german'
]:
...
...
@@ -55,25 +61,27 @@ class CharacterOps(object):
"Nonsupport type of the character: {}"
.
format
(
self
.
character_str
)
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
# add start and end str for attention
if
self
.
loss_type
==
"attention"
:
dict_character
=
[
self
.
beg_str
,
self
.
end_str
]
+
dict_character
elif
self
.
loss_type
==
"srn"
:
dict_character
=
dict_character
+
[
self
.
beg_str
,
self
.
end_str
]
# create char dict
self
.
dict
=
{}
for
i
,
char
in
enumerate
(
dict_character
):
self
.
dict
[
char
]
=
i
self
.
character
=
dict_character
def
encode
(
self
,
text
):
"""convert text-label into text-index.
input:
"""
convert text-label into text-index.
Args:
text: text labels of each image. [batch_size]
output:
Return:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
# Ignore capital
if
self
.
character_type
==
"en"
:
text
=
text
.
lower
()
...
...
@@ -86,7 +94,15 @@ class CharacterOps(object):
return
text
def
decode
(
self
,
text_index
,
is_remove_duplicate
=
False
):
""" convert text-index into text-label. """
"""
convert text-index into text-label.
Args:
text_index: text index for each image
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
text: text label
"""
char_list
=
[]
char_num
=
self
.
get_char_num
()
...
...
@@ -108,6 +124,9 @@ class CharacterOps(object):
return
text
def
get_char_num
(
self
):
"""
Get character num
"""
return
len
(
self
.
character
)
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
...
...
@@ -132,6 +151,21 @@ def cal_predicts_accuracy(char_ops,
labels
,
labels_lod
,
is_remove_duplicate
=
False
):
"""
Calculate predicts accrarcy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod: lod tensor of preds
labels: label of input image, text index
labels_lod: lod tensor of label
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
acc: The accuracy of test set
acc_num: The correct number of samples predicted
img_num: The total sample number of the test set
"""
acc_num
=
0
img_num
=
0
for
ino
in
range
(
len
(
labels_lod
)
-
1
):
...
...
@@ -189,6 +223,14 @@ def cal_predicts_accuracy_srn(char_ops,
def
convert_rec_attention_infer_res
(
preds
):
"""
Convert recognition attention predict result with lod information
Args:
preds: the output of the model
Return:
convert_ids: A 1-D Tensor represents all the predicted results.
target_lod: The lod information of the predicted results
"""
img_num
=
preds
.
shape
[
0
]
target_lod
=
[
0
]
convert_ids
=
[]
...
...
tools/eval_utils/eval_rec_utils.py
浏览文件 @
fc512a84
...
...
@@ -122,7 +122,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
def
test_rec_benchmark
(
exe
,
config
,
eval_info_dict
):
" Evaluate lmdb dataset "
"""
eval rec benchmark
"""
eval_data_list
=
[
'IIIT5k_3000'
,
'SVT'
,
'IC03_860'
,
'IC03_867'
,
\
'IC13_857'
,
'IC13_1015'
,
'IC15_1811'
,
'IC15_2077'
,
'SVTP'
,
'CUTE80'
]
eval_data_dir
=
config
[
'TestReader'
][
'lmdb_sets_dir'
]
...
...
tools/program.py
浏览文件 @
fc512a84
...
...
@@ -150,19 +150,20 @@ def check_gpu(use_gpu):
def
build
(
config
,
main_prog
,
startup_prog
,
mode
):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
1. create a dataloader
2. create a model
3. create fetchs
4. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool
): train or valid
mode(str
): train or valid
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
fetch_name_list(dict): dict of model outputs(included loss and measures)
fetch_varname_list(list): list of outputs' varname
opt_loss_name(str): name of loss
"""
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
...
...
@@ -257,9 +258,14 @@ def train_eval_det_run(config,
train_info_dict
,
eval_info_dict
,
is_slim
=
None
):
'''
main program of evaluation for detection
'''
"""
Feed data to the model and fetch the measures and loss for detection
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id
=
0
log_smooth_window
=
config
[
'Global'
][
'log_smooth_window'
]
epoch_num
=
config
[
'Global'
][
'epoch_num'
]
...
...
@@ -376,9 +382,14 @@ def train_eval_rec_run(config,
train_info_dict
,
eval_info_dict
,
is_slim
=
None
):
'''
main program of evaluation for recognition
'''
"""
Feed data to the model and fetch the measures and loss for recognition
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id
=
0
log_smooth_window
=
config
[
'Global'
][
'log_smooth_window'
]
epoch_num
=
config
[
'Global'
][
'epoch_num'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录