Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
807dd106
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
807dd106
编写于
7月 01, 2022
作者:
文幕地方
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pre-commit
上级
dc7bfe8a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
25 addition
and
24 deletion
+25
-24
ppocr/modeling/backbones/vqa_layoutlm.py
ppocr/modeling/backbones/vqa_layoutlm.py
+8
-8
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+2
-1
ppstructure/utility.py
ppstructure/utility.py
+1
-2
tools/export_model.py
tools/export_model.py
+2
-1
tools/infer/utility.py
tools/infer/utility.py
+1
-1
tools/infer_vqa_token_ser.py
tools/infer_vqa_token_ser.py
+9
-9
tools/infer_vqa_token_ser_re.py
tools/infer_vqa_token_ser_re.py
+2
-2
未找到文件。
ppocr/modeling/backbones/vqa_layoutlm.py
浏览文件 @
807dd106
...
@@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel):
...
@@ -121,14 +121,14 @@ class LayoutXLMForSer(NLPBaseModel):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
model
(
x
=
self
.
model
(
input_ids
=
x
[
0
],
input_ids
=
x
[
0
],
bbox
=
x
[
1
],
bbox
=
x
[
1
],
attention_mask
=
x
[
2
],
attention_mask
=
x
[
2
],
token_type_ids
=
x
[
3
],
token_type_ids
=
x
[
3
],
image
=
x
[
4
],
image
=
x
[
4
],
position_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
head_mask
=
None
,
labels
=
None
)
labels
=
None
)
if
not
self
.
training
:
if
not
self
.
training
:
return
x
return
x
return
x
[
0
]
return
x
[
0
]
...
...
ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
浏览文件 @
807dd106
...
@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
...
@@ -68,7 +68,8 @@ class VQASerTokenLayoutLMPostProcess(object):
def
_infer
(
self
,
preds
,
segment_offset_ids
,
ocr_infos
):
def
_infer
(
self
,
preds
,
segment_offset_ids
,
ocr_infos
):
results
=
[]
results
=
[]
for
pred
,
segment_offset_id
,
ocr_info
in
zip
(
preds
,
segment_offset_ids
,
ocr_infos
):
for
pred
,
segment_offset_id
,
ocr_info
in
zip
(
preds
,
segment_offset_ids
,
ocr_infos
):
pred
=
np
.
argmax
(
pred
,
axis
=
1
)
pred
=
np
.
argmax
(
pred
,
axis
=
1
)
pred
=
[
self
.
id2label_map
[
idx
]
for
idx
in
pred
]
pred
=
[
self
.
id2label_map
[
idx
]
for
idx
in
pred
]
...
...
ppstructure/utility.py
浏览文件 @
807dd106
...
@@ -40,7 +40,6 @@ def init_args():
...
@@ -40,7 +40,6 @@ def init_args():
type
=
ast
.
literal_eval
,
type
=
ast
.
literal_eval
,
default
=
None
,
default
=
None
,
help
=
'label map according to ppstructure/layout/README_ch.md'
)
help
=
'label map according to ppstructure/layout/README_ch.md'
)
# params for vqa
# params for vqa
parser
.
add_argument
(
"--vqa_algorithm"
,
type
=
str
,
default
=
'LayoutXLM'
)
parser
.
add_argument
(
"--vqa_algorithm"
,
type
=
str
,
default
=
'LayoutXLM'
)
parser
.
add_argument
(
"--ser_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--ser_model_dir"
,
type
=
str
)
...
@@ -73,7 +72,7 @@ def init_args():
...
@@ -73,7 +72,7 @@ def init_args():
"--recovery"
,
"--recovery"
,
type
=
bool
,
type
=
bool
,
default
=
False
,
default
=
False
,
help
=
'Whether to enable layout of recovery'
)
help
=
'Whether to enable layout of recovery'
)
return
parser
return
parser
...
...
tools/export_model.py
浏览文件 @
807dd106
...
@@ -97,8 +97,9 @@ def export_single_model(model,
...
@@ -97,8 +97,9 @@ def export_single_model(model,
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
shape
=
[
None
,
1
,
32
,
100
],
dtype
=
"float32"
),
]
]
model
=
to_static
(
model
,
input_spec
=
other_shape
)
model
=
to_static
(
model
,
input_spec
=
other_shape
)
elif
arch_config
[
"algorithm"
]
in
[
"LayoutLM"
,
"LayoutLMv2"
,
"LayoutXLM"
]:
elif
arch_config
[
"algorithm"
]
in
[
"LayoutLM"
,
"LayoutLMv2"
,
"LayoutXLM"
]:
input_spec
=
[
input_spec
=
[
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
512
],
dtype
=
"int64"
),
# input_ids
shape
=
[
None
,
512
],
dtype
=
"int64"
),
# input_ids
paddle
.
static
.
InputSpec
(
paddle
.
static
.
InputSpec
(
...
...
tools/infer/utility.py
浏览文件 @
807dd106
...
@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
...
@@ -318,7 +318,7 @@ def create_predictor(args, mode, logger):
# create predictor
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_names
=
predictor
.
get_input_names
()
if
mode
in
[
'ser'
,
're'
]:
if
mode
in
[
'ser'
,
're'
]:
input_tensor
=
[]
input_tensor
=
[]
for
name
in
input_names
:
for
name
in
input_names
:
input_tensor
.
append
(
predictor
.
get_input_handle
(
name
))
input_tensor
.
append
(
predictor
.
get_input_handle
(
name
))
...
...
tools/infer_vqa_token_ser.py
浏览文件 @
807dd106
...
@@ -44,7 +44,7 @@ def to_tensor(data):
...
@@ -44,7 +44,7 @@ def to_tensor(data):
from
collections
import
defaultdict
from
collections
import
defaultdict
data_dict
=
defaultdict
(
list
)
data_dict
=
defaultdict
(
list
)
to_tensor_idxs
=
[]
to_tensor_idxs
=
[]
for
idx
,
v
in
enumerate
(
data
):
for
idx
,
v
in
enumerate
(
data
):
if
isinstance
(
v
,
(
np
.
ndarray
,
paddle
.
Tensor
,
numbers
.
Number
)):
if
isinstance
(
v
,
(
np
.
ndarray
,
paddle
.
Tensor
,
numbers
.
Number
)):
if
idx
not
in
to_tensor_idxs
:
if
idx
not
in
to_tensor_idxs
:
...
@@ -72,7 +72,10 @@ class SerPredictor(object):
...
@@ -72,7 +72,10 @@ class SerPredictor(object):
from
paddleocr
import
PaddleOCR
from
paddleocr
import
PaddleOCR
self
.
ocr_engine
=
PaddleOCR
(
use_angle_cls
=
False
,
show_log
=
False
,
use_gpu
=
global_config
[
'use_gpu'
])
self
.
ocr_engine
=
PaddleOCR
(
use_angle_cls
=
False
,
show_log
=
False
,
use_gpu
=
global_config
[
'use_gpu'
])
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
@@ -82,8 +85,8 @@ class SerPredictor(object):
...
@@ -82,8 +85,8 @@ class SerPredictor(object):
op
[
op_name
][
'ocr_engine'
]
=
self
.
ocr_engine
op
[
op_name
][
'ocr_engine'
]
=
self
.
ocr_engine
elif
op_name
==
'KeepKeys'
:
elif
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
op
[
op_name
][
'keep_keys'
]
=
[
'input_ids'
,
'bbox'
,
'attention_mask'
,
'token_type_ids'
,
'image'
,
'labels'
,
'input_ids'
,
'bbox'
,
'attention_mask'
,
'token_type_ids'
,
'segment_offset_id'
,
'ocr_info'
,
'
image'
,
'labels'
,
'
segment_offset_id'
,
'ocr_info'
,
'entities'
'entities'
]
]
...
@@ -103,11 +106,9 @@ class SerPredictor(object):
...
@@ -103,11 +106,9 @@ class SerPredictor(object):
preds
=
self
.
model
(
batch
)
preds
=
self
.
model
(
batch
)
if
self
.
algorithm
in
[
'LayoutLMv2'
,
'LayoutXLM'
]:
if
self
.
algorithm
in
[
'LayoutLMv2'
,
'LayoutXLM'
]:
preds
=
preds
[
0
]
preds
=
preds
[
0
]
post_result
=
self
.
post_process_class
(
post_result
=
self
.
post_process_class
(
preds
,
preds
,
segment_offset_ids
=
batch
[
6
],
ocr_infos
=
batch
[
7
])
segment_offset_ids
=
batch
[
6
],
ocr_infos
=
batch
[
7
])
return
post_result
,
batch
return
post_result
,
batch
...
@@ -154,4 +155,3 @@ if __name__ == '__main__':
...
@@ -154,4 +155,3 @@ if __name__ == '__main__':
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
idx
,
len
(
infer_imgs
),
save_img_path
))
tools/infer_vqa_token_ser_re.py
浏览文件 @
807dd106
...
@@ -192,6 +192,6 @@ if __name__ == '__main__':
...
@@ -192,6 +192,6 @@ if __name__ == '__main__':
},
ensure_ascii
=
False
)
+
"
\n
"
)
},
ensure_ascii
=
False
)
+
"
\n
"
)
img_res
=
draw_re_results
(
img_path
,
result
)
img_res
=
draw_re_results
(
img_path
,
result
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
cv2
.
imwrite
(
save_img_path
,
img_res
)
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
logger
.
info
(
"process: [{}/{}], save result to {}"
.
format
(
idx
,
len
(
infer_imgs
),
save_img_path
))
idx
,
len
(
infer_imgs
),
save_img_path
))
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录