Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
c708041e
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c708041e
编写于
11月 12, 2020
作者:
Z
zhoujun
提交者:
GitHub
11月 12, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
CRNN导出 (#1159)
* 识别模型导出 * 识别模型inference
上级
882ad395
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
37 addition
and
62 deletion
+37
-62
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+18
-58
tools/program.py
tools/program.py
+15
-1
tools/train.py
tools/train.py
+4
-3
未找到文件。
tools/infer/predict_rec.py
浏览文件 @
c708041e
...
@@ -26,34 +26,27 @@ import time
...
@@ -26,34 +26,27 @@ import time
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
from
ppocr.
utils.utility
import
initial_logger
from
ppocr.
postprocess
import
build_post_process
logger
=
initial_logger
()
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.character
import
CharacterOps
class
TextRecognizer
(
object
):
class
TextRecognizer
(
object
):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
mode
=
"rec"
)
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
rec_image_shape
=
[
int
(
v
)
for
v
in
args
.
rec_image_shape
.
split
(
","
)]
self
.
character_type
=
args
.
rec_char_type
self
.
character_type
=
args
.
rec_char_type
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_batch_num
=
args
.
rec_batch_num
self
.
rec_algorithm
=
args
.
rec_algorithm
self
.
rec_algorithm
=
args
.
rec_algorithm
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
self
.
use_zero_copy_run
=
args
.
use_zero_copy_run
char_ops_params
=
{
postprocess_params
=
{
'name'
:
'CTCLabelDecode'
,
"character_type"
:
args
.
rec_char_type
,
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
,
"use_space_char"
:
args
.
use_space_char
"max_text_length"
:
args
.
max_text_length
}
}
if
self
.
rec_algorithm
!=
"RARE"
:
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
char_ops_params
[
'loss_type'
]
=
'ctc'
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
self
.
loss_type
=
'ctc'
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
else
:
char_ops_params
[
'loss_type'
]
=
'attention'
self
.
loss_type
=
'attention'
self
.
char_ops
=
CharacterOps
(
char_ops_params
)
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
...
@@ -112,48 +105,14 @@ class TextRecognizer(object):
...
@@ -112,48 +105,14 @@ class TextRecognizer(object):
else
:
else
:
norm_img_batch
=
fluid
.
core
.
PaddleTensor
(
norm_img_batch
)
norm_img_batch
=
fluid
.
core
.
PaddleTensor
(
norm_img_batch
)
self
.
predictor
.
run
([
norm_img_batch
])
self
.
predictor
.
run
([
norm_img_batch
])
outputs
=
[]
if
self
.
loss_type
==
"ctc"
:
for
output_tensor
in
self
.
output_tensors
:
rec_idx_batch
=
self
.
output_tensors
[
0
]
.
copy_to_cpu
()
output
=
output_tensor
.
copy_to_cpu
()
rec_idx_lod
=
self
.
output_tensors
[
0
].
lod
()[
0
]
outputs
.
append
(
output
)
predict_batch
=
self
.
output_tensors
[
1
].
copy_to_cpu
()
preds
=
outputs
[
0
]
predict_lod
=
self
.
output_tensors
[
1
].
lod
()[
0
]
rec_res
=
self
.
postprocess_op
(
preds
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
predict_time
+=
elapse
return
rec_res
,
elapse
for
rno
in
range
(
len
(
rec_idx_lod
)
-
1
):
beg
=
rec_idx_lod
[
rno
]
end
=
rec_idx_lod
[
rno
+
1
]
rec_idx_tmp
=
rec_idx_batch
[
beg
:
end
,
0
]
preds_text
=
self
.
char_ops
.
decode
(
rec_idx_tmp
)
beg
=
predict_lod
[
rno
]
end
=
predict_lod
[
rno
+
1
]
probs
=
predict_batch
[
beg
:
end
,
:]
ind
=
np
.
argmax
(
probs
,
axis
=
1
)
blank
=
probs
.
shape
[
1
]
valid_ind
=
np
.
where
(
ind
!=
(
blank
-
1
))[
0
]
if
len
(
valid_ind
)
==
0
:
continue
score
=
np
.
mean
(
probs
[
valid_ind
,
ind
[
valid_ind
]])
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
else
:
rec_idx_batch
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
predict_batch
=
self
.
output_tensors
[
1
].
copy_to_cpu
()
elapse
=
time
.
time
()
-
starttime
predict_time
+=
elapse
for
rno
in
range
(
len
(
rec_idx_batch
)):
end_pos
=
np
.
where
(
rec_idx_batch
[
rno
,
:]
==
1
)[
0
]
if
len
(
end_pos
)
<=
1
:
preds
=
rec_idx_batch
[
rno
,
1
:]
score
=
np
.
mean
(
predict_batch
[
rno
,
1
:])
else
:
preds
=
rec_idx_batch
[
rno
,
1
:
end_pos
[
1
]]
score
=
np
.
mean
(
predict_batch
[
rno
,
1
:
end_pos
[
1
]])
preds_text
=
self
.
char_ops
.
decode
(
preds
)
# rec_res.append([preds_text, score])
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
[
preds_text
,
score
]
return
rec_res
,
predict_time
def
main
(
args
):
def
main
(
args
):
...
@@ -183,9 +142,10 @@ def main(args):
...
@@ -183,9 +142,10 @@ def main(args):
exit
()
exit
()
for
ino
in
range
(
len
(
img_list
)):
for
ino
in
range
(
len
(
img_list
)):
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Total predict time for %d images
:
%.3f"
%
print
(
"Total predict time for %d images
, cost:
%.3f"
%
(
len
(
img_list
),
predict_time
))
(
len
(
img_list
),
predict_time
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
logger
=
get_logger
()
main
(
utility
.
parse_args
())
main
(
utility
.
parse_args
())
tools/program.py
浏览文件 @
c708041e
...
@@ -323,6 +323,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
...
@@ -323,6 +323,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
return
metirc
return
metirc
def
save_inference_mode
(
model
,
config
,
logger
):
model
.
eval
()
save_path
=
'{}/infer/{}'
.
format
(
config
[
'Global'
][
'save_model_dir'
],
config
[
'Architecture'
][
'model_type'
])
if
config
[
'Architecture'
][
'model_type'
]
==
'rec'
:
input_shape
=
[
None
,
3
,
32
,
None
]
jit_model
=
paddle
.
jit
.
to_static
(
model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
input_shape
)])
paddle
.
jit
.
save
(
jit_model
,
save_path
)
logger
.
info
(
'inference model save to {}'
.
format
(
save_path
))
model
.
train
()
def
preprocess
():
def
preprocess
():
FLAGS
=
ArgsParser
().
parse_args
()
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
config
=
load_config
(
FLAGS
.
config
)
...
@@ -334,7 +348,7 @@ def preprocess():
...
@@ -334,7 +348,7 @@ def preprocess():
alg
=
config
[
'Architecture'
][
'algorithm'
]
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
]
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
tools/train.py
浏览文件 @
c708041e
...
@@ -89,6 +89,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -89,6 +89,7 @@ def main(config, device, logger, vdl_writer):
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
eval_class
,
pre_best_model_dict
,
logger
,
vdl_writer
)
eval_class
,
pre_best_model_dict
,
logger
,
vdl_writer
)
program
.
save_inference_mode
(
model
,
config
,
logger
)
def
test_reader
(
config
,
device
,
logger
):
def
test_reader
(
config
,
device
,
logger
):
...
@@ -102,8 +103,8 @@ def test_reader(config, device, logger):
...
@@ -102,8 +103,8 @@ def test_reader(config, device, logger):
if
count
%
1
==
0
:
if
count
%
1
==
0
:
batch_time
=
time
.
time
()
-
starttime
batch_time
=
time
.
time
()
-
starttime
starttime
=
time
.
time
()
starttime
=
time
.
time
()
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
count
,
logger
.
info
(
"reader: {}, {}, {}"
.
format
(
len
(
data
),
batch_time
))
count
,
len
(
data
[
0
]
),
batch_time
))
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
info
(
e
)
logger
.
info
(
e
)
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
logger
.
info
(
"finish reader: {}, Success!"
.
format
(
count
))
...
@@ -112,4 +113,4 @@ def test_reader(config, device, logger):
...
@@ -112,4 +113,4 @@ def test_reader(config, device, logger):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
(
config
,
device
,
logger
,
vdl_writer
)
main
(
config
,
device
,
logger
,
vdl_writer
)
#
test_reader(config, device, logger)
#
test_reader(config, device, logger)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录