Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
66029dd8
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
66029dd8
编写于
10月 11, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix kie infer and eval bug
上级
30e8dd8e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
16 addition
and
14 deletion
+16
-14
ppocr/modeling/backbones/kie_unet_sdmgr.py
ppocr/modeling/backbones/kie_unet_sdmgr.py
+4
-4
ppocr/modeling/heads/kie_sdmgr_head.py
ppocr/modeling/heads/kie_sdmgr_head.py
+1
-1
tools/eval.py
tools/eval.py
+2
-2
tools/infer_kie.py
tools/infer_kie.py
+2
-4
tools/program.py
tools/program.py
+7
-3
未找到文件。
ppocr/modeling/backbones/kie_unet_sdmgr.py
浏览文件 @
66029dd8
...
...
@@ -167,10 +167,10 @@ class Kie_backbone(nn.Layer):
gt_bboxes
[
i
,
:
num
,
...],
dtype
=
'float32'
))
return
img
,
temp_relations
,
temp_texts
,
temp_gt_bboxes
def
forward
(
self
,
i
mages
,
i
nputs
):
img
=
i
mages
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
=
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
4
],
inputs
[
-
1
]
def
forward
(
self
,
inputs
):
img
=
i
nputs
[
0
]
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
=
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
inputs
[
5
],
inputs
[
-
1
]
img
,
relations
,
texts
,
gt_bboxes
=
self
.
pre_process
(
img
,
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
)
x
=
self
.
img_feat
(
img
)
...
...
ppocr/modeling/heads/kie_sdmgr_head.py
浏览文件 @
66029dd8
...
...
@@ -49,7 +49,7 @@ class SDMGRHead(nn.Layer):
self
.
node_cls
=
nn
.
Linear
(
node_embed
,
num_classes
)
self
.
edge_cls
=
nn
.
Linear
(
edge_embed
,
2
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
,
targets
):
relations
,
texts
,
x
=
input
node_nums
,
char_nums
=
[],
[]
for
text
in
texts
:
...
...
tools/eval.py
浏览文件 @
66029dd8
...
...
@@ -54,7 +54,7 @@ def main():
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"SAR"
]
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
if
"model_type"
in
config
[
'Architecture'
].
keys
():
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
...
...
@@ -68,7 +68,7 @@ def main():
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
logger
.
info
(
f
"extra_inputs:
{
extra_input
}
"
)
# start eval
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
extra_input
)
...
...
tools/infer_kie.py
浏览文件 @
66029dd8
...
...
@@ -80,8 +80,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img
=
np
.
ones
((
h
,
w
*
3
,
3
),
dtype
=
np
.
uint8
)
*
255
vis_img
[:,
:
w
]
=
img
vis_img
[:,
w
:]
=
pred_img
save_kie_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/kie_results/"
save_kie_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/kie_results/"
if
not
os
.
path
.
exists
(
save_kie_path
):
os
.
makedirs
(
save_kie_path
)
save_path
=
os
.
path
.
join
(
save_kie_path
,
str
(
count
)
+
".png"
)
...
...
@@ -129,8 +128,7 @@ def main():
batch_pred
[
i
]
=
paddle
.
to_tensor
(
np
.
expand_dims
(
batch
[
i
],
axis
=
0
))
node
,
edge
=
model
(
batch
[
0
],
batch
[
1
:])
node
,
edge
=
model
(
batch_pred
)
node
=
F
.
softmax
(
node
,
-
1
)
draw_kie_result
(
batch
,
node
,
idx_to_cls
,
index
)
logger
.
info
(
"success!"
)
...
...
tools/program.py
浏览文件 @
66029dd8
...
...
@@ -196,7 +196,7 @@ def train(config,
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SDMGR"
]
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
...
...
@@ -228,6 +228,8 @@ def train(config,
model_average
=
True
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
if
model_type
==
"kie"
:
preds
=
model
(
batch
)
else
:
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
...
...
@@ -249,7 +251,7 @@ def train(config,
if
cal_metric_during_train
:
# only rec and cls need
batch
=
[
item
.
numpy
()
for
item
in
batch
]
if
model_type
==
'table'
:
if
model_type
in
[
'table'
,
'kie'
]
:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
...
...
@@ -377,13 +379,15 @@ def eval(model,
start
=
time
.
time
()
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
if
model_type
==
"kie"
:
preds
=
model
(
batch
)
else
:
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
# Obtain usable results from post-processing methods
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
if
model_type
==
'table'
:
if
model_type
in
[
'table'
,
'kie'
]
:
eval_class
(
preds
,
batch
)
else
:
post_result
=
post_process_class
(
preds
,
batch
[
1
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录