Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
807dd106
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
807dd106
编写于
7月 01, 2022
作者:
文幕地方
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pre-commit
上级
dc7bfe8a
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
23 addition
and
22 deletion
+23
-22
ppocr/modeling/backbones/vqa_layoutlm.py
ppocr/modeling/backbones/vqa_layoutlm.py
+8
-8
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+2
-1
ppstructure/utility.py
ppstructure/utility.py
+1
-2
tools/export_model.py
tools/export_model.py
+2
-1
tools/infer/utility.py
tools/infer/utility.py
+1
-1
tools/infer_vqa_token_ser.py
tools/infer_vqa_token_ser.py
+9
-9
未找到文件。
ppocr/modeling/backbones/vqa_layoutlm.py
浏览文件 @
807dd106
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
浏览文件 @
807dd106
...
@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
...
@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
def
_infer
(
self
,
preds
,
segment_offset_ids
,
ocr_infos
):
def
_infer
(
self
,
preds
,
segment_offset_ids
,
ocr_infos
):
results
=
[]
results
=
[]
for
pred
,
segment_offset_id
,
ocr_info
in
zip
(
preds
,
segment_offset_ids
,
ocr_infos
):
for
pred
,
segment_offset_id
,
ocr_info
in
zip
(
preds
,
segment_offset_ids
,
ocr_infos
):
pred
=
np
.
argmax
(
pred
,
axis
=
1
)
pred
=
np
.
argmax
(
pred
,
axis
=
1
)
pred
=
[
self
.
id2label_map
[
idx
]
for
idx
in
pred
]
pred
=
[
self
.
id2label_map
[
idx
]
for
idx
in
pred
]
...
...
ppstructure/utility.py
浏览文件 @
807dd106
...
@@ -40,7 +40,6 @@ def init_args():
...
@@ -40,7 +40,6 @@ def init_args():
type
=
ast
.
literal_eval
,
type
=
ast
.
literal_eval
,
default
=
None
,
default
=
None
,
help
=
'label map according to ppstructure/layout/README_ch.md'
)
help
=
'label map according to ppstructure/layout/README_ch.md'
)
# params for vqa
# params for vqa
parser
.
add_argument
(
"--vqa_algorithm"
,
type
=
str
,
default
=
'LayoutXLM'
)
parser
.
add_argument
(
"--vqa_algorithm"
,
type
=
str
,
default
=
'LayoutXLM'
)
parser
.
add_argument
(
"--ser_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--ser_model_dir"
,
type
=
str
)
...
...
tools/export_model.py
浏览文件 @
807dd106
...
@@ -97,8 +97,9 @@ def export_single_model(model,
...
@@ -97,8 +97,9 @@ def export_single_model(model,
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
in
[
"LayoutLM"
,
"LayoutLMv2"
,
"LayoutXLM"
]:
elif
arch_config
[
"algorithm"
]
in
[
"LayoutLM"
,
"LayoutLMv2"
,
"LayoutXLM"
]:
input_spec
=
[
input_spec
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
512
],
dtype
=
"int64"
),
# input_ids
shape
=
[
None
,
512
],
dtype
=
"int64"
),
# input_ids
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
...
...
tools/infer/utility.py
浏览文件 @
807dd106
...
@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
...
@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
# create predictor
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_names
=
predictor
.
get_input_names
()
if
mode
in
[
'ser'
,
're'
]:
if
mode
in
[
'ser'
,
're'
]:
input_tensor
=
[]
input_tensor
=
[]
for
name
in
input_names
:
for
name
in
input_names
:
input_tensor
.
append
(
predictor
.
get_input_handle
(
name
))
input_tensor
.
append
(
predictor
.
get_input_handle
(
name
))
...
...
tools/infer_vqa_token_ser.py
浏览文件 @
807dd106
...
@@ -72,7 +72,10 @@ class SerPredictor(object):
...
@@ -72,7 +72,10 @@ class SerPredictor(object):
from
paddleocr
import
PaddleOCR
from
paddleocr
import
PaddleOCR
self
.
ocr_engine
=
PaddleOCR
(
use_angle_cls
=
False
,
show_log
=
False
,
use_gpu
=
global_config
[
'use_gpu'
])
self
.
ocr_engine
=
PaddleOCR
(
use_angle_cls
=
False
,
show_log
=
False
,
use_gpu
=
global_config
[
'use_gpu'
])
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
@@ -82,8 +85,8 @@ class SerPredictor(object):
...
@@ -82,8 +85,8 @@ class SerPredictor(object):
op
[
op_name
][
'ocr_engine'
]
=
self
.
ocr_engine
op
[
op_name
][
'ocr_engine'
]
=
self
.
ocr_engine
elif
op_name
==
'KeepKeys'
:
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
op
[
op_name
][
'keep_keys'
]
=
[
'input_ids'
,
'bbox'
,
'attention_mask'
,
'token_type_ids'
,
'image'
,
'labels'
,
'input_ids'
,
'bbox'
,
'attention_mask'
,
'token_type_ids'
,
'segment_offset_id'
,
'ocr_info'
,
'
image'
,
'labels'
,
'
segment_offset_id'
,
'ocr_info'
,
'entities'
'entities'
]
]
...
@@ -105,9 +108,7 @@ class SerPredictor(object):
...
@@ -105,9 +108,7 @@ class SerPredictor(object):
preds
=
preds
[
0
]
preds
=
preds
[
0
]
post_result
=
self
.
post_process_class
(
post_result
=
self
.
post_process_class
(
preds
,
preds
,
segment_offset_ids
=
batch
[
6
],
ocr_infos
=
batch
[
7
])
segment_offset_ids
=
batch
[
6
],
ocr_infos
=
batch
[
7
])
return
post_result
,
batch
return
post_result
,
batch
...
@@ -154,4 +155,3 @@ if __name__ == '__main__':
...
@@ -154,4 +155,3 @@ if __name__ == '__main__':
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
idx
,
len
(
infer_imgs
),
save_img_path
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录