Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
d9c28128
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d9c28128
编写于
10月 09, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix multi-inputs
上级
4e0fcd6e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
11 addition
and
17 deletion
+11
-17
ppocr/modeling/backbones/kie_unet_sdmgr.py
ppocr/modeling/backbones/kie_unet_sdmgr.py
+4
-12
tools/infer_kie.py
tools/infer_kie.py
+4
-2
tools/program.py
tools/program.py
+3
-3
未找到文件。
ppocr/modeling/backbones/kie_unet_sdmgr.py
浏览文件 @
d9c28128
...
@@ -167,20 +167,12 @@ class Kie_backbone(nn.Layer):
...
@@ -167,20 +167,12 @@ class Kie_backbone(nn.Layer):
gt_bboxes
[
i
,
:
num
,
...],
dtype
=
'float32'
))
gt_bboxes
[
i
,
:
num
,
...],
dtype
=
'float32'
))
return
img
,
temp_relations
,
temp_texts
,
temp_gt_bboxes
return
img
,
temp_relations
,
temp_texts
,
temp_gt_bboxes
def
forward
(
self
,
inputs
):
def
forward
(
self
,
images
,
inputs
):
img
,
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
=
inputs
[
0
],
inputs
[
img
=
images
1
],
inputs
[
2
],
inputs
[
3
],
inputs
[
5
],
inputs
[
-
1
]
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
=
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
4
],
inputs
[
-
1
]
img
,
relations
,
texts
,
gt_bboxes
=
self
.
pre_process
(
img
,
relations
,
texts
,
gt_bboxes
=
self
.
pre_process
(
img
,
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
)
img
,
relations
,
texts
,
gt_bboxes
,
tag
,
img_size
)
# for i in range(4):
# img_t = (img[i].numpy().transpose([1, 2, 0]) * 255.0).astype('uint8')
# img_t = img_t.copy()
# gt_bboxes_t = gt_bboxes[i].cpu().numpy()
# box = gt_bboxes_t.astype(np.int32).reshape((-1, 1, 2))
# cv2.polylines(img_t, [box], True, color=(255, 255, 0), thickness=1)
# cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t)
# # cv2.imwrite("/Users/hongyongjie/project/PaddleOCR/output/{}.png".format(i), img_t * 255.0)
# exit()
x
=
self
.
img_feat
(
img
)
x
=
self
.
img_feat
(
img
)
boxes
,
rois_num
=
self
.
bbox2roi
(
gt_bboxes
)
boxes
,
rois_num
=
self
.
bbox2roi
(
gt_bboxes
)
feats
=
paddle
.
fluid
.
layers
.
roi_align
(
feats
=
paddle
.
fluid
.
layers
.
roi_align
(
...
...
tools/infer_kie.py
浏览文件 @
d9c28128
...
@@ -80,7 +80,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
...
@@ -80,7 +80,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img
=
np
.
ones
((
h
,
w
*
3
,
3
),
dtype
=
np
.
uint8
)
*
255
vis_img
=
np
.
ones
((
h
,
w
*
3
,
3
),
dtype
=
np
.
uint8
)
*
255
vis_img
[:,
:
w
]
=
img
vis_img
[:,
:
w
]
=
img
vis_img
[:,
w
:]
=
pred_img
vis_img
[:,
w
:]
=
pred_img
save_kie_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/kie_results/"
save_kie_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/kie_results/"
if
not
os
.
path
.
exists
(
save_kie_path
):
if
not
os
.
path
.
exists
(
save_kie_path
):
os
.
makedirs
(
save_kie_path
)
os
.
makedirs
(
save_kie_path
)
save_path
=
os
.
path
.
join
(
save_kie_path
,
str
(
count
)
+
".png"
)
save_path
=
os
.
path
.
join
(
save_kie_path
,
str
(
count
)
+
".png"
)
...
@@ -128,7 +129,8 @@ def main():
...
@@ -128,7 +129,8 @@ def main():
batch_pred
[
i
]
=
paddle
.
to_tensor
(
batch_pred
[
i
]
=
paddle
.
to_tensor
(
np
.
expand_dims
(
np
.
expand_dims
(
batch
[
i
],
axis
=
0
))
batch
[
i
],
axis
=
0
))
node
,
edge
=
model
(
batch_pred
)
node
,
edge
=
model
(
batch
[
0
],
batch
[
1
:])
node
=
F
.
softmax
(
node
,
-
1
)
node
=
F
.
softmax
(
node
,
-
1
)
draw_kie_result
(
batch
,
node
,
idx_to_cls
,
index
)
draw_kie_result
(
batch
,
node
,
idx_to_cls
,
index
)
logger
.
info
(
"success!"
)
logger
.
info
(
"success!"
)
...
...
tools/program.py
浏览文件 @
d9c28128
...
@@ -197,7 +197,7 @@ def train(config,
...
@@ -197,7 +197,7 @@ def train(config,
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
extra_input
=
config
[
'Architecture'
][
extra_input
=
config
[
'Architecture'
][
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
]
'algorithm'
]
in
[
"SRN"
,
"NRTR"
,
"SAR"
,
"SEED"
,
"SDMGR"
]
try
:
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
except
:
...
@@ -230,7 +230,7 @@ def train(config,
...
@@ -230,7 +230,7 @@ def train(config,
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
batch
)
preds
=
model
(
images
)
loss
=
loss_class
(
preds
,
batch
)
loss
=
loss_class
(
preds
,
batch
)
avg_loss
=
loss
[
'loss'
]
avg_loss
=
loss
[
'loss'
]
avg_loss
.
backward
()
avg_loss
.
backward
()
...
@@ -379,7 +379,7 @@ def eval(model,
...
@@ -379,7 +379,7 @@ def eval(model,
if
model_type
==
'table'
or
extra_input
:
if
model_type
==
'table'
or
extra_input
:
preds
=
model
(
images
,
data
=
batch
[
1
:])
preds
=
model
(
images
,
data
=
batch
[
1
:])
else
:
else
:
preds
=
model
(
batch
)
preds
=
model
(
images
)
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录