Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
jobily
PaddleOCR
提交
9b2c0e48
P
PaddleOCR
项目概览
jobily
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
9b2c0e48
编写于
11月 30, 2020
作者:
Z
zhoujun
提交者:
GitHub
11月 30, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1235 from WenmuZhou/dygraph_rc
修复ips计算过少的问题
上级
c4fcd143
1c43e8bb
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
21 addition
and
45 deletion
+21
-45
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+1
-11
tools/infer/predict_cls.py
tools/infer/predict_cls.py
+9
-9
tools/infer/predict_det.py
tools/infer/predict_det.py
+5
-5
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+5
-4
tools/program.py
tools/program.py
+1
-15
tools/train.py
tools/train.py
+0
-1
未找到文件。
ppocr/data/imaug/label_ops.py
浏览文件 @
9b2c0e48
...
...
@@ -123,7 +123,7 @@ class BaseRecLabelEncode(object):
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if
len
(
text
)
>
self
.
max_text_len
:
if
len
(
text
)
==
0
or
len
(
text
)
>
self
.
max_text_len
:
return
None
if
self
.
character_type
==
"en"
:
text
=
text
.
lower
()
...
...
@@ -138,9 +138,6 @@ class BaseRecLabelEncode(object):
return
None
return
text_list
def
get_ignored_tokens
(
self
):
return
[
0
]
# for ctc blank
class
CTCLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
...
...
@@ -160,8 +157,6 @@ class CTCLabelEncode(BaseRecLabelEncode):
text
=
self
.
encode
(
text
)
if
text
is
None
:
return
None
if
len
(
text
)
>
self
.
max_text_len
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
...
...
@@ -195,11 +190,6 @@ class AttnLabelEncode(BaseRecLabelEncode):
text
=
self
.
encode
(
text
)
return
text
def
get_ignored_tokens
(
self
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
):
if
beg_or_end
==
"beg"
:
idx
=
np
.
array
(
self
.
dict
[
self
.
beg_str
])
...
...
tools/infer/predict_cls.py
浏览文件 @
9b2c0e48
...
...
@@ -82,7 +82,7 @@ class TextClassifier(object):
cls_res
=
[[
''
,
0.0
]]
*
img_num
batch_num
=
self
.
cls_batch_num
predict_tim
e
=
0
elaps
e
=
0
for
beg_img_no
in
range
(
0
,
img_num
,
batch_num
):
end_img_no
=
min
(
img_num
,
beg_img_no
+
batch_num
)
norm_img_batch
=
[]
...
...
@@ -107,14 +107,14 @@ class TextClassifier(object):
self
.
predictor
.
run
([
norm_img_batch
])
prob_out
=
self
.
output_tensors
[
0
].
copy_to_cpu
()
cls_res
=
self
.
postprocess_op
(
prob_out
)
elapse
=
time
.
time
()
-
starttime
elapse
+
=
time
.
time
()
-
starttime
for
rno
in
range
(
len
(
cls_res
)):
label
,
score
=
cls_res
[
rno
]
cls_res
[
indices
[
beg_img_no
+
rno
]]
=
[
label
,
score
]
if
'180'
in
label
and
score
>
self
.
cls_thresh
:
img_list
[
indices
[
beg_img_no
+
rno
]]
=
cv2
.
rotate
(
img_list
[
indices
[
beg_img_no
+
rno
]],
1
)
return
img_list
,
cls_res
,
predict_tim
e
return
img_list
,
cls_res
,
elaps
e
def
main
(
args
):
...
...
@@ -143,10 +143,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit
()
for
ino
in
range
(
len
(
img_list
)):
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
cls_res
[
ino
]))
print
(
"Total predict time for %d images, cost: %.3f"
%
(
len
(
img_list
),
predict_time
))
print
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
cls_res
[
ino
]))
print
(
"Total predict time for {} images, cost: {:.3f}"
.
format
(
len
(
img_list
),
predict_time
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
tools/infer/predict_det.py
浏览文件 @
9b2c0e48
...
...
@@ -174,15 +174,15 @@ if __name__ == "__main__":
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
dt_boxes
,
elapse
=
text_detector
(
img
)
if
count
>
0
:
total_time
+=
elapse
count
+=
1
print
(
"Predict time of
%s:"
%
image_file
,
elapse
)
print
(
"Predict time of
{}: {}"
.
format
(
image_file
,
elapse
)
)
src_im
=
utility
.
draw_text_det_res
(
dt_boxes
,
image_file
)
img_name_pure
=
image_file
.
split
(
"/"
)[
-
1
]
cv2
.
imwrite
(
os
.
path
.
join
(
draw_img_save
,
"det_res_%s"
%
img_name_pure
),
src_im
)
img_name_pure
=
os
.
path
.
split
(
image_file
)[
-
1
]
img_path
=
os
.
path
.
join
(
draw_img_save
,
"det_res_{}"
.
format
(
img_name_pure
))
cv2
.
imwrite
(
img_path
,
src_im
)
if
count
>
1
:
print
(
"Avg Time:"
,
total_time
/
(
count
-
1
))
tools/infer/predict_rec.py
浏览文件 @
9b2c0e48
...
...
@@ -115,7 +115,7 @@ class TextRecognizer(object):
rec_result
=
self
.
postprocess_op
(
preds
)
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
elapse
=
time
.
time
()
-
starttime
elapse
+
=
time
.
time
()
-
starttime
return
rec_res
,
elapse
...
...
@@ -145,9 +145,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
)
exit
()
for
ino
in
range
(
len
(
img_list
)):
print
(
"Predicts of %s:%s"
%
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Total predict time for %d images, cost: %.3f"
%
(
len
(
img_list
),
predict_time
))
print
(
"Predicts of {}:{}"
.
format
(
valid_image_file_list
[
ino
],
rec_res
[
ino
]))
print
(
"Total predict time for {} images, cost: {:.3f}"
.
format
(
len
(
img_list
),
predict_time
))
if
__name__
==
"__main__"
:
...
...
tools/program.py
浏览文件 @
9b2c0e48
...
...
@@ -236,7 +236,6 @@ def train(config,
train_batch_cost
=
0.0
train_reader_cost
=
0.0
batch_sum
=
0
batch_start
=
time
.
time
()
# eval
if
global_step
>
start_eval_step
and
\
(
global_step
-
start_eval_step
)
%
eval_batch_step
==
0
and
dist
.
get_rank
()
==
0
:
...
...
@@ -275,6 +274,7 @@ def train(config,
best_model_dict
[
main_indicator
],
global_step
)
global_step
+=
1
batch_start
=
time
.
time
()
if
dist
.
get_rank
()
==
0
:
save_model
(
model
,
...
...
@@ -333,20 +333,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
return
metirc
def
save_inference_mode
(
model
,
config
,
logger
):
model
.
eval
()
save_path
=
'{}/infer/{}'
.
format
(
config
[
'Global'
][
'save_model_dir'
],
config
[
'Architecture'
][
'model_type'
])
if
config
[
'Architecture'
][
'model_type'
]
==
'rec'
:
input_shape
=
[
None
,
3
,
32
,
None
]
jit_model
=
paddle
.
jit
.
to_static
(
model
,
input_spec
=
[
paddle
.
static
.
InputSpec
(
input_shape
)])
paddle
.
jit
.
save
(
jit_model
,
save_path
)
logger
.
info
(
'inference model save to {}'
.
format
(
save_path
))
model
.
train
()
def
preprocess
():
FLAGS
=
ArgsParser
().
parse_args
()
config
=
load_config
(
FLAGS
.
config
)
...
...
tools/train.py
浏览文件 @
9b2c0e48
...
...
@@ -89,7 +89,6 @@ def main(config, device, logger, vdl_writer):
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
eval_class
,
pre_best_model_dict
,
logger
,
vdl_writer
)
program
.
save_inference_mode
(
model
,
config
,
logger
)
def
test_reader
(
config
,
device
,
logger
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录