Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
dc7bfe8a
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dc7bfe8a
编写于
7月 01, 2022
作者:
文幕地方
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix
上级
ce21ad83
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
25 addition
and
21 deletion
+25
-21
ppstructure/infer.sh
ppstructure/infer.sh
+0
-4
ppstructure/vqa/predict_vqa_token_ser.py
ppstructure/vqa/predict_vqa_token_ser.py
+25
-17
未找到文件。
ppstructure/infer.sh
已删除
100644 → 0
浏览文件 @
ce21ad83
python3.7 vqa/predict_vqa_token_ser.py
--vqa_algorithm
=
LayoutXLM
--ser_model_dir
=
../models/ser_LayoutXLM_xfun_zh/infer
--ser_dict_path
=
../train_data/XFUND/class_list_xfun.txt
--image_dir
=
docs/vqa/input/zh_val_42.jpg
python3.7 tools/infer_vqa_token_ser_re.py
-c
configs/vqa/re/layoutxlm.yml
-o
Architecture.Backbone.checkpoints
=
models/re_LayoutXLM_xfun_zh/ Global.infer_img
=
ppstructure/docs/vqa/input/zh_val_21.jpg
-c_ser
configs/vqa/ser/layoutxlm.yml
-o_ser
Architecture.Backbone.checkpoints
=
models/ser_LayoutXLM_xfun_zh/
\ No newline at end of file
ppstructure/vqa/predict_vqa_token_ser.py
浏览文件 @
dc7bfe8a
...
...
@@ -16,7 +16,7 @@ import sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
...
...
@@ -50,18 +50,18 @@ class SerPredictor(object):
'ocr_engine'
:
self
.
ocr_engine
}
},
{
'VQATokenPad'
:{
'max_seq_len'
:
512
,
'VQATokenPad'
:
{
'max_seq_len'
:
512
,
'return_attention_mask'
:
True
}
},
{
'VQASerTokenChunk'
:{
'max_seq_len'
:
512
,
'VQASerTokenChunk'
:
{
'max_seq_len'
:
512
,
'return_attention_mask'
:
True
}
},
{
'Resize'
:{
'size'
:
[
224
,
224
]
'Resize'
:
{
'size'
:
[
224
,
224
]
}
},
{
'NormalizeImage'
:
{
...
...
@@ -75,8 +75,8 @@ class SerPredictor(object):
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'input_ids'
,
'bbox'
,
'attention_mask'
,
'token_type_ids'
,
'image'
,
'labels'
,
'segment_offset_id'
,
'ocr_info'
,
'input_ids'
,
'bbox'
,
'attention_mask'
,
'token_type_ids'
,
'
image'
,
'labels'
,
'
segment_offset_id'
,
'ocr_info'
,
'entities'
]
}
...
...
@@ -86,7 +86,8 @@ class SerPredictor(object):
"class_path"
:
args
.
ser_dict_path
,
}
self
.
preprocess_op
=
create_operators
(
pre_process_list
,
{
'infer_mode'
:
True
})
self
.
preprocess_op
=
create_operators
(
pre_process_list
,
{
'infer_mode'
:
True
})
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'ser'
,
logger
)
...
...
@@ -113,11 +114,9 @@ class SerPredictor(object):
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
outputs
[
0
]
post_result
=
self
.
postprocess_op
(
preds
,
segment_offset_ids
=
[
data
[
6
]],
ocr_infos
=
[
data
[
7
]])
preds
,
segment_offset_ids
=
[
data
[
6
]],
ocr_infos
=
[
data
[
7
]])
elapse
=
time
.
time
()
-
starttime
return
post_result
,
elapse
...
...
@@ -136,17 +135,25 @@ def main(args):
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
img
=
img
[:,
:,
::
-
1
]
img
=
img
[:,
:,
::
-
1
]
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
ser_res
,
elapse
=
ser_predictor
(
img
)
ser_res
=
ser_res
[
0
]
res_str
=
'{}
\t
{}
\n
'
.
format
(
image_file
,
json
.
dumps
({
"ocr_info"
:
ser_res
,},
ensure_ascii
=
False
))
res_str
=
'{}
\t
{}
\n
'
.
format
(
image_file
,
json
.
dumps
(
{
"ocr_info"
:
ser_res
,
},
ensure_ascii
=
False
))
f_w
.
write
(
res_str
)
img_res
=
draw_ser_results
(
image_file
,
ser_res
,
font_path
=
"../doc/fonts/simfang.ttf"
,)
img_res
=
draw_ser_results
(
image_file
,
ser_res
,
font_path
=
"../doc/fonts/simfang.ttf"
,
)
img_save_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
basename
(
image_file
))
...
...
@@ -157,5 +164,6 @@ def main(args):
count
+=
1
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
if
__name__
==
"__main__"
:
main
(
parse_args
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录