Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ddaa2c25
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 2 年 前同步成功
通知
1557
Star
32965
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ddaa2c25
编写于
8月 08, 2022
作者:
文幕地方
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add SLANet
上级
342522ab
变更
35
显示空白变更内容
内联
并排
Showing
35 changed file
with
2781 addition
and
357 deletion
+2781
-357
.gitignore
.gitignore
+28
-0
configs/table/SLANet.yml
configs/table/SLANet.yml
+141
-0
configs/table/table_master.yml
configs/table/table_master.yml
+11
-6
configs/table/table_mv3.yml
configs/table/table_mv3.yml
+8
-4
deploy/hubserving/ocr_system/module.py
deploy/hubserving/ocr_system/module.py
+1
-1
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+12
-9
ppocr/data/imaug/table_ops.py
ppocr/data/imaug/table_ops.py
+1
-1
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+2
-2
ppocr/losses/table_att_loss.py
ppocr/losses/table_att_loss.py
+46
-72
ppocr/metrics/table_metric.py
ppocr/metrics/table_metric.py
+8
-4
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+4
-1
ppocr/modeling/backbones/det_pp_lcnet.py
ppocr/modeling/backbones/det_pp_lcnet.py
+271
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+3
-2
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+134
-2
ppocr/modeling/heads/table_master_head.py
ppocr/modeling/heads/table_master_head.py
+4
-4
ppocr/modeling/necks/__init__.py
ppocr/modeling/necks/__init__.py
+2
-1
ppocr/modeling/necks/csp_pan.py
ppocr/modeling/necks/csp_pan.py
+325
-0
ppocr/postprocess/table_postprocess.py
ppocr/postprocess/table_postprocess.py
+6
-5
ppocr/utils/visual.py
ppocr/utils/visual.py
+2
-6
ppstructure/layout/picodet_postprocess.py
ppstructure/layout/picodet_postprocess.py
+227
-0
ppstructure/layout/predict_layout.py
ppstructure/layout/predict_layout.py
+155
-0
ppstructure/predict_system.py
ppstructure/predict_system.py
+49
-39
ppstructure/table/eval_table.py
ppstructure/table/eval_table.py
+64
-28
ppstructure/table/matcher.py
ppstructure/table/matcher.py
+173
-28
ppstructure/table/predict_structure.py
ppstructure/table/predict_structure.py
+3
-4
ppstructure/table/predict_table.py
ppstructure/table/predict_table.py
+58
-120
ppstructure/table/table_master_match.py
ppstructure/table/table_master_match.py
+1009
-0
ppstructure/utility.py
ppstructure/utility.py
+2
-1
test_tipc/configs/en_table_structure/table_mv3.yml
test_tipc/configs/en_table_structure/table_mv3.yml
+0
-2
test_tipc/configs/table_master/table_master.yml
test_tipc/configs/table_master/table_master.yml
+2
-4
tools/infer/predict_system.py
tools/infer/predict_system.py
+18
-7
tools/infer/utility.py
tools/infer/utility.py
+2
-0
tools/infer_table.py
tools/infer_table.py
+1
-2
tools/program.py
tools/program.py
+3
-1
tools/train.py
tools/train.py
+6
-1
未找到文件。
.gitignore
浏览文件 @
ddaa2c25
...
@@ -32,3 +32,31 @@ paddleocr.egg-info/
...
@@ -32,3 +32,31 @@ paddleocr.egg-info/
/deploy/android_demo/app/cache/
/deploy/android_demo/app/cache/
test_tipc/web/models/
test_tipc/web/models/
test_tipc/web/node_modules/
test_tipc/web/node_modules/
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdiparams.info
en_ppocr_mobile_v2.0_table_structure_infer/._inference.pdmodel
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdiparams.info
en_ppocr_mobile_v2.0_table_structure_infer/inference.pdmodel
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdiparams.info
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/._inference.pdmodel
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdiparams.info
ppstructure/layout/en_ppocr_mobile_v2.0_table_det_infer/inference.pdmodel
.gitignore
.gitignore
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdiparams.info
ppstructure/layout/en_ppocr_mobile_v2.0_table_rec_infer/inference.pdmodel
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/infer_cfg.yml
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdiparams.info
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape/inference.pdmodel
.gitignore
ppstructure/layout/table/inference.pdiparams
ppstructure/layout/table/inference.pdiparams.info
ppstructure/layout/table/inference.pdmodel
ppstructure/layout/picodet_lcnet_x2_5_640_publayernet_shape.tar
._en_ppocr_mobile_v2.0_table_structure_infer
en_ppocr_mobile_v2.0_table_structure_infer.tar
configs/table/SLANet.yml
0 → 100644
浏览文件 @
ddaa2c25
Global
:
use_gpu
:
true
epoch_num
:
400
log_smooth_window
:
20
print_batch_step
:
20
save_model_dir
:
./output/SLANet
save_epoch_step
:
400
# evaluation is run every 1000 iterations after the 0th iteration
eval_batch_step
:
[
0
,
1000
]
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
/ssd1/zhoujun20/table/ch/PaddleOCR/output/en/table_lcnet_1_0_csp_pan_headsv3_smooth_l1_pretrain_ssld_weight81_sync_bn/best_accuracy.pdparams
save_inference_dir
:
./output/SLANet/infer
use_visualdl
:
False
infer_img
:
doc/table/table.jpg
# for data or label process
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_type
:
en
max_text_length
:
&max_text_length
500
box_format
:
&box_format
'
xyxy'
# 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode
:
False
use_sync_bn
:
True
save_res_path
:
'
output/infer'
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
clip_norm
:
5.0
lr
:
# name: Piecewise
learning_rate
:
0.001
# decay_epochs : [10, 20]
# values : [0.002, 0.0002, 0.0001]
# warmup_epoch: 0
regularizer
:
name
:
'
L2'
factor
:
0.00000
Architecture
:
model_type
:
table
algorithm
:
SLANet
Backbone
:
name
:
PPLCNet
scale
:
1.0
pretrained
:
true
use_ssld
:
true
Neck
:
name
:
CSPPAN
out_channels
:
96
Head
:
name
:
SLAHead
hidden_size
:
256
max_text_length
:
*max_text_length
loc_reg_num
:
&loc_reg_num
4
Loss
:
name
:
SLANetLoss
structure_weight
:
1.0
loc_weight
:
2.0
loc_loss
:
smooth_l1
PostProcess
:
name
:
TableLabelDecode
Metric
:
name
:
TableMetric
main_indicator
:
acc
compute_bbox_metric
:
False
loc_reg_num
:
*loc_reg_num
box_format
:
*box_format
Train
:
dataset
:
name
:
PubTabDataSet
data_dir
:
/home/zhoujun20/table/PubTabNe/pubtabnet/train/
label_file_list
:
[
/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
TableLabelEncode
:
learn_empty_box
:
False
merge_no_span_structure
:
False
replace_empty_cell_token
:
False
loc_reg_num
:
*loc_reg_num
max_text_length
:
*max_text_length
-
TableBoxEncode
:
box_format
:
*box_format
-
ResizeTableImage
:
max_len
:
488
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
PaddingTableImage
:
size
:
[
488
,
488
]
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
structure'
,
'
bboxes'
,
'
bbox_masks'
,
'
shape'
]
loader
:
shuffle
:
True
batch_size_per_card
:
48
drop_last
:
True
num_workers
:
1
Eval
:
dataset
:
name
:
PubTabDataSet
data_dir
:
/home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list
:
[
/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_val.jsonl
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
TableLabelEncode
:
learn_empty_box
:
False
merge_no_span_structure
:
False
replace_empty_cell_token
:
False
loc_reg_num
:
*loc_reg_num
max_text_length
:
*max_text_length
-
TableBoxEncode
:
box_format
:
*box_format
-
ResizeTableImage
:
max_len
:
488
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
PaddingTableImage
:
size
:
[
488
,
488
]
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
structure'
,
'
bboxes'
,
'
bbox_masks'
,
'
shape'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
48
num_workers
:
1
configs/table/table_master.yml
浏览文件 @
ddaa2c25
...
@@ -15,9 +15,8 @@ Global:
...
@@ -15,9 +15,8 @@ Global:
save_res_path
:
./output/table_master
save_res_path
:
./output/table_master
character_dict_path
:
ppocr/utils/dict/table_master_structure_dict.txt
character_dict_path
:
ppocr/utils/dict/table_master_structure_dict.txt
infer_mode
:
false
infer_mode
:
false
max_text_length
:
500
max_text_length
:
&max_text_length
500
process_total_num
:
0
box_format
:
&box_format
'
xywh'
# 'xywh', 'xyxy', 'xyxyxyxy'
process_cut_num
:
0
Optimizer
:
Optimizer
:
...
@@ -52,7 +51,8 @@ Architecture:
...
@@ -52,7 +51,8 @@ Architecture:
headers
:
8
headers
:
8
dropout
:
0
dropout
:
0
d_ff
:
2024
d_ff
:
2024
max_text_length
:
500
max_text_length
:
*max_text_length
loc_reg_num
:
&loc_reg_num
4
Loss
:
Loss
:
name
:
TableMasterLoss
name
:
TableMasterLoss
...
@@ -66,6 +66,7 @@ Metric:
...
@@ -66,6 +66,7 @@ Metric:
name
:
TableMetric
name
:
TableMetric
main_indicator
:
acc
main_indicator
:
acc
compute_bbox_metric
:
False
compute_bbox_metric
:
False
box_format
:
*box_format
Train
:
Train
:
dataset
:
dataset
:
...
@@ -80,13 +81,15 @@ Train:
...
@@ -80,13 +81,15 @@ Train:
learn_empty_box
:
False
learn_empty_box
:
False
merge_no_span_structure
:
True
merge_no_span_structure
:
True
replace_empty_cell_token
:
True
replace_empty_cell_token
:
True
loc_reg_num
:
*loc_reg_num
max_text_length
:
*max_text_length
-
ResizeTableImage
:
-
ResizeTableImage
:
max_len
:
480
max_len
:
480
resize_bboxes
:
True
resize_bboxes
:
True
-
PaddingTableImage
:
-
PaddingTableImage
:
size
:
[
480
,
480
]
size
:
[
480
,
480
]
-
TableBoxEncode
:
-
TableBoxEncode
:
use_xywh
:
True
box_format
:
*box_format
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1./255.
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
mean
:
[
0.5
,
0.5
,
0.5
]
...
@@ -114,13 +117,15 @@ Eval:
...
@@ -114,13 +117,15 @@ Eval:
learn_empty_box
:
False
learn_empty_box
:
False
merge_no_span_structure
:
True
merge_no_span_structure
:
True
replace_empty_cell_token
:
True
replace_empty_cell_token
:
True
loc_reg_num
:
*loc_reg_num
max_text_length
:
*max_text_length
-
ResizeTableImage
:
-
ResizeTableImage
:
max_len
:
480
max_len
:
480
resize_bboxes
:
True
resize_bboxes
:
True
-
PaddingTableImage
:
-
PaddingTableImage
:
size
:
[
480
,
480
]
size
:
[
480
,
480
]
-
TableBoxEncode
:
-
TableBoxEncode
:
use_xywh
:
True
box_format
:
*box_format
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1./255.
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
mean
:
[
0.5
,
0.5
,
0.5
]
...
...
configs/table/table_mv3.yml
浏览文件 @
ddaa2c25
...
@@ -17,10 +17,9 @@ Global:
...
@@ -17,10 +17,9 @@ Global:
# for data or label process
# for data or label process
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_type
:
en
character_type
:
en
max_text_length
:
800
max_text_length
:
&max_text_length
800
box_format
:
&box_format
'
xyxy'
# 'xywh', 'xyxy', 'xyxyxyxy'
infer_mode
:
False
infer_mode
:
False
process_total_num
:
0
process_cut_num
:
0
Optimizer
:
Optimizer
:
name
:
Adam
name
:
Adam
...
@@ -44,7 +43,8 @@ Architecture:
...
@@ -44,7 +43,8 @@ Architecture:
name
:
TableAttentionHead
name
:
TableAttentionHead
hidden_size
:
256
hidden_size
:
256
loc_type
:
2
loc_type
:
2
max_text_length
:
800
max_text_length
:
*max_text_length
loc_reg_num
:
&loc_reg_num
4
Loss
:
Loss
:
name
:
TableAttentionLoss
name
:
TableAttentionLoss
...
@@ -72,6 +72,8 @@ Train:
...
@@ -72,6 +72,8 @@ Train:
learn_empty_box
:
False
learn_empty_box
:
False
merge_no_span_structure
:
False
merge_no_span_structure
:
False
replace_empty_cell_token
:
False
replace_empty_cell_token
:
False
loc_reg_num
:
*loc_reg_num
max_text_length
:
*max_text_length
-
TableBoxEncode
:
-
TableBoxEncode
:
-
ResizeTableImage
:
-
ResizeTableImage
:
max_len
:
488
max_len
:
488
...
@@ -104,6 +106,8 @@ Eval:
...
@@ -104,6 +106,8 @@ Eval:
learn_empty_box
:
False
learn_empty_box
:
False
merge_no_span_structure
:
False
merge_no_span_structure
:
False
replace_empty_cell_token
:
False
replace_empty_cell_token
:
False
loc_reg_num
:
*loc_reg_num
max_text_length
:
*max_text_length
-
TableBoxEncode
:
-
TableBoxEncode
:
-
ResizeTableImage
:
-
ResizeTableImage
:
max_len
:
488
max_len
:
488
...
...
deploy/hubserving/ocr_system/module.py
浏览文件 @
ddaa2c25
...
@@ -118,7 +118,7 @@ class OCRSystem(hub.Module):
...
@@ -118,7 +118,7 @@ class OCRSystem(hub.Module):
all_results
.
append
([])
all_results
.
append
([])
continue
continue
starttime
=
time
.
time
()
starttime
=
time
.
time
()
dt_boxes
,
rec_res
=
self
.
text_sys
(
img
)
dt_boxes
,
rec_res
,
_
=
self
.
text_sys
(
img
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
logger
.
info
(
"Predict time: {}"
.
format
(
elapse
))
logger
.
info
(
"Predict time: {}"
.
format
(
elapse
))
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
ddaa2c25
...
@@ -571,7 +571,7 @@ class TableLabelEncode(AttnLabelEncode):
...
@@ -571,7 +571,7 @@ class TableLabelEncode(AttnLabelEncode):
replace_empty_cell_token
=
False
,
replace_empty_cell_token
=
False
,
merge_no_span_structure
=
False
,
merge_no_span_structure
=
False
,
learn_empty_box
=
False
,
learn_empty_box
=
False
,
point_num
=
2
,
loc_reg_num
=
4
,
**
kwargs
):
**
kwargs
):
self
.
max_text_len
=
max_text_length
self
.
max_text_len
=
max_text_length
self
.
lower
=
False
self
.
lower
=
False
...
@@ -593,7 +593,7 @@ class TableLabelEncode(AttnLabelEncode):
...
@@ -593,7 +593,7 @@ class TableLabelEncode(AttnLabelEncode):
self
.
idx2char
=
{
v
:
k
for
k
,
v
in
self
.
dict
.
items
()}
self
.
idx2char
=
{
v
:
k
for
k
,
v
in
self
.
dict
.
items
()}
self
.
character
=
dict_character
self
.
character
=
dict_character
self
.
point_num
=
point
_num
self
.
loc_reg_num
=
loc_reg
_num
self
.
pad_idx
=
self
.
dict
[
self
.
beg_str
]
self
.
pad_idx
=
self
.
dict
[
self
.
beg_str
]
self
.
start_idx
=
self
.
dict
[
self
.
beg_str
]
self
.
start_idx
=
self
.
dict
[
self
.
beg_str
]
self
.
end_idx
=
self
.
dict
[
self
.
end_str
]
self
.
end_idx
=
self
.
dict
[
self
.
end_str
]
...
@@ -649,7 +649,7 @@ class TableLabelEncode(AttnLabelEncode):
...
@@ -649,7 +649,7 @@ class TableLabelEncode(AttnLabelEncode):
# encode box
# encode box
bboxes
=
np
.
zeros
(
bboxes
=
np
.
zeros
(
(
self
.
_max_text_len
,
self
.
point_num
*
2
),
dtype
=
np
.
float32
)
(
self
.
_max_text_len
,
self
.
loc_reg_num
),
dtype
=
np
.
float32
)
bbox_masks
=
np
.
zeros
((
self
.
_max_text_len
,
1
),
dtype
=
np
.
float32
)
bbox_masks
=
np
.
zeros
((
self
.
_max_text_len
,
1
),
dtype
=
np
.
float32
)
bbox_idx
=
0
bbox_idx
=
0
...
@@ -714,11 +714,11 @@ class TableMasterLabelEncode(TableLabelEncode):
...
@@ -714,11 +714,11 @@ class TableMasterLabelEncode(TableLabelEncode):
replace_empty_cell_token
=
False
,
replace_empty_cell_token
=
False
,
merge_no_span_structure
=
False
,
merge_no_span_structure
=
False
,
learn_empty_box
=
False
,
learn_empty_box
=
False
,
point_num
=
2
,
loc_reg_num
=
4
,
**
kwargs
):
**
kwargs
):
super
(
TableMasterLabelEncode
,
self
).
__init__
(
super
(
TableMasterLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
replace_empty_cell_token
,
max_text_length
,
character_dict_path
,
replace_empty_cell_token
,
merge_no_span_structure
,
learn_empty_box
,
point
_num
,
**
kwargs
)
merge_no_span_structure
,
learn_empty_box
,
loc_reg
_num
,
**
kwargs
)
self
.
pad_idx
=
self
.
dict
[
self
.
pad_str
]
self
.
pad_idx
=
self
.
dict
[
self
.
pad_str
]
self
.
unknown_idx
=
self
.
dict
[
self
.
unknown_str
]
self
.
unknown_idx
=
self
.
dict
[
self
.
unknown_str
]
...
@@ -739,13 +739,14 @@ class TableMasterLabelEncode(TableLabelEncode):
...
@@ -739,13 +739,14 @@ class TableMasterLabelEncode(TableLabelEncode):
class
TableBoxEncode
(
object
):
class
TableBoxEncode
(
object
):
def
__init__
(
self
,
use_xywh
=
False
,
**
kwargs
):
def
__init__
(
self
,
box_format
=
'xyxy'
,
**
kwargs
):
self
.
use_xywh
=
use_xywh
assert
box_format
in
[
'xywh'
,
'xyxy'
,
'xyxyxyxy'
]
self
.
box_format
=
box_format
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
img_height
,
img_width
=
data
[
'image'
].
shape
[:
2
]
img_height
,
img_width
=
data
[
'image'
].
shape
[:
2
]
bboxes
=
data
[
'bboxes'
]
bboxes
=
data
[
'bboxes'
]
if
self
.
use_xywh
and
bboxes
.
shape
[
1
]
==
4
:
if
self
.
box_format
==
'xywh'
and
bboxes
.
shape
[
1
]
==
4
:
bboxes
=
self
.
xyxy2xywh
(
bboxes
)
bboxes
=
self
.
xyxy2xywh
(
bboxes
)
bboxes
[:,
0
::
2
]
/=
img_width
bboxes
[:,
0
::
2
]
/=
img_width
bboxes
[:,
1
::
2
]
/=
img_height
bboxes
[:,
1
::
2
]
/=
img_height
...
@@ -1217,6 +1218,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
...
@@ -1217,6 +1218,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
dict_character
=
[
'</s>'
]
+
dict_character
dict_character
=
[
'</s>'
]
+
dict_character
return
dict_character
return
dict_character
class
SPINLabelEncode
(
AttnLabelEncode
):
class
SPINLabelEncode
(
AttnLabelEncode
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
...
@@ -1229,6 +1231,7 @@ class SPINLabelEncode(AttnLabelEncode):
...
@@ -1229,6 +1231,7 @@ class SPINLabelEncode(AttnLabelEncode):
super
(
SPINLabelEncode
,
self
).
__init__
(
super
(
SPINLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
use_space_char
)
max_text_length
,
character_dict_path
,
use_space_char
)
self
.
lower
=
lower
self
.
lower
=
lower
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
"sos"
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
self
.
end_str
=
"eos"
...
...
ppocr/data/imaug/table_ops.py
浏览文件 @
ddaa2c25
...
@@ -206,7 +206,7 @@ class ResizeTableImage(object):
...
@@ -206,7 +206,7 @@ class ResizeTableImage(object):
data
[
'bboxes'
]
=
data
[
'bboxes'
]
*
ratio
data
[
'bboxes'
]
=
data
[
'bboxes'
]
*
ratio
data
[
'image'
]
=
resize_img
data
[
'image'
]
=
resize_img
data
[
'src_img'
]
=
img
data
[
'src_img'
]
=
img
data
[
'shape'
]
=
np
.
array
([
resize_h
,
resize_w
,
ratio
,
ratio
])
data
[
'shape'
]
=
np
.
array
([
height
,
width
,
ratio
,
ratio
])
data
[
'max_len'
]
=
self
.
max_len
data
[
'max_len'
]
=
self
.
max_len
return
data
return
data
...
...
ppocr/losses/__init__.py
浏览文件 @
ddaa2c25
...
@@ -51,7 +51,7 @@ from .basic_loss import DistanceLoss
...
@@ -51,7 +51,7 @@ from .basic_loss import DistanceLoss
from
.combined_loss
import
CombinedLoss
from
.combined_loss
import
CombinedLoss
# table loss
# table loss
from
.table_att_loss
import
TableAttentionLoss
from
.table_att_loss
import
TableAttentionLoss
,
SLANetLoss
from
.table_master_loss
import
TableMasterLoss
from
.table_master_loss
import
TableMasterLoss
# vqa token loss
# vqa token loss
from
.vqa_token_layoutlm_loss
import
VQASerTokenLayoutLMLoss
from
.vqa_token_layoutlm_loss
import
VQASerTokenLayoutLMLoss
...
@@ -63,7 +63,7 @@ def build_loss(config):
...
@@ -63,7 +63,7 @@ def build_loss(config):
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'CELoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'CELoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
,
'TableMasterLoss'
,
'SPINAttentionLoss'
'TableMasterLoss'
,
'SPINAttentionLoss'
,
'SLANetLoss'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/table_att_loss.py
浏览文件 @
ddaa2c25
...
@@ -22,65 +22,11 @@ from paddle.nn import functional as F
...
@@ -22,65 +22,11 @@ from paddle.nn import functional as F
class
TableAttentionLoss
(
nn
.
Layer
):
class
TableAttentionLoss
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
structure_weight
,
loc_weight
,
**
kwargs
):
structure_weight
,
loc_weight
,
use_giou
=
False
,
giou_weight
=
1.0
,
**
kwargs
):
super
(
TableAttentionLoss
,
self
).
__init__
()
super
(
TableAttentionLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'none'
)
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'none'
)
self
.
structure_weight
=
structure_weight
self
.
structure_weight
=
structure_weight
self
.
loc_weight
=
loc_weight
self
.
loc_weight
=
loc_weight
self
.
use_giou
=
use_giou
self
.
giou_weight
=
giou_weight
def
giou_loss
(
self
,
preds
,
bbox
,
eps
=
1e-7
,
reduction
=
'mean'
):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
:return: loss
'''
ix1
=
paddle
.
maximum
(
preds
[:,
0
],
bbox
[:,
0
])
iy1
=
paddle
.
maximum
(
preds
[:,
1
],
bbox
[:,
1
])
ix2
=
paddle
.
minimum
(
preds
[:,
2
],
bbox
[:,
2
])
iy2
=
paddle
.
minimum
(
preds
[:,
3
],
bbox
[:,
3
])
iw
=
paddle
.
clip
(
ix2
-
ix1
+
1e-3
,
0.
,
1e10
)
ih
=
paddle
.
clip
(
iy2
-
iy1
+
1e-3
,
0.
,
1e10
)
# overlap
inters
=
iw
*
ih
# union
uni
=
(
preds
[:,
2
]
-
preds
[:,
0
]
+
1e-3
)
*
(
preds
[:,
3
]
-
preds
[:,
1
]
+
1e-3
)
+
(
bbox
[:,
2
]
-
bbox
[:,
0
]
+
1e-3
)
*
(
bbox
[:,
3
]
-
bbox
[:,
1
]
+
1e-3
)
-
inters
+
eps
# ious
ious
=
inters
/
uni
ex1
=
paddle
.
minimum
(
preds
[:,
0
],
bbox
[:,
0
])
ey1
=
paddle
.
minimum
(
preds
[:,
1
],
bbox
[:,
1
])
ex2
=
paddle
.
maximum
(
preds
[:,
2
],
bbox
[:,
2
])
ey2
=
paddle
.
maximum
(
preds
[:,
3
],
bbox
[:,
3
])
ew
=
paddle
.
clip
(
ex2
-
ex1
+
1e-3
,
0.
,
1e10
)
eh
=
paddle
.
clip
(
ey2
-
ey1
+
1e-3
,
0.
,
1e10
)
# enclose erea
enclose
=
ew
*
eh
+
eps
giou
=
ious
-
(
enclose
-
uni
)
/
enclose
loss
=
1
-
giou
if
reduction
==
'mean'
:
loss
=
paddle
.
mean
(
loss
)
elif
reduction
==
'sum'
:
loss
=
paddle
.
sum
(
loss
)
else
:
raise
NotImplementedError
return
loss
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
structure_probs
=
predicts
[
'structure_probs'
]
structure_probs
=
predicts
[
'structure_probs'
]
...
@@ -100,17 +46,45 @@ class TableAttentionLoss(nn.Layer):
...
@@ -100,17 +46,45 @@ class TableAttentionLoss(nn.Layer):
loc_targets_mask
=
loc_targets_mask
[:,
1
:,
:]
loc_targets_mask
=
loc_targets_mask
[:,
1
:,
:]
loc_loss
=
F
.
mse_loss
(
loc_preds
*
loc_targets_mask
,
loc_loss
=
F
.
mse_loss
(
loc_preds
*
loc_targets_mask
,
loc_targets
)
*
self
.
loc_weight
loc_targets
)
*
self
.
loc_weight
if
self
.
use_giou
:
loc_loss_giou
=
self
.
giou_loss
(
loc_preds
*
loc_targets_mask
,
total_loss
=
structure_loss
+
loc_loss
loc_targets
)
*
self
.
giou_weight
total_loss
=
structure_loss
+
loc_loss
+
loc_loss_giou
return
{
return
{
'loss'
:
total_loss
,
'loss'
:
total_loss
,
"structure_loss"
:
structure_loss
,
"structure_loss"
:
structure_loss
,
"loc_loss"
:
loc_loss
,
"loc_loss"
:
loc_loss
"loc_loss_giou"
:
loc_loss_giou
}
}
else
:
class
SLANetLoss
(
nn
.
Layer
):
def
__init__
(
self
,
structure_weight
,
loc_weight
,
loc_loss
=
'mse'
,
**
kwargs
):
super
(
SLANetLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'mean'
)
self
.
structure_weight
=
structure_weight
self
.
loc_weight
=
loc_weight
self
.
loc_loss
=
loc_loss
self
.
eps
=
1e-12
def
forward
(
self
,
predicts
,
batch
):
structure_probs
=
predicts
[
'structure_probs'
]
structure_targets
=
batch
[
1
].
astype
(
"int64"
)
structure_targets
=
structure_targets
[:,
1
:]
structure_loss
=
self
.
loss_func
(
structure_probs
,
structure_targets
)
structure_loss
=
paddle
.
mean
(
structure_loss
)
*
self
.
structure_weight
loc_preds
=
predicts
[
'loc_preds'
]
loc_targets
=
batch
[
2
].
astype
(
"float32"
)
loc_targets_mask
=
batch
[
3
].
astype
(
"float32"
)
loc_targets
=
loc_targets
[:,
1
:,
:]
loc_targets_mask
=
loc_targets_mask
[:,
1
:,
:]
loc_loss
=
F
.
smooth_l1_loss
(
loc_preds
*
loc_targets_mask
,
loc_targets
*
loc_targets_mask
,
reduction
=
'sum'
)
*
self
.
loc_weight
loc_loss
=
loc_loss
/
(
loc_targets_mask
.
sum
()
+
self
.
eps
)
total_loss
=
structure_loss
+
loc_loss
total_loss
=
structure_loss
+
loc_loss
return
{
return
{
'loss'
:
total_loss
,
'loss'
:
total_loss
,
...
...
ppocr/metrics/table_metric.py
浏览文件 @
ddaa2c25
...
@@ -59,7 +59,7 @@ class TableMetric(object):
...
@@ -59,7 +59,7 @@ class TableMetric(object):
def
__init__
(
self
,
def
__init__
(
self
,
main_indicator
=
'acc'
,
main_indicator
=
'acc'
,
compute_bbox_metric
=
False
,
compute_bbox_metric
=
False
,
point_num
=
2
,
box_format
=
'xyxy'
,
**
kwargs
):
**
kwargs
):
"""
"""
...
@@ -70,7 +70,7 @@ class TableMetric(object):
...
@@ -70,7 +70,7 @@ class TableMetric(object):
self
.
structure_metric
=
TableStructureMetric
()
self
.
structure_metric
=
TableStructureMetric
()
self
.
bbox_metric
=
DetMetric
()
if
compute_bbox_metric
else
None
self
.
bbox_metric
=
DetMetric
()
if
compute_bbox_metric
else
None
self
.
main_indicator
=
main_indicator
self
.
main_indicator
=
main_indicator
self
.
point_num
=
point_num
self
.
box_format
=
box_format
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
pred_label
,
batch
=
None
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
pred_label
,
batch
=
None
,
*
args
,
**
kwargs
):
...
@@ -129,10 +129,14 @@ class TableMetric(object):
...
@@ -129,10 +129,14 @@ class TableMetric(object):
self
.
bbox_metric
.
reset
()
self
.
bbox_metric
.
reset
()
def
format_box
(
self
,
box
):
def
format_box
(
self
,
box
):
if
self
.
point_num
==
2
:
if
self
.
box_format
==
'xyxy'
:
x1
,
y1
,
x2
,
y2
=
box
x1
,
y1
,
x2
,
y2
=
box
box
=
[[
x1
,
y1
],
[
x2
,
y1
],
[
x2
,
y2
],
[
x1
,
y2
]]
box
=
[[
x1
,
y1
],
[
x2
,
y1
],
[
x2
,
y2
],
[
x1
,
y2
]]
elif
self
.
point_num
==
4
:
elif
self
.
box_format
==
'xywh'
:
x
,
y
,
w
,
h
=
box
x1
,
y1
,
x2
,
y2
=
x
-
w
//
2
,
y
-
h
//
2
,
x
+
w
//
2
,
y
+
h
//
2
box
=
[[
x1
,
y1
],
[
x2
,
y1
],
[
x2
,
y2
],
[
x1
,
y2
]]
elif
self
.
box_format
==
'xyxyxyxy'
:
x1
,
y1
,
x2
,
y2
,
x3
,
y3
,
x4
,
y4
=
box
x1
,
y1
,
x2
,
y2
,
x3
,
y3
,
x4
,
y4
=
box
box
=
[[
x1
,
y1
],
[
x2
,
y2
],
[
x3
,
y3
],
[
x4
,
y4
]]
box
=
[[
x1
,
y1
],
[
x2
,
y2
],
[
x3
,
y3
],
[
x4
,
y4
]]
return
box
return
box
ppocr/modeling/backbones/__init__.py
浏览文件 @
ddaa2c25
...
@@ -21,7 +21,10 @@ def build_backbone(config, model_type):
...
@@ -21,7 +21,10 @@ def build_backbone(config, model_type):
from
.det_resnet
import
ResNet
from
.det_resnet
import
ResNet
from
.det_resnet_vd
import
ResNet_vd
from
.det_resnet_vd
import
ResNet_vd
from
.det_resnet_vd_sast
import
ResNet_SAST
from
.det_resnet_vd_sast
import
ResNet_SAST
support_dict
=
[
"MobileNetV3"
,
"ResNet"
,
"ResNet_vd"
,
"ResNet_SAST"
]
from
.det_pp_lcnet
import
PPLCNet
support_dict
=
[
"MobileNetV3"
,
"ResNet"
,
"ResNet_vd"
,
"ResNet_SAST"
,
"PPLCNet"
]
if
model_type
==
"table"
:
if
model_type
==
"table"
:
from
.table_master_resnet
import
TableResNetExtra
from
.table_master_resnet
import
TableResNetExtra
support_dict
.
append
(
'TableResNetExtra'
)
support_dict
.
append
(
'TableResNetExtra'
)
...
...
ppocr/modeling/backbones/det_pp_lcnet.py
0 → 100644
浏览文件 @
ddaa2c25
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
os
import
paddle
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm
,
Conv2D
,
Dropout
,
Linear
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingNormal
from
paddle.utils.download
import
get_path_from_url
MODEL_URLS
=
{
"PPLCNet_x0.25"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_25_pretrained.pdparams"
,
"PPLCNet_x0.35"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_35_pretrained.pdparams"
,
"PPLCNet_x0.5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_5_pretrained.pdparams"
,
"PPLCNet_x0.75"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x0_75_pretrained.pdparams"
,
"PPLCNet_x1.0"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_0_pretrained.pdparams"
,
"PPLCNet_x1.5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x1_5_pretrained.pdparams"
,
"PPLCNet_x2.0"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams"
,
"PPLCNet_x2.5"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams"
}
MODEL_STAGES_PATTERN
=
{
"PPLCNet"
:
[
"blocks2"
,
"blocks3"
,
"blocks4"
,
"blocks5"
,
"blocks6"
]
}
__all__
=
list
(
MODEL_URLS
.
keys
())
# Each element(list) represents a depthwise block, which is composed of k, in_c, out_c, s, use_se.
# k: kernel_size
# in_c: input channel number in depthwise block
# out_c: output channel number in depthwise block
# s: stride in depthwise block
# use_se: whether to use SE block
NET_CONFIG
=
{
"blocks2"
:
# k, in_c, out_c, s, use_se
[[
3
,
16
,
32
,
1
,
False
]],
"blocks3"
:
[[
3
,
32
,
64
,
2
,
False
],
[
3
,
64
,
64
,
1
,
False
]],
"blocks4"
:
[[
3
,
64
,
128
,
2
,
False
],
[
3
,
128
,
128
,
1
,
False
]],
"blocks5"
:
[[
3
,
128
,
256
,
2
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
],
[
5
,
256
,
256
,
1
,
False
]],
"blocks6"
:
[[
5
,
256
,
512
,
2
,
True
],
[
5
,
512
,
512
,
1
,
True
]]
}
def
make_divisible
(
v
,
divisor
=
8
,
min_value
=
None
):
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
filter_size
,
num_filters
,
stride
,
num_groups
=
1
):
super
().
__init__
()
self
.
conv
=
Conv2D
(
in_channels
=
num_channels
,
out_channels
=
num_filters
,
kernel_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()),
bias_attr
=
False
)
self
.
bn
=
BatchNorm
(
num_filters
,
param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
self
.
hardswish
=
nn
.
Hardswish
()
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
x
=
self
.
hardswish
(
x
)
return
x
class
DepthwiseSeparable
(
nn
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
dw_size
=
3
,
use_se
=
False
):
super
().
__init__
()
self
.
use_se
=
use_se
self
.
dw_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_channels
,
filter_size
=
dw_size
,
stride
=
stride
,
num_groups
=
num_channels
)
if
use_se
:
self
.
se
=
SEModule
(
num_channels
)
self
.
pw_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
filter_size
=
1
,
num_filters
=
num_filters
,
stride
=
1
)
def
forward
(
self
,
x
):
x
=
self
.
dw_conv
(
x
)
if
self
.
use_se
:
x
=
self
.
se
(
x
)
x
=
self
.
pw_conv
(
x
)
return
x
class
SEModule
(
nn
.
Layer
):
def
__init__
(
self
,
channel
,
reduction
=
4
):
super
().
__init__
()
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
conv1
=
Conv2D
(
in_channels
=
channel
,
out_channels
=
channel
//
reduction
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
Conv2D
(
in_channels
=
channel
//
reduction
,
out_channels
=
channel
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
hardsigmoid
=
nn
.
Hardsigmoid
()
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
hardsigmoid
(
x
)
x
=
paddle
.
multiply
(
x
=
identity
,
y
=
x
)
return
x
class
PPLCNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
3
,
scale
=
1.0
,
pretrained
=
False
,
use_ssld
=
False
):
super
().
__init__
()
self
.
out_channels
=
[
int
(
NET_CONFIG
[
"blocks3"
][
-
1
][
2
]
*
scale
),
int
(
NET_CONFIG
[
"blocks4"
][
-
1
][
2
]
*
scale
),
int
(
NET_CONFIG
[
"blocks5"
][
-
1
][
2
]
*
scale
),
int
(
NET_CONFIG
[
"blocks6"
][
-
1
][
2
]
*
scale
)
]
self
.
scale
=
scale
self
.
conv1
=
ConvBNLayer
(
num_channels
=
in_channels
,
filter_size
=
3
,
num_filters
=
make_divisible
(
16
*
scale
),
stride
=
2
)
self
.
blocks2
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks2"
])
])
self
.
blocks3
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks3"
])
])
self
.
blocks4
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks4"
])
])
self
.
blocks5
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks5"
])
])
self
.
blocks6
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
num_channels
=
make_divisible
(
in_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
stride
=
s
,
use_se
=
se
)
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks6"
])
])
if
pretrained
:
self
.
_load_pretrained
(
MODEL_URLS
[
'PPLCNet_x{}'
.
format
(
scale
)],
use_ssld
=
use_ssld
)
def
forward
(
self
,
x
):
outs
=
[]
x
=
self
.
conv1
(
x
)
x
=
self
.
blocks2
(
x
)
x
=
self
.
blocks3
(
x
)
outs
.
append
(
x
)
x
=
self
.
blocks4
(
x
)
outs
.
append
(
x
)
x
=
self
.
blocks5
(
x
)
outs
.
append
(
x
)
x
=
self
.
blocks6
(
x
)
outs
.
append
(
x
)
return
outs
def
_load_pretrained
(
self
,
pretrained_url
,
use_ssld
=
False
):
if
use_ssld
:
pretrained_url
=
pretrained_url
.
replace
(
"_pretrained"
,
"_ssld_pretrained"
)
print
(
pretrained_url
)
local_weight_path
=
get_path_from_url
(
pretrained_url
,
os
.
path
.
expanduser
(
"~/.paddleclas/weights"
))
param_state_dict
=
paddle
.
load
(
local_weight_path
)
self
.
set_dict
(
param_state_dict
)
return
ppocr/modeling/heads/__init__.py
浏览文件 @
ddaa2c25
...
@@ -42,14 +42,15 @@ def build_head(config):
...
@@ -42,14 +42,15 @@ def build_head(config):
#kie head
#kie head
from
.kie_sdmgr_head
import
SDMGRHead
from
.kie_sdmgr_head
import
SDMGRHead
from
.table_att_head
import
TableAttentionHead
from
.table_att_head
import
TableAttentionHead
,
SLAHead
from
.table_master_head
import
TableMasterHead
from
.table_master_head
import
TableMasterHead
support_dict
=
[
support_dict
=
[
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'MultiHead'
,
'ABINetHead'
,
'TableMasterHead'
,
'SPINAttentionHead'
'MultiHead'
,
'ABINetHead'
,
'TableMasterHead'
,
'SPINAttentionHead'
,
'SLAHead'
]
]
#table head
#table head
...
...
ppocr/modeling/heads/table_att_head.py
浏览文件 @
ddaa2c25
...
@@ -18,12 +18,26 @@ from __future__ import print_function
...
@@ -18,12 +18,26 @@ from __future__ import print_function
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
.rec_att_head
import
AttentionGRUCell
from
.rec_att_head
import
AttentionGRUCell
def
get_para_bias_attr
(
l2_decay
,
k
):
if
l2_decay
>
0
:
regularizer
=
paddle
.
regularizer
.
L2Decay
(
l2_decay
)
stdv
=
1.0
/
math
.
sqrt
(
k
*
1.0
)
initializer
=
nn
.
initializer
.
Uniform
(
-
stdv
,
stdv
)
else
:
regularizer
=
None
initializer
=
None
weight_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
)
bias_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
)
return
[
weight_attr
,
bias_attr
]
class
TableAttentionHead
(
nn
.
Layer
):
class
TableAttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
in_channels
,
in_channels
,
...
@@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer):
...
@@ -32,7 +46,7 @@ class TableAttentionHead(nn.Layer):
in_max_len
=
488
,
in_max_len
=
488
,
max_text_length
=
800
,
max_text_length
=
800
,
out_channels
=
30
,
out_channels
=
30
,
point_num
=
2
,
loc_reg_num
=
4
,
**
kwargs
):
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
input_size
=
in_channels
[
-
1
]
...
@@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer):
...
@@ -56,7 +70,7 @@ class TableAttentionHead(nn.Layer):
else
:
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_text_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_text_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
point_num
*
2
)
loc_reg_num
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
...
@@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer):
...
@@ -129,3 +143,121 @@ class TableAttentionHead(nn.Layer):
loc_preds
=
self
.
loc_generator
(
loc_concat
)
loc_preds
=
self
.
loc_generator
(
loc_concat
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
class
SLAHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
out_channels
=
30
,
max_text_length
=
500
,
loc_reg_num
=
4
,
fc_decay
=
0.0
,
**
kwargs
):
"""
@param in_channels: input shape
@param hidden_size: hidden_size for RNN and Embedding
@param out_channels: num_classes to rec
@param max_text_length: max text pred
"""
super
().
__init__
()
in_channels
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
max_text_length
=
max_text_length
self
.
emb
=
self
.
_char_to_onehot
self
.
num_embeddings
=
out_channels
# structure
self
.
structure_attention_cell
=
AttentionGRUCell
(
in_channels
,
hidden_size
,
self
.
num_embeddings
)
weight_attr
,
bias_attr
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
hidden_size
)
weight_attr1_1
,
bias_attr1_1
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
hidden_size
)
weight_attr1_2
,
bias_attr1_2
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
hidden_size
)
self
.
structure_generator
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
,
weight_attr
=
weight_attr1_2
,
bias_attr
=
bias_attr1_2
),
nn
.
Linear
(
hidden_size
,
out_channels
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
))
# loc
weight_attr1
,
bias_attr1
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
self
.
hidden_size
)
weight_attr2
,
bias_attr2
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
self
.
hidden_size
)
self
.
loc_generator
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
,
weight_attr
=
weight_attr1
,
bias_attr
=
bias_attr1
),
nn
.
Linear
(
self
.
hidden_size
,
loc_reg_num
,
weight_attr
=
weight_attr2
,
bias_attr
=
bias_attr2
),
nn
.
Sigmoid
())
def
forward
(
self
,
inputs
,
targets
=
None
):
fea
=
inputs
[
-
1
]
batch_size
=
fea
.
shape
[
0
]
# reshape
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
-
1
])
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
structure_preds
=
[]
loc_preds
=
[]
if
self
.
training
and
targets
is
not
None
:
structure
=
targets
[
0
]
for
i
in
range
(
self
.
max_text_length
+
1
):
hidden
,
structure_step
,
loc_step
=
self
.
_decode
(
structure
[:,
i
],
fea
,
hidden
)
structure_preds
.
append
(
structure_step
)
loc_preds
.
append
(
loc_step
)
else
:
pre_chars
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
max_text_length
=
paddle
.
to_tensor
(
self
.
max_text_length
)
# for export
loc_step
,
structure_step
=
None
,
None
for
i
in
range
(
max_text_length
+
1
):
hidden
,
structure_step
,
loc_step
=
self
.
_decode
(
pre_chars
,
fea
,
hidden
)
pre_chars
=
structure_step
.
argmax
(
axis
=
1
,
dtype
=
"int32"
)
structure_preds
.
append
(
structure_step
)
loc_preds
.
append
(
loc_step
)
structure_preds
=
paddle
.
stack
(
structure_preds
,
axis
=
1
)
loc_preds
=
paddle
.
stack
(
loc_preds
,
axis
=
1
)
if
not
self
.
training
:
structure_preds
=
F
.
softmax
(
structure_preds
)
return
{
'structure_probs'
:
structure_preds
,
'loc_preds'
:
loc_preds
}
def
_decode
(
self
,
pre_chars
,
features
,
hidden
):
"""
Predict table label and coordinates for each step
@param pre_chars: Table label in previous step
@param features:
@param hidden: hidden status in previous step
@return:
"""
emb_feature
=
self
.
emb
(
pre_chars
)
# output shape is b * self.hidden_size
(
output
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
hidden
,
features
,
emb_feature
)
# structure
structure_step
=
self
.
structure_generator
(
output
)
# loc
loc_step
=
self
.
loc_generator
(
output
)
return
hidden
,
structure_step
,
loc_step
def
_char_to_onehot
(
self
,
input_char
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
self
.
num_embeddings
)
return
input_ont_hot
ppocr/modeling/heads/table_master_head.py
浏览文件 @
ddaa2c25
...
@@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer):
...
@@ -37,7 +37,7 @@ class TableMasterHead(nn.Layer):
d_ff
=
2048
,
d_ff
=
2048
,
dropout
=
0
,
dropout
=
0
,
max_text_length
=
500
,
max_text_length
=
500
,
point_num
=
2
,
loc_reg_num
=
4
,
**
kwargs
):
**
kwargs
):
super
(
TableMasterHead
,
self
).
__init__
()
super
(
TableMasterHead
,
self
).
__init__
()
hidden_size
=
in_channels
[
-
1
]
hidden_size
=
in_channels
[
-
1
]
...
@@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer):
...
@@ -50,7 +50,7 @@ class TableMasterHead(nn.Layer):
self
.
cls_fc
=
nn
.
Linear
(
hidden_size
,
out_channels
)
self
.
cls_fc
=
nn
.
Linear
(
hidden_size
,
out_channels
)
self
.
bbox_fc
=
nn
.
Sequential
(
self
.
bbox_fc
=
nn
.
Sequential
(
# nn.Linear(hidden_size, hidden_size),
# nn.Linear(hidden_size, hidden_size),
nn
.
Linear
(
hidden_size
,
point_num
*
2
),
nn
.
Linear
(
hidden_size
,
loc_reg_num
),
nn
.
Sigmoid
())
nn
.
Sigmoid
())
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
)
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
)
self
.
embedding
=
Embeddings
(
d_model
=
hidden_size
,
vocab
=
out_channels
)
self
.
embedding
=
Embeddings
(
d_model
=
hidden_size
,
vocab
=
out_channels
)
...
@@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer):
...
@@ -59,7 +59,7 @@ class TableMasterHead(nn.Layer):
self
.
SOS
=
out_channels
-
3
self
.
SOS
=
out_channels
-
3
self
.
PAD
=
out_channels
-
1
self
.
PAD
=
out_channels
-
1
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
point_num
=
point
_num
self
.
loc_reg_num
=
loc_reg
_num
self
.
max_text_length
=
max_text_length
self
.
max_text_length
=
max_text_length
def
make_mask
(
self
,
tgt
):
def
make_mask
(
self
,
tgt
):
...
@@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer):
...
@@ -105,7 +105,7 @@ class TableMasterHead(nn.Layer):
output
=
paddle
.
zeros
(
output
=
paddle
.
zeros
(
[
input
.
shape
[
0
],
self
.
max_text_length
+
1
,
self
.
out_channels
])
[
input
.
shape
[
0
],
self
.
max_text_length
+
1
,
self
.
out_channels
])
bbox_output
=
paddle
.
zeros
(
bbox_output
=
paddle
.
zeros
(
[
input
.
shape
[
0
],
self
.
max_text_length
+
1
,
self
.
point_num
*
2
])
[
input
.
shape
[
0
],
self
.
max_text_length
+
1
,
self
.
loc_reg_num
])
max_text_length
=
paddle
.
to_tensor
(
self
.
max_text_length
)
max_text_length
=
paddle
.
to_tensor
(
self
.
max_text_length
)
for
i
in
range
(
max_text_length
+
1
):
for
i
in
range
(
max_text_length
+
1
):
target_mask
=
self
.
make_mask
(
input
)
target_mask
=
self
.
make_mask
(
input
)
...
...
ppocr/modeling/necks/__init__.py
浏览文件 @
ddaa2c25
...
@@ -25,9 +25,10 @@ def build_neck(config):
...
@@ -25,9 +25,10 @@ def build_neck(config):
from
.fpn
import
FPN
from
.fpn
import
FPN
from
.fce_fpn
import
FCEFPN
from
.fce_fpn
import
FCEFPN
from
.pren_fpn
import
PRENFPN
from
.pren_fpn
import
PRENFPN
from
.csp_pan
import
CSPPAN
support_dict
=
[
support_dict
=
[
'FPN'
,
'FCEFPN'
,
'LKPAN'
,
'DBFPN'
,
'RSEFPN'
,
'EASTFPN'
,
'SASTFPN'
,
'FPN'
,
'FCEFPN'
,
'LKPAN'
,
'DBFPN'
,
'RSEFPN'
,
'EASTFPN'
,
'SASTFPN'
,
'SequenceEncoder'
,
'PGFPN'
,
'TableFPN'
,
'PRENFPN'
'SequenceEncoder'
,
'PGFPN'
,
'TableFPN'
,
'PRENFPN'
,
'CSPPAN'
]
]
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/modeling/necks/csp_pan.py
0 → 100755
浏览文件 @
ddaa2c25
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The code is based on:
# https://github.com/PaddlePaddle/PaddleDetection/blob/release%2F2.3/ppdet/modeling/necks/csp_pan.py
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
__all__
=
[
'CSPPAN'
]
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
in_channel
=
96
,
out_channel
=
96
,
kernel_size
=
3
,
stride
=
1
,
groups
=
1
,
act
=
'leaky_relu'
):
super
(
ConvBNLayer
,
self
).
__init__
()
initializer
=
nn
.
initializer
.
KaimingUniform
()
self
.
act
=
act
assert
self
.
act
in
[
'leaky_relu'
,
"hard_swish"
]
self
.
conv
=
nn
.
Conv2D
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
groups
=
groups
,
padding
=
(
kernel_size
-
1
)
//
2
,
stride
=
stride
,
weight_attr
=
ParamAttr
(
initializer
=
initializer
),
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm2D
(
out_channel
)
def
forward
(
self
,
x
):
x
=
self
.
bn
(
self
.
conv
(
x
))
if
self
.
act
==
"leaky_relu"
:
x
=
F
.
leaky_relu
(
x
)
elif
self
.
act
==
"hard_swish"
:
x
=
F
.
hardswish
(
x
)
return
x
class
DPModule
(
nn
.
Layer
):
"""
Depth-wise and point-wise module.
Args:
in_channel (int): The input channels of this Module.
out_channel (int): The output channels of this Module.
kernel_size (int): The conv2d kernel size of this Module.
stride (int): The conv2d's stride of this Module.
act (str): The activation function of this Module,
Now support `leaky_relu` and `hard_swish`.
"""
def
__init__
(
self
,
in_channel
=
96
,
out_channel
=
96
,
kernel_size
=
3
,
stride
=
1
,
act
=
'leaky_relu'
):
super
(
DPModule
,
self
).
__init__
()
initializer
=
nn
.
initializer
.
KaimingUniform
()
self
.
act
=
act
self
.
dwconv
=
nn
.
Conv2D
(
in_channels
=
in_channel
,
out_channels
=
out_channel
,
kernel_size
=
kernel_size
,
groups
=
out_channel
,
padding
=
(
kernel_size
-
1
)
//
2
,
stride
=
stride
,
weight_attr
=
ParamAttr
(
initializer
=
initializer
),
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
out_channel
)
self
.
pwconv
=
nn
.
Conv2D
(
in_channels
=
out_channel
,
out_channels
=
out_channel
,
kernel_size
=
1
,
groups
=
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
initializer
=
initializer
),
bias_attr
=
False
)
self
.
bn2
=
nn
.
BatchNorm2D
(
out_channel
)
def
act_func
(
self
,
x
):
if
self
.
act
==
"leaky_relu"
:
x
=
F
.
leaky_relu
(
x
)
elif
self
.
act
==
"hard_swish"
:
x
=
F
.
hardswish
(
x
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
act_func
(
self
.
bn1
(
self
.
dwconv
(
x
)))
x
=
self
.
act_func
(
self
.
bn2
(
self
.
pwconv
(
x
)))
return
x
class
DarknetBottleneck
(
nn
.
Layer
):
"""The basic bottleneck block used in Darknet.
Each Block consists of two ConvModules and the input is added to the
final output. Each ConvModule is composed of Conv, BN, and act.
The first convLayer has filter size of 1x1 and the second one has the
filter size of 3x3.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
expansion (int): The kernel size of the convolution. Default: 0.5
add_identity (bool): Whether to add identity to the out.
Default: True
use_depthwise (bool): Whether to use depthwise separable convolution.
Default: False
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
expansion
=
0.5
,
add_identity
=
True
,
use_depthwise
=
False
,
act
=
"leaky_relu"
):
super
(
DarknetBottleneck
,
self
).
__init__
()
hidden_channels
=
int
(
out_channels
*
expansion
)
conv_func
=
DPModule
if
use_depthwise
else
ConvBNLayer
self
.
conv1
=
ConvBNLayer
(
in_channel
=
in_channels
,
out_channel
=
hidden_channels
,
kernel_size
=
1
,
act
=
act
)
self
.
conv2
=
conv_func
(
in_channel
=
hidden_channels
,
out_channel
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
1
,
act
=
act
)
self
.
add_identity
=
\
add_identity
and
in_channels
==
out_channels
def
forward
(
self
,
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
conv2
(
out
)
if
self
.
add_identity
:
return
out
+
identity
else
:
return
out
class
CSPLayer
(
nn
.
Layer
):
"""Cross Stage Partial Layer.
Args:
in_channels (int): The input channels of the CSP layer.
out_channels (int): The output channels of the CSP layer.
expand_ratio (float): Ratio to adjust the number of channels of the
hidden layer. Default: 0.5
num_blocks (int): Number of blocks. Default: 1
add_identity (bool): Whether to add identity in blocks.
Default: True
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: False
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
3
,
expand_ratio
=
0.5
,
num_blocks
=
1
,
add_identity
=
True
,
use_depthwise
=
False
,
act
=
"leaky_relu"
):
super
().
__init__
()
mid_channels
=
int
(
out_channels
*
expand_ratio
)
self
.
main_conv
=
ConvBNLayer
(
in_channels
,
mid_channels
,
1
,
act
=
act
)
self
.
short_conv
=
ConvBNLayer
(
in_channels
,
mid_channels
,
1
,
act
=
act
)
self
.
final_conv
=
ConvBNLayer
(
2
*
mid_channels
,
out_channels
,
1
,
act
=
act
)
self
.
blocks
=
nn
.
Sequential
(
*
[
DarknetBottleneck
(
mid_channels
,
mid_channels
,
kernel_size
,
1.0
,
add_identity
,
use_depthwise
,
act
=
act
)
for
_
in
range
(
num_blocks
)
])
def
forward
(
self
,
x
):
x_short
=
self
.
short_conv
(
x
)
x_main
=
self
.
main_conv
(
x
)
x_main
=
self
.
blocks
(
x_main
)
x_final
=
paddle
.
concat
((
x_main
,
x_short
),
axis
=
1
)
return
self
.
final_conv
(
x_final
)
class
Channel_T
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
=
[
116
,
232
,
464
],
out_channels
=
96
,
act
=
"leaky_relu"
):
super
(
Channel_T
,
self
).
__init__
()
self
.
convs
=
nn
.
LayerList
()
for
i
in
range
(
len
(
in_channels
)):
self
.
convs
.
append
(
ConvBNLayer
(
in_channels
[
i
],
out_channels
,
1
,
act
=
act
))
def
forward
(
self
,
x
):
outs
=
[
self
.
convs
[
i
](
x
[
i
])
for
i
in
range
(
len
(
x
))]
return
outs
class
CSPPAN
(
nn
.
Layer
):
"""Path Aggregation Network with CSP module.
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
kernel_size (int): The conv2d kernel size of this Module.
num_csp_blocks (int): Number of bottlenecks in CSPLayer. Default: 1
use_depthwise (bool): Whether to depthwise separable convolution in
blocks. Default: True
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
5
,
num_csp_blocks
=
1
,
use_depthwise
=
True
,
act
=
'hard_swish'
):
super
(
CSPPAN
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
[
out_channels
]
*
len
(
in_channels
)
conv_func
=
DPModule
if
use_depthwise
else
ConvBNLayer
self
.
conv_t
=
Channel_T
(
in_channels
,
out_channels
,
act
=
act
)
# build top-down blocks
self
.
upsample
=
nn
.
Upsample
(
scale_factor
=
2
,
mode
=
'nearest'
)
self
.
top_down_blocks
=
nn
.
LayerList
()
for
idx
in
range
(
len
(
in_channels
)
-
1
,
0
,
-
1
):
self
.
top_down_blocks
.
append
(
CSPLayer
(
out_channels
*
2
,
out_channels
,
kernel_size
=
kernel_size
,
num_blocks
=
num_csp_blocks
,
add_identity
=
False
,
use_depthwise
=
use_depthwise
,
act
=
act
))
# build bottom-up blocks
self
.
downsamples
=
nn
.
LayerList
()
self
.
bottom_up_blocks
=
nn
.
LayerList
()
for
idx
in
range
(
len
(
in_channels
)
-
1
):
self
.
downsamples
.
append
(
conv_func
(
out_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
2
,
act
=
act
))
self
.
bottom_up_blocks
.
append
(
CSPLayer
(
out_channels
*
2
,
out_channels
,
kernel_size
=
kernel_size
,
num_blocks
=
num_csp_blocks
,
add_identity
=
False
,
use_depthwise
=
use_depthwise
,
act
=
act
))
def
forward
(
self
,
inputs
):
"""
Args:
inputs (tuple[Tensor]): input features.
Returns:
tuple[Tensor]: CSPPAN features.
"""
assert
len
(
inputs
)
==
len
(
self
.
in_channels
)
inputs
=
self
.
conv_t
(
inputs
)
# top-down path
inner_outs
=
[
inputs
[
-
1
]]
for
idx
in
range
(
len
(
self
.
in_channels
)
-
1
,
0
,
-
1
):
feat_heigh
=
inner_outs
[
0
]
feat_low
=
inputs
[
idx
-
1
]
upsample_feat
=
F
.
upsample
(
feat_heigh
,
size
=
feat_low
.
shape
[
2
:
4
],
mode
=
"nearest"
)
inner_out
=
self
.
top_down_blocks
[
len
(
self
.
in_channels
)
-
1
-
idx
](
paddle
.
concat
([
upsample_feat
,
feat_low
],
1
))
inner_outs
.
insert
(
0
,
inner_out
)
# bottom-up path
outs
=
[
inner_outs
[
0
]]
for
idx
in
range
(
len
(
self
.
in_channels
)
-
1
):
feat_low
=
outs
[
-
1
]
feat_height
=
inner_outs
[
idx
+
1
]
downsample_feat
=
self
.
downsamples
[
idx
](
feat_low
)
out
=
self
.
bottom_up_blocks
[
idx
](
paddle
.
concat
(
[
downsample_feat
,
feat_height
],
1
))
outs
.
append
(
out
)
return
tuple
(
outs
)
ppocr/postprocess/table_postprocess.py
浏览文件 @
ddaa2c25
...
@@ -23,7 +23,7 @@ class TableLabelDecode(AttnLabelDecode):
...
@@ -23,7 +23,7 @@ class TableLabelDecode(AttnLabelDecode):
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
super
(
TableLabelDecode
,
self
).
__init__
(
character_dict_path
)
super
(
TableLabelDecode
,
self
).
__init__
(
character_dict_path
)
self
.
td_token
=
[
'<td>'
,
'<td'
,
'<
eb></eb>'
,
'<
td></td>'
]
self
.
td_token
=
[
'<td>'
,
'<td'
,
'<td></td>'
]
def
__call__
(
self
,
preds
,
batch
=
None
):
def
__call__
(
self
,
preds
,
batch
=
None
):
structure_probs
=
preds
[
'structure_probs'
]
structure_probs
=
preds
[
'structure_probs'
]
...
@@ -114,10 +114,8 @@ class TableLabelDecode(AttnLabelDecode):
...
@@ -114,10 +114,8 @@ class TableLabelDecode(AttnLabelDecode):
def
_bbox_decode
(
self
,
bbox
,
shape
):
def
_bbox_decode
(
self
,
bbox
,
shape
):
h
,
w
,
ratio_h
,
ratio_w
,
pad_h
,
pad_w
=
shape
h
,
w
,
ratio_h
,
ratio_w
,
pad_h
,
pad_w
=
shape
src_h
=
h
/
ratio_h
bbox
[
0
::
2
]
*=
w
src_w
=
w
/
ratio_w
bbox
[
1
::
2
]
*=
h
bbox
[
0
::
2
]
*=
src_w
bbox
[
1
::
2
]
*=
src_h
return
bbox
return
bbox
...
@@ -157,4 +155,7 @@ class TableMasterLabelDecode(TableLabelDecode):
...
@@ -157,4 +155,7 @@ class TableMasterLabelDecode(TableLabelDecode):
bbox
[
1
::
2
]
*=
h
bbox
[
1
::
2
]
*=
h
bbox
[
0
::
2
]
/=
ratio_w
bbox
[
0
::
2
]
/=
ratio_w
bbox
[
1
::
2
]
/=
ratio_h
bbox
[
1
::
2
]
/=
ratio_h
x
,
y
,
w
,
h
=
bbox
x1
,
y1
,
x2
,
y2
=
x
-
w
//
2
,
y
-
h
//
2
,
x
+
w
//
2
,
y
+
h
//
2
bbox
=
np
.
array
([
x1
,
y1
,
x2
,
y2
])
return
bbox
return
bbox
ppocr/utils/visual.py
浏览文件 @
ddaa2c25
...
@@ -113,14 +113,10 @@ def draw_re_results(image,
...
@@ -113,14 +113,10 @@ def draw_re_results(image,
return
np
.
array
(
img_new
)
return
np
.
array
(
img_new
)
def
draw_rectangle
(
img_path
,
boxes
,
use_xywh
=
False
):
def
draw_rectangle
(
img_path
,
boxes
):
img
=
cv2
.
imread
(
img_path
)
img
=
cv2
.
imread
(
img_path
)
img_show
=
img
.
copy
()
img_show
=
img
.
copy
()
for
box
in
boxes
.
astype
(
int
):
for
box
in
boxes
.
astype
(
int
):
if
use_xywh
:
x
,
y
,
w
,
h
=
box
x1
,
y1
,
x2
,
y2
=
x
-
w
//
2
,
y
-
h
//
2
,
x
+
w
//
2
,
y
+
h
//
2
else
:
x1
,
y1
,
x2
,
y2
=
box
x1
,
y1
,
x2
,
y2
=
box
cv2
.
rectangle
(
img_show
,
(
x1
,
y1
),
(
x2
,
y2
),
(
255
,
0
,
0
),
2
)
cv2
.
rectangle
(
img_show
,
(
x1
,
y1
),
(
x2
,
y2
),
(
255
,
0
,
0
),
2
)
return
img_show
return
img_show
\ No newline at end of file
ppstructure/layout/picodet_postprocess.py
0 → 100644
浏览文件 @
ddaa2c25
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
scipy.special
import
softmax
def
hard_nms
(
box_scores
,
iou_threshold
,
top_k
=-
1
,
candidate_size
=
200
):
"""
Args:
box_scores (N, 5): boxes in corner-form and probabilities.
iou_threshold: intersection over union threshold.
top_k: keep top_k results. If k <= 0, keep all the results.
candidate_size: only consider the candidates with the highest scores.
Returns:
picked: a list of indexes of the kept boxes
"""
scores
=
box_scores
[:,
-
1
]
boxes
=
box_scores
[:,
:
-
1
]
picked
=
[]
indexes
=
np
.
argsort
(
scores
)
indexes
=
indexes
[
-
candidate_size
:]
while
len
(
indexes
)
>
0
:
current
=
indexes
[
-
1
]
picked
.
append
(
current
)
if
0
<
top_k
==
len
(
picked
)
or
len
(
indexes
)
==
1
:
break
current_box
=
boxes
[
current
,
:]
indexes
=
indexes
[:
-
1
]
rest_boxes
=
boxes
[
indexes
,
:]
iou
=
iou_of
(
rest_boxes
,
np
.
expand_dims
(
current_box
,
axis
=
0
),
)
indexes
=
indexes
[
iou
<=
iou_threshold
]
return
box_scores
[
picked
,
:]
def
iou_of
(
boxes0
,
boxes1
,
eps
=
1e-5
):
"""Return intersection-over-union (Jaccard index) of boxes.
Args:
boxes0 (N, 4): ground truth boxes.
boxes1 (N or 1, 4): predicted boxes.
eps: a small number to avoid 0 as denominator.
Returns:
iou (N): IoU values.
"""
overlap_left_top
=
np
.
maximum
(
boxes0
[...,
:
2
],
boxes1
[...,
:
2
])
overlap_right_bottom
=
np
.
minimum
(
boxes0
[...,
2
:],
boxes1
[...,
2
:])
overlap_area
=
area_of
(
overlap_left_top
,
overlap_right_bottom
)
area0
=
area_of
(
boxes0
[...,
:
2
],
boxes0
[...,
2
:])
area1
=
area_of
(
boxes1
[...,
:
2
],
boxes1
[...,
2
:])
return
overlap_area
/
(
area0
+
area1
-
overlap_area
+
eps
)
def
area_of
(
left_top
,
right_bottom
):
"""Compute the areas of rectangles given two corners.
Args:
left_top (N, 2): left top corner.
right_bottom (N, 2): right bottom corner.
Returns:
area (N): return the area.
"""
hw
=
np
.
clip
(
right_bottom
-
left_top
,
0.0
,
None
)
return
hw
[...,
0
]
*
hw
[...,
1
]
class
PicoDetPostProcess
(
object
):
"""
Args:
input_shape (int): network input image size
ori_shape (int): ori image shape of before padding
scale_factor (float): scale factor of ori image
enable_mkldnn (bool): whether to open MKLDNN
"""
def
__init__
(
self
,
input_shape
,
ori_shape
,
scale_factor
,
strides
=
[
8
,
16
,
32
,
64
],
score_threshold
=
0.4
,
nms_threshold
=
0.5
,
nms_top_k
=
1000
,
keep_top_k
=
100
):
self
.
ori_shape
=
ori_shape
self
.
input_shape
=
input_shape
self
.
scale_factor
=
scale_factor
self
.
strides
=
strides
self
.
score_threshold
=
score_threshold
self
.
nms_threshold
=
nms_threshold
self
.
nms_top_k
=
nms_top_k
self
.
keep_top_k
=
keep_top_k
def
warp_boxes
(
self
,
boxes
,
ori_shape
):
"""Apply transform to boxes
"""
width
,
height
=
ori_shape
[
1
],
ori_shape
[
0
]
n
=
len
(
boxes
)
if
n
:
# warp points
xy
=
np
.
ones
((
n
*
4
,
3
))
xy
[:,
:
2
]
=
boxes
[:,
[
0
,
1
,
2
,
3
,
0
,
3
,
2
,
1
]].
reshape
(
n
*
4
,
2
)
# x1y1, x2y2, x1y2, x2y1
# xy = xy @ M.T # transform
xy
=
(
xy
[:,
:
2
]
/
xy
[:,
2
:
3
]).
reshape
(
n
,
8
)
# rescale
# create new boxes
x
=
xy
[:,
[
0
,
2
,
4
,
6
]]
y
=
xy
[:,
[
1
,
3
,
5
,
7
]]
xy
=
np
.
concatenate
(
(
x
.
min
(
1
),
y
.
min
(
1
),
x
.
max
(
1
),
y
.
max
(
1
))).
reshape
(
4
,
n
).
T
# clip boxes
xy
[:,
[
0
,
2
]]
=
xy
[:,
[
0
,
2
]].
clip
(
0
,
width
)
xy
[:,
[
1
,
3
]]
=
xy
[:,
[
1
,
3
]].
clip
(
0
,
height
)
return
xy
.
astype
(
np
.
float32
)
else
:
return
boxes
def
__call__
(
self
,
scores
,
raw_boxes
):
batch_size
=
raw_boxes
[
0
].
shape
[
0
]
reg_max
=
int
(
raw_boxes
[
0
].
shape
[
-
1
]
/
4
-
1
)
out_boxes_num
=
[]
out_boxes_list
=
[]
for
batch_id
in
range
(
batch_size
):
# generate centers
decode_boxes
=
[]
select_scores
=
[]
for
stride
,
box_distribute
,
score
in
zip
(
self
.
strides
,
raw_boxes
,
scores
):
box_distribute
=
box_distribute
[
batch_id
]
score
=
score
[
batch_id
]
# centers
fm_h
=
self
.
input_shape
[
0
]
/
stride
fm_w
=
self
.
input_shape
[
1
]
/
stride
h_range
=
np
.
arange
(
fm_h
)
w_range
=
np
.
arange
(
fm_w
)
ww
,
hh
=
np
.
meshgrid
(
w_range
,
h_range
)
ct_row
=
(
hh
.
flatten
()
+
0.5
)
*
stride
ct_col
=
(
ww
.
flatten
()
+
0.5
)
*
stride
center
=
np
.
stack
((
ct_col
,
ct_row
,
ct_col
,
ct_row
),
axis
=
1
)
# box distribution to distance
reg_range
=
np
.
arange
(
reg_max
+
1
)
box_distance
=
box_distribute
.
reshape
((
-
1
,
reg_max
+
1
))
box_distance
=
softmax
(
box_distance
,
axis
=
1
)
box_distance
=
box_distance
*
np
.
expand_dims
(
reg_range
,
axis
=
0
)
box_distance
=
np
.
sum
(
box_distance
,
axis
=
1
).
reshape
((
-
1
,
4
))
box_distance
=
box_distance
*
stride
# top K candidate
topk_idx
=
np
.
argsort
(
score
.
max
(
axis
=
1
))[::
-
1
]
topk_idx
=
topk_idx
[:
self
.
nms_top_k
]
center
=
center
[
topk_idx
]
score
=
score
[
topk_idx
]
box_distance
=
box_distance
[
topk_idx
]
# decode box
decode_box
=
center
+
[
-
1
,
-
1
,
1
,
1
]
*
box_distance
select_scores
.
append
(
score
)
decode_boxes
.
append
(
decode_box
)
# nms
bboxes
=
np
.
concatenate
(
decode_boxes
,
axis
=
0
)
confidences
=
np
.
concatenate
(
select_scores
,
axis
=
0
)
picked_box_probs
=
[]
picked_labels
=
[]
for
class_index
in
range
(
0
,
confidences
.
shape
[
1
]):
probs
=
confidences
[:,
class_index
]
mask
=
probs
>
self
.
score_threshold
probs
=
probs
[
mask
]
if
probs
.
shape
[
0
]
==
0
:
continue
subset_boxes
=
bboxes
[
mask
,
:]
box_probs
=
np
.
concatenate
(
[
subset_boxes
,
probs
.
reshape
(
-
1
,
1
)],
axis
=
1
)
box_probs
=
hard_nms
(
box_probs
,
iou_threshold
=
self
.
nms_threshold
,
top_k
=
self
.
keep_top_k
,
)
picked_box_probs
.
append
(
box_probs
)
picked_labels
.
extend
([
class_index
]
*
box_probs
.
shape
[
0
])
if
len
(
picked_box_probs
)
==
0
:
out_boxes_list
.
append
(
np
.
empty
((
0
,
4
)))
out_boxes_num
.
append
(
0
)
else
:
picked_box_probs
=
np
.
concatenate
(
picked_box_probs
)
# resize output boxes
picked_box_probs
[:,
:
4
]
=
self
.
warp_boxes
(
picked_box_probs
[:,
:
4
],
self
.
ori_shape
[
batch_id
])
im_scale
=
np
.
concatenate
([
self
.
scale_factor
[
batch_id
][::
-
1
],
self
.
scale_factor
[
batch_id
][::
-
1
]
])
picked_box_probs
[:,
:
4
]
/=
im_scale
# clas score box
out_boxes_list
.
append
(
np
.
concatenate
(
[
np
.
expand_dims
(
np
.
array
(
picked_labels
),
axis
=-
1
),
np
.
expand_dims
(
picked_box_probs
[:,
4
],
axis
=-
1
),
picked_box_probs
[:,
:
4
]
],
axis
=
1
))
out_boxes_num
.
append
(
len
(
picked_labels
))
out_boxes_list
=
np
.
concatenate
(
out_boxes_list
,
axis
=
0
)
out_boxes_num
=
np
.
asarray
(
out_boxes_num
).
astype
(
np
.
int32
)
return
out_boxes_list
,
out_boxes_num
ppstructure/layout/predict_layout.py
0 → 100644
浏览文件 @
ddaa2c25
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
numpy
as
np
import
time
import
tools.infer.utility
as
utility
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppstructure.utility
import
parse_args
from
picodet_postprocess
import
PicoDetPostProcess
logger
=
get_logger
()
class
LayoutPredictor
(
object
):
def
__init__
(
self
,
args
):
pre_process_list
=
[{
'Resize'
:
{
'size'
:
[
800
,
608
]
}
},
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
'mean'
:
[
0.485
,
0.456
,
0.406
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
},
{
'ToCHWImage'
:
None
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
]
}
}]
# postprocess_params = {
# 'name': 'LayoutPostProcess',
# "character_dict_path": args.layout_dict_path,
# }
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
# self.postprocess_op = build_post_process(postprocess_params)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'layout'
,
logger
)
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
img
=
data
[
0
]
if
img
is
None
:
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
img
.
copy
()
preds
,
elapse
=
0
,
1
starttime
=
time
.
time
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
# outputs = []
# for output_tensor in self.output_tensors:
# output = output_tensor.copy_to_cpu()
# outputs.append(output)
np_score_list
,
np_boxes_list
=
[],
[]
output_names
=
self
.
predictor
.
get_output_names
()
num_outs
=
int
(
len
(
output_names
)
/
2
)
for
out_idx
in
range
(
num_outs
):
np_score_list
.
append
(
self
.
predictor
.
get_output_handle
(
output_names
[
out_idx
])
.
copy_to_cpu
())
np_boxes_list
.
append
(
self
.
predictor
.
get_output_handle
(
output_names
[
out_idx
+
num_outs
]).
copy_to_cpu
())
# result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
postprocessor
=
PicoDetPostProcess
(
(
800
,
608
),
[[
800.
,
608.
]],
np
.
array
([[
1.010101
,
0.99346405
]]),
strides
=
[
8
,
16
,
32
,
64
],
nms_threshold
=
0.5
)
np_boxes
,
np_boxes_num
=
postprocessor
(
np_score_list
,
np_boxes_list
)
result
=
dict
(
boxes
=
np_boxes
,
boxes_num
=
np_boxes_num
)
# print(result)
im_bboxes_num
=
result
[
'boxes_num'
][
0
]
# print('im_bboxes_num:',im_bboxes_num)
bboxs
=
result
[
'boxes'
][
0
:
0
+
im_bboxes_num
,
:]
threshold
=
0.5
expect_boxes
=
(
np_boxes
[:,
1
]
>
threshold
)
&
(
np_boxes
[:,
0
]
>
-
1
)
np_boxes
=
np_boxes
[
expect_boxes
,
:]
preds
=
[]
id2label
=
{
1
:
'text'
,
2
:
'title'
,
3
:
'list'
,
4
:
'table'
,
5
:
'figure'
}
for
dt
in
np_boxes
:
clsid
,
bbox
,
score
=
int
(
dt
[
0
]),
dt
[
2
:],
dt
[
1
]
label
=
id2label
[
clsid
+
1
]
result_di
=
{
'bbox'
:
bbox
,
'label'
:
label
}
preds
.
append
(
result_di
)
# print('result_di',result_di)
# print('clsid, bbox, score:',clsid, bbox, score)
elapse
=
time
.
time
()
-
starttime
return
preds
,
elapse
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
layout_predictor
=
LayoutPredictor
(
args
)
count
=
0
total_time
=
0
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
layout_res
,
elapse
=
layout_predictor
(
img
)
logger
.
info
(
"result: {}"
.
format
(
layout_res
))
if
count
>
0
:
total_time
+=
elapse
count
+=
1
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
if
__name__
==
"__main__"
:
main
(
parse_args
())
ppstructure/predict_system.py
浏览文件 @
ddaa2c25
...
@@ -18,7 +18,7 @@ import subprocess
...
@@ -18,7 +18,7 @@ import subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..
'
)))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../
'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
cv2
...
@@ -32,6 +32,7 @@ from attrdict import AttrDict
...
@@ -32,6 +32,7 @@ from attrdict import AttrDict
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
from
tools.infer.predict_system
import
TextSystem
from
tools.infer.predict_system
import
TextSystem
from
ppstructure.layout.predict_layout
import
LayoutPredictor
from
ppstructure.table.predict_table
import
TableSystem
,
to_excel
from
ppstructure.table.predict_table
import
TableSystem
,
to_excel
from
ppstructure.utility
import
parse_args
,
draw_structure_result
from
ppstructure.utility
import
parse_args
,
draw_structure_result
from
ppstructure.recovery.recovery_to_doc
import
convert_info_docx
from
ppstructure.recovery.recovery_to_doc
import
convert_info_docx
...
@@ -51,28 +52,14 @@ class StructureSystem(object):
...
@@ -51,28 +52,14 @@ class StructureSystem(object):
"When args.layout is false, args.ocr is automatically set to false"
"When args.layout is false, args.ocr is automatically set to false"
)
)
args
.
drop_score
=
0
args
.
drop_score
=
0
# init layout and ocr model
# init model
self
.
layout_predictor
=
None
self
.
text_system
=
None
self
.
text_system
=
None
self
.
table_system
=
None
if
args
.
layout
:
if
args
.
layout
:
import
layoutparser
as
lp
self
.
layout_predictor
=
LayoutPredictor
(
args
)
config_path
=
None
model_path
=
None
if
os
.
path
.
isdir
(
args
.
layout_path_model
):
model_path
=
args
.
layout_path_model
else
:
config_path
=
args
.
layout_path_model
self
.
table_layout
=
lp
.
PaddleDetectionLayoutModel
(
config_path
=
config_path
,
model_path
=
model_path
,
label_map
=
args
.
layout_label_map
,
threshold
=
0.5
,
enable_mkldnn
=
args
.
enable_mkldnn
,
enforce_cpu
=
not
args
.
use_gpu
,
thread_num
=
args
.
cpu_threads
)
if
args
.
ocr
:
if
args
.
ocr
:
self
.
text_system
=
TextSystem
(
args
)
self
.
text_system
=
TextSystem
(
args
)
else
:
self
.
table_layout
=
None
if
args
.
table
:
if
args
.
table
:
if
self
.
text_system
is
not
None
:
if
self
.
text_system
is
not
None
:
self
.
table_system
=
TableSystem
(
self
.
table_system
=
TableSystem
(
...
@@ -80,38 +67,59 @@ class StructureSystem(object):
...
@@ -80,38 +67,59 @@ class StructureSystem(object):
self
.
text_system
.
text_recognizer
)
self
.
text_system
.
text_recognizer
)
else
:
else
:
self
.
table_system
=
TableSystem
(
args
)
self
.
table_system
=
TableSystem
(
args
)
else
:
self
.
table_system
=
None
elif
self
.
mode
==
'vqa'
:
elif
self
.
mode
==
'vqa'
:
raise
NotImplementedError
raise
NotImplementedError
def
__call__
(
self
,
img
,
return_ocr_result_in_table
=
False
):
def
__call__
(
self
,
img
,
return_ocr_result_in_table
=
False
):
time_dict
=
{
'layout'
:
0
,
'table'
:
0
,
'table_match'
:
0
,
'det'
:
0
,
'rec'
:
0
,
'vqa'
:
0
,
'all'
:
0
}
start
=
time
.
time
()
if
self
.
mode
==
'structure'
:
if
self
.
mode
==
'structure'
:
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
if
self
.
table_layout
is
not
None
:
if
self
.
layout_predictor
is
not
None
:
layout_res
=
self
.
table_layout
.
detect
(
img
[...,
::
-
1
])
layout_res
,
elapse
=
self
.
layout_predictor
(
img
)
time_dict
[
'layout'
]
+=
elapse
else
:
else
:
h
,
w
=
ori_im
.
shape
[:
2
]
h
,
w
=
ori_im
.
shape
[:
2
]
layout_res
=
[
AttrDict
(
coordinates
=
[
0
,
0
,
w
,
h
],
type
=
'T
able'
)]
layout_res
=
[
dict
(
bbox
=
None
,
label
=
't
able'
)]
res_list
=
[]
res_list
=
[]
for
region
in
layout_res
:
for
region
in
layout_res
:
res
=
''
res
=
''
x1
,
y1
,
x2
,
y2
=
region
.
coordinates
if
region
[
'bbox'
]
is
not
None
:
x1
,
y1
,
x2
,
y2
=
region
[
'bbox'
]
x1
,
y1
,
x2
,
y2
=
int
(
x1
),
int
(
y1
),
int
(
x2
),
int
(
y2
)
x1
,
y1
,
x2
,
y2
=
int
(
x1
),
int
(
y1
),
int
(
x2
),
int
(
y2
)
roi_img
=
ori_im
[
y1
:
y2
,
x1
:
x2
,
:]
roi_img
=
ori_im
[
y1
:
y2
,
x1
:
x2
,
:]
if
region
.
type
==
'Table'
:
else
:
x1
,
y1
,
x2
,
y2
=
0
,
0
,
w
,
h
roi_img
=
ori_im
if
region
[
'label'
]
==
'table'
:
if
self
.
table_system
is
not
None
:
if
self
.
table_system
is
not
None
:
res
=
self
.
table_system
(
roi_img
,
res
,
table_time_dict
=
self
.
table_system
(
return_ocr_result_in_table
)
roi_img
,
return_ocr_result_in_table
)
time_dict
[
'table'
]
+=
table_time_dict
[
'table'
]
time_dict
[
'table_match'
]
+=
table_time_dict
[
'match'
]
time_dict
[
'det'
]
+=
table_time_dict
[
'det'
]
time_dict
[
'rec'
]
+=
table_time_dict
[
'rec'
]
else
:
else
:
if
self
.
text_system
is
not
None
:
if
self
.
text_system
is
not
None
:
if
args
.
recovery
:
if
args
.
recovery
:
wht_im
=
np
.
ones
(
ori_im
.
shape
,
dtype
=
ori_im
.
dtype
)
wht_im
=
np
.
ones
(
ori_im
.
shape
,
dtype
=
ori_im
.
dtype
)
wht_im
[
y1
:
y2
,
x1
:
x2
,
:]
=
roi_img
wht_im
[
y1
:
y2
,
x1
:
x2
,
:]
=
roi_img
filter_boxes
,
filter_rec_res
=
self
.
text_system
(
wht_im
)
filter_boxes
,
filter_rec_res
,
ocr_time_dict
=
self
.
text_system
(
wht_im
)
else
:
else
:
filter_boxes
,
filter_rec_res
=
self
.
text_system
(
roi_img
)
filter_boxes
,
filter_rec_res
,
ocr_time_dict
=
self
.
text_system
(
roi_img
)
time_dict
[
'det'
]
+=
ocr_time_dict
[
'det'
]
time_dict
[
'rec'
]
+=
ocr_time_dict
[
'rec'
]
# remove style char
# remove style char
style_token
=
[
style_token
=
[
'<strike>'
,
'<strike>'
,
'<sup>'
,
'</sub>'
,
'<b>'
,
'<strike>'
,
'<strike>'
,
'<sup>'
,
'</sub>'
,
'<b>'
,
...
@@ -133,15 +141,17 @@ class StructureSystem(object):
...
@@ -133,15 +141,17 @@ class StructureSystem(object):
'text_region'
:
box
.
tolist
()
'text_region'
:
box
.
tolist
()
})
})
res_list
.
append
({
res_list
.
append
({
'type'
:
region
.
type
,
'type'
:
region
[
'label'
].
lower
()
,
'bbox'
:
[
x1
,
y1
,
x2
,
y2
],
'bbox'
:
[
x1
,
y1
,
x2
,
y2
],
'img'
:
roi_img
,
'img'
:
roi_img
,
'res'
:
res
'res'
:
res
})
})
return
res_list
end
=
time
.
time
()
time_dict
[
'all'
]
=
end
-
start
return
res_list
,
time_dict
elif
self
.
mode
==
'vqa'
:
elif
self
.
mode
==
'vqa'
:
raise
NotImplementedError
raise
NotImplementedError
return
None
return
None
,
None
def
save_structure_res
(
res
,
save_folder
,
img_name
):
def
save_structure_res
(
res
,
save_folder
,
img_name
):
...
@@ -156,12 +166,12 @@ def save_structure_res(res, save_folder, img_name):
...
@@ -156,12 +166,12 @@ def save_structure_res(res, save_folder, img_name):
roi_img
=
region
.
pop
(
'img'
)
roi_img
=
region
.
pop
(
'img'
)
f
.
write
(
'{}
\n
'
.
format
(
json
.
dumps
(
region
)))
f
.
write
(
'{}
\n
'
.
format
(
json
.
dumps
(
region
)))
if
region
[
'type'
]
==
'
T
able'
and
len
(
region
[
if
region
[
'type'
]
==
'
t
able'
and
len
(
region
[
'res'
])
>
0
and
'html'
in
region
[
'res'
]:
'res'
])
>
0
and
'html'
in
region
[
'res'
]:
excel_path
=
os
.
path
.
join
(
excel_save_folder
,
excel_path
=
os
.
path
.
join
(
excel_save_folder
,
'{}.xlsx'
.
format
(
region
[
'bbox'
]))
'{}.xlsx'
.
format
(
region
[
'bbox'
]))
to_excel
(
region
[
'res'
][
'html'
],
excel_path
)
to_excel
(
region
[
'res'
][
'html'
],
excel_path
)
elif
region
[
'type'
]
==
'
F
igure'
:
elif
region
[
'type'
]
==
'
f
igure'
:
img_path
=
os
.
path
.
join
(
excel_save_folder
,
img_path
=
os
.
path
.
join
(
excel_save_folder
,
'{}.jpg'
.
format
(
region
[
'bbox'
]))
'{}.jpg'
.
format
(
region
[
'bbox'
]))
cv2
.
imwrite
(
img_path
,
roi_img
)
cv2
.
imwrite
(
img_path
,
roi_img
)
...
@@ -188,7 +198,7 @@ def main(args):
...
@@ -188,7 +198,7 @@ def main(args):
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
continue
continue
starttime
=
time
.
time
()
starttime
=
time
.
time
()
res
=
structure_sys
(
img
)
res
,
time_dict
=
structure_sys
(
img
)
if
structure_sys
.
mode
==
'structure'
:
if
structure_sys
.
mode
==
'structure'
:
save_structure_res
(
res
,
save_folder
,
img_name
)
save_structure_res
(
res
,
save_folder
,
img_name
)
...
...
ppstructure/table/eval_table.py
浏览文件 @
ddaa2c25
...
@@ -13,12 +13,14 @@
...
@@ -13,12 +13,14 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
cv2
import
cv2
import
json
import
pickle
import
paddle
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
ppstructure.table.table_metric
import
TEDS
from
ppstructure.table.table_metric
import
TEDS
from
ppstructure.table.predict_table
import
TableSystem
from
ppstructure.table.predict_table
import
TableSystem
...
@@ -33,40 +35,74 @@ def parse_args():
...
@@ -33,40 +35,74 @@ def parse_args():
parser
.
add_argument
(
"--gt_path"
,
type
=
str
)
parser
.
add_argument
(
"--gt_path"
,
type
=
str
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
main
(
gt_path
,
img_root
,
args
):
teds
=
TEDS
(
n_jobs
=
16
)
def
load_txt
(
txt_path
):
pred_html_dict
=
{}
if
not
os
.
path
.
exists
(
txt_path
):
return
pred_html_dict
with
open
(
txt_path
,
encoding
=
'utf-8'
)
as
f
:
lines
=
f
.
readlines
()
for
line
in
lines
:
line
=
line
.
strip
().
split
(
'
\t
'
)
img_name
,
pred_html
=
line
pred_html_dict
[
img_name
]
=
pred_html
return
pred_html_dict
def
load_result
(
path
):
data
=
{}
if
os
.
path
.
exists
(
path
):
data
=
pickle
.
load
(
open
(
path
,
'rb'
))
return
data
def
save_result
(
path
,
data
):
old_data
=
load_result
(
path
)
old_data
.
update
(
data
)
with
open
(
path
,
'wb'
)
as
f
:
pickle
.
dump
(
old_data
,
f
)
def
main
(
gt_path
,
img_root
,
args
):
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
# init TableSystem
text_sys
=
TableSystem
(
args
)
text_sys
=
TableSystem
(
args
)
jsons_gt
=
json
.
load
(
open
(
gt_path
))
# gt
# load gt and preds html result
gt_html_dict
=
load_txt
(
gt_path
)
ocr_result
=
load_result
(
os
.
path
.
join
(
args
.
output
,
'ocr.pickle'
))
structure_result
=
load_result
(
os
.
path
.
join
(
args
.
output
,
'structure.pickle'
))
pred_htmls
=
[]
pred_htmls
=
[]
gt_htmls
=
[]
gt_htmls
=
[]
for
img_name
in
tqdm
(
jsons_gt
):
for
img_name
,
gt_html
in
tqdm
(
gt_html_dict
.
items
()):
# read image
img
=
cv2
.
imread
(
os
.
path
.
join
(
img_root
,
img_name
))
img
=
cv2
.
imread
(
os
.
path
.
join
(
img_root
,
img_name
))
# run ocr and save result
pred_html
=
text_sys
(
img
)
if
img_name
not
in
ocr_result
:
pred_htmls
.
append
(
pred_html
)
dt_boxes
,
rec_res
,
_
,
_
=
text_sys
.
_ocr
(
img
)
ocr_result
[
img_name
]
=
[
dt_boxes
,
rec_res
]
save_result
(
os
.
path
.
join
(
args
.
output
,
'ocr.pickle'
),
ocr_result
)
# run structure and save result
if
img_name
not
in
structure_result
:
structure_res
,
_
=
text_sys
.
_structure
(
img
)
structure_result
[
img_name
]
=
structure_res
save_result
(
os
.
path
.
join
(
args
.
output
,
'structure.pickle'
),
structure_result
)
dt_boxes
,
rec_res
=
ocr_result
[
img_name
]
structure_res
=
structure_result
[
img_name
]
# match ocr and structure
pred_html
=
text_sys
.
match
(
structure_res
,
dt_boxes
,
rec_res
)
gt_structures
,
gt_bboxes
,
gt_contents
=
jsons_gt
[
img_name
]
pred_htmls
.
append
(
pred_html
)
gt_html
,
gt
=
get_gt_html
(
gt_structures
,
gt_contents
)
gt_htmls
.
append
(
gt_html
)
gt_htmls
.
append
(
gt_html
)
scores
=
teds
.
batch_evaluate_html
(
gt_htmls
,
pred_htmls
)
logger
.
info
(
'teds:'
,
sum
(
scores
)
/
len
(
scores
))
def
get_gt_html
(
gt_structures
,
gt_contents
):
# compute teds
end_html
=
[]
teds
=
TEDS
(
n_jobs
=
16
)
td_index
=
0
scores
=
teds
.
batch_evaluate_html
(
gt_htmls
,
pred_htmls
)
for
tag
in
gt_structures
:
logger
.
info
(
'teds: {}'
.
format
(
sum
(
scores
)
/
len
(
scores
)))
if
'</td>'
in
tag
:
if
gt_contents
[
td_index
]
!=
[]:
end_html
.
extend
(
gt_contents
[
td_index
])
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
args
=
parse_args
()
args
=
parse_args
()
main
(
args
.
gt_path
,
args
.
image_dir
,
args
)
main
(
args
.
gt_path
,
args
.
image_dir
,
args
)
ppstructure/table/matcher.py
浏览文件 @
ddaa2c25
import
json
import
json
from
ppstructure.table.table_master_match
import
deal_eb_token
,
deal_bb
def
distance
(
box_1
,
box_2
):
def
distance
(
box_1
,
box_2
):
x1
,
y1
,
x2
,
y2
=
box_1
x1
,
y1
,
x2
,
y2
=
box_1
x3
,
y3
,
x4
,
y4
=
box_2
x3
,
y3
,
x4
,
y4
=
box_2
dis
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
+
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
dis
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
+
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
dis_2
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
dis_2
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
dis_3
=
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
dis_3
=
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
return
dis
+
min
(
dis_2
,
dis_3
)
return
dis
+
min
(
dis_2
,
dis_3
)
def
compute_iou
(
rec1
,
rec2
):
def
compute_iou
(
rec1
,
rec2
):
"""
"""
computing IoU
computing IoU
...
@@ -33,8 +37,7 @@ def compute_iou(rec1, rec2):
...
@@ -33,8 +37,7 @@ def compute_iou(rec1, rec2):
return
0.0
return
0.0
else
:
else
:
intersect
=
(
right_line
-
left_line
)
*
(
bottom_line
-
top_line
)
intersect
=
(
right_line
-
left_line
)
*
(
bottom_line
-
top_line
)
return
(
intersect
/
(
sum_area
-
intersect
))
*
1.0
return
(
intersect
/
(
sum_area
-
intersect
))
*
1.0
def
matcher_merge
(
ocr_bboxes
,
pred_bboxes
):
def
matcher_merge
(
ocr_bboxes
,
pred_bboxes
):
...
@@ -45,15 +48,18 @@ def matcher_merge(ocr_bboxes, pred_bboxes):
...
@@ -45,15 +48,18 @@ def matcher_merge(ocr_bboxes, pred_bboxes):
distances
=
[]
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
# compute l1 distence and IOU between two boxes
# compute l1 distence and IOU between two boxes
distances
.
append
((
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
distances
.
append
((
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
sorted_distances
=
distances
.
copy
()
sorted_distances
=
distances
.
copy
()
# select nearest cell
# select nearest cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
return
matched
#, sum(ious) / len(ious)
def
complex_num
(
pred_bboxes
):
def
complex_num
(
pred_bboxes
):
complex_nums
=
[]
complex_nums
=
[]
...
@@ -67,6 +73,7 @@ def complex_num(pred_bboxes):
...
@@ -67,6 +73,7 @@ def complex_num(pred_bboxes):
complex_nums
.
append
(
temp_ious
[
distances
.
index
(
min
(
distances
))])
complex_nums
.
append
(
temp_ious
[
distances
.
index
(
min
(
distances
))])
return
sum
(
complex_nums
)
/
len
(
complex_nums
)
return
sum
(
complex_nums
)
/
len
(
complex_nums
)
def
get_rows
(
pred_bboxes
):
def
get_rows
(
pred_bboxes
):
pre_bbox
=
pred_bboxes
[
0
]
pre_bbox
=
pred_bboxes
[
0
]
res
=
[]
res
=
[]
...
@@ -81,6 +88,8 @@ def get_rows(pred_bboxes):
...
@@ -81,6 +88,8 @@ def get_rows(pred_bboxes):
for
i
in
range
(
step
):
for
i
in
range
(
step
):
pred_bboxes
.
pop
(
0
)
pred_bboxes
.
pop
(
0
)
return
res
,
pred_bboxes
return
res
,
pred_bboxes
def
refine_rows
(
pred_bboxes
):
# 微调整行的框,使在一条水平线上
def
refine_rows
(
pred_bboxes
):
# 微调整行的框,使在一条水平线上
ys_1
=
[]
ys_1
=
[]
ys_2
=
[]
ys_2
=
[]
...
@@ -96,11 +105,13 @@ def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
...
@@ -96,11 +105,13 @@ def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
re_boxes
.
append
(
box
)
re_boxes
.
append
(
box
)
return
re_boxes
return
re_boxes
def
matcher_refine_row
(
gt_bboxes
,
pred_bboxes
):
def
matcher_refine_row
(
gt_bboxes
,
pred_bboxes
):
before_refine_pred_bboxes
=
pred_bboxes
.
copy
()
before_refine_pred_bboxes
=
pred_bboxes
.
copy
()
pred_bboxes
=
[]
pred_bboxes
=
[]
while
(
len
(
before_refine_pred_bboxes
)
!=
0
):
while
(
len
(
before_refine_pred_bboxes
)
!=
0
):
row_bboxes
,
before_refine_pred_bboxes
=
get_rows
(
before_refine_pred_bboxes
)
row_bboxes
,
before_refine_pred_bboxes
=
get_rows
(
before_refine_pred_bboxes
)
print
(
row_bboxes
)
print
(
row_bboxes
)
pred_bboxes
.
extend
(
refine_rows
(
row_bboxes
))
pred_bboxes
.
extend
(
refine_rows
(
row_bboxes
))
all_dis
=
[]
all_dis
=
[]
...
@@ -118,8 +129,7 @@ def matcher_refine_row(gt_bboxes, pred_bboxes):
...
@@ -118,8 +129,7 @@ def matcher_refine_row(gt_bboxes, pred_bboxes):
matched
[
distances
.
index
(
min
(
distances
))]
=
[
i
]
matched
[
distances
.
index
(
min
(
distances
))]
=
[
i
]
else
:
else
:
matched
[
distances
.
index
(
min
(
distances
))].
append
(
i
)
matched
[
distances
.
index
(
min
(
distances
))].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
return
matched
#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
#先挑选出一行,再进行匹配
...
@@ -128,9 +138,9 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
...
@@ -128,9 +138,9 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
delete_gt_bboxes
=
gt_bboxes
.
copy
()
delete_gt_bboxes
=
gt_bboxes
.
copy
()
match_bboxes_ready
=
[]
match_bboxes_ready
=
[]
matched
=
{}
matched
=
{}
while
(
len
(
delete_gt_bboxes
)
!=
0
):
while
(
len
(
delete_gt_bboxes
)
!=
0
):
row_bboxes
,
delete_gt_bboxes
=
get_rows
(
delete_gt_bboxes
)
row_bboxes
,
delete_gt_bboxes
=
get_rows
(
delete_gt_bboxes
)
row_bboxes
=
sorted
(
row_bboxes
,
key
=
lambda
key
:
key
[
0
])
row_bboxes
=
sorted
(
row_bboxes
,
key
=
lambda
key
:
key
[
0
])
if
len
(
pred_bboxes_rows
)
>
0
:
if
len
(
pred_bboxes_rows
)
>
0
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
print
(
row_bboxes
)
print
(
row_bboxes
)
...
@@ -151,6 +161,7 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
...
@@ -151,6 +161,7 @@ def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
gt_box_index
+=
1
gt_box_index
+=
1
return
matched
return
matched
def
matcher_structure
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
def
matcher_structure
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
'''
'''
gt_bboxes: 排序后
gt_bboxes: 排序后
...
@@ -190,3 +201,137 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
...
@@ -190,3 +201,137 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
matched
[
index
].
append
(
i
)
matched
[
index
].
append
(
i
)
pre_bbox
=
gt_box
pre_bbox
=
gt_box
return
matched
return
matched
class
TableMatch
:
def
__init__
(
self
,
filter_ocr_result
=
False
,
use_master
=
False
):
self
.
filter_ocr_result
=
filter_ocr_result
self
.
use_master
=
use_master
def
__call__
(
self
,
structure_res
,
dt_boxes
,
rec_res
):
pred_structures
,
pred_bboxes
=
structure_res
if
self
.
filter_ocr_result
:
dt_boxes
,
rec_res
=
self
.
filter_ocr_result
(
pred_bboxes
,
dt_boxes
,
rec_res
)
matched_index
=
self
.
match_result
(
dt_boxes
,
pred_bboxes
)
if
self
.
use_master
:
pred_html
,
pred
=
self
.
get_pred_html_master
(
pred_structures
,
matched_index
,
rec_res
)
else
:
pred_html
,
pred
=
self
.
get_pred_html
(
pred_structures
,
matched_index
,
rec_res
)
return
pred_html
def
match_result
(
self
,
dt_boxes
,
pred_bboxes
):
matched
=
{}
for
i
,
gt_box
in
enumerate
(
dt_boxes
):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
((
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)
))
# 获取两两cell之间的L1距离和 1- IOU
sorted_distances
=
distances
.
copy
()
# 根据距离和IOU挑选最"近"的cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
def
get_pred_html
(
self
,
pred_structures
,
matched_index
,
ocr_contents
):
end_html
=
[]
td_index
=
0
for
tag
in
pred_structures
:
if
'</td>'
in
tag
:
if
'<td></td>'
==
tag
:
end_html
.
extend
(
'<td>'
)
if
td_index
in
matched_index
.
keys
():
b_with
=
False
if
'<b>'
in
ocr_contents
[
matched_index
[
td_index
][
0
]]
and
len
(
matched_index
[
td_index
])
>
1
:
b_with
=
True
end_html
.
extend
(
'<b>'
)
for
i
,
td_index_index
in
enumerate
(
matched_index
[
td_index
]):
content
=
ocr_contents
[
td_index_index
][
0
]
if
len
(
matched_index
[
td_index
])
>
1
:
if
len
(
content
)
==
0
:
continue
if
content
[
0
]
==
' '
:
content
=
content
[
1
:]
if
'<b>'
in
content
:
content
=
content
[
3
:]
if
'</b>'
in
content
:
content
=
content
[:
-
4
]
if
len
(
content
)
==
0
:
continue
if
i
!=
len
(
matched_index
[
td_index
])
-
1
and
' '
!=
content
[
-
1
]:
content
+=
' '
end_html
.
extend
(
content
)
if
b_with
:
end_html
.
extend
(
'</b>'
)
if
'<td></td>'
==
tag
:
end_html
.
append
(
'</td>'
)
else
:
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
def
get_pred_html_master
(
self
,
pred_structures
,
matched_index
,
ocr_contents
):
end_html
=
[]
td_index
=
0
for
token
in
pred_structures
:
if
'</td>'
in
token
:
txt
=
''
b_with
=
False
if
td_index
in
matched_index
.
keys
():
if
'<b>'
in
ocr_contents
[
matched_index
[
td_index
][
0
]]
and
len
(
matched_index
[
td_index
])
>
1
:
b_with
=
True
for
i
,
td_index_index
in
enumerate
(
matched_index
[
td_index
]):
content
=
ocr_contents
[
td_index_index
][
0
]
if
len
(
matched_index
[
td_index
])
>
1
:
if
len
(
content
)
==
0
:
continue
if
content
[
0
]
==
' '
:
content
=
content
[
1
:]
if
'<b>'
in
content
:
content
=
content
[
3
:]
if
'</b>'
in
content
:
content
=
content
[:
-
4
]
if
len
(
content
)
==
0
:
continue
if
i
!=
len
(
matched_index
[
td_index
])
-
1
and
' '
!=
content
[
-
1
]:
content
+=
' '
txt
+=
content
if
b_with
:
txt
=
'<b>{}</b>'
.
format
(
txt
)
if
'<td></td>'
==
token
:
token
=
'<td>{}</td>'
.
format
(
txt
)
else
:
token
=
'{}</td>'
.
format
(
txt
)
td_index
+=
1
token
=
deal_eb_token
(
token
)
end_html
.
append
(
token
)
html
=
''
.
join
(
end_html
)
html
=
deal_bb
(
html
)
return
html
,
end_html
def
filter_ocr_result
(
self
,
pred_bboxes
,
dt_boxes
,
rec_res
):
y1
=
pred_bboxes
[:,
1
::
2
].
min
()
new_dt_boxes
=
[]
new_rec_res
=
[]
for
box
,
rec
in
zip
(
dt_boxes
,
rec_res
):
if
np
.
max
(
box
[
1
::
2
])
<
y1
:
continue
new_dt_boxes
.
append
(
box
)
new_rec_res
.
append
(
rec
)
return
new_dt_boxes
,
new_rec_res
ppstructure/table/predict_structure.py
浏览文件 @
ddaa2c25
...
@@ -16,7 +16,7 @@ import sys
...
@@ -16,7 +16,7 @@ import sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
...
@@ -87,6 +87,7 @@ class TableStructurer(object):
...
@@ -87,6 +87,7 @@ class TableStructurer(object):
utility
.
create_predictor
(
args
,
'table'
,
logger
)
utility
.
create_predictor
(
args
,
'table'
,
logger
)
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
starttime
=
time
.
time
()
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
data
=
transform
(
data
,
self
.
preprocess_op
)
...
@@ -95,7 +96,6 @@ class TableStructurer(object):
...
@@ -95,7 +96,6 @@ class TableStructurer(object):
return
None
,
0
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
img
.
copy
()
img
=
img
.
copy
()
starttime
=
time
.
time
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
self
.
predictor
.
run
()
...
@@ -126,7 +126,6 @@ def main(args):
...
@@ -126,7 +126,6 @@ def main(args):
table_structurer
=
TableStructurer
(
args
)
table_structurer
=
TableStructurer
(
args
)
count
=
0
count
=
0
total_time
=
0
total_time
=
0
use_xywh
=
args
.
table_algorithm
in
[
'TableMaster'
]
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
with
open
(
with
open
(
os
.
path
.
join
(
args
.
output
,
'infer.txt'
),
mode
=
'w'
,
os
.
path
.
join
(
args
.
output
,
'infer.txt'
),
mode
=
'w'
,
...
@@ -146,7 +145,7 @@ def main(args):
...
@@ -146,7 +145,7 @@ def main(args):
f_w
.
write
(
"result: {}, {}
\n
"
.
format
(
structure_str_list
,
f_w
.
write
(
"result: {}, {}
\n
"
.
format
(
structure_str_list
,
bbox_list_str
))
bbox_list_str
))
img
=
draw_rectangle
(
image_file
,
bbox_list
,
use_xywh
)
img
=
draw_rectangle
(
image_file
,
bbox_list
)
img_save_path
=
os
.
path
.
join
(
args
.
output
,
img_save_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
basename
(
image_file
))
os
.
path
.
basename
(
image_file
))
cv2
.
imwrite
(
img_save_path
,
img
)
cv2
.
imwrite
(
img_save_path
,
img
)
...
...
ppstructure/table/predict_table.py
浏览文件 @
ddaa2c25
...
@@ -18,20 +18,23 @@ import subprocess
...
@@ -18,20 +18,23 @@ import subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
cv2
import
copy
import
copy
import
logging
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_det
as
predict_det
import
tools.infer.predict_det
as
predict_det
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
from
tools.infer.predict_system
import
sorted_boxes
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
from
ppstructure.table.matcher
import
distance
,
compute_iou
from
ppstructure.table.matcher
import
TableMatch
from
ppstructure.table.table_master_match
import
TableMasterMatcher
from
ppstructure.utility
import
parse_args
from
ppstructure.utility
import
parse_args
import
ppstructure.table.predict_structure
as
predict_strture
import
ppstructure.table.predict_structure
as
predict_strture
...
@@ -55,11 +58,20 @@ def expand(pix, det_box, shape):
...
@@ -55,11 +58,20 @@ def expand(pix, det_box, shape):
class
TableSystem
(
object
):
class
TableSystem
(
object
):
def
__init__
(
self
,
args
,
text_detector
=
None
,
text_recognizer
=
None
):
def
__init__
(
self
,
args
,
text_detector
=
None
,
text_recognizer
=
None
):
if
not
args
.
show_log
:
logger
.
setLevel
(
logging
.
INFO
)
self
.
text_detector
=
predict_det
.
TextDetector
(
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
if
text_detector
is
None
else
text_detector
args
)
if
text_detector
is
None
else
text_detector
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
if
text_recognizer
is
None
else
text_recognizer
args
)
if
text_recognizer
is
None
else
text_recognizer
self
.
table_structurer
=
predict_strture
.
TableStructurer
(
args
)
self
.
table_structurer
=
predict_strture
.
TableStructurer
(
args
)
if
args
.
table_algorithm
in
[
'TableMaster'
]:
self
.
match
=
TableMasterMatcher
()
else
:
self
.
match
=
TableMatch
()
self
.
benchmark
=
args
.
benchmark
self
.
benchmark
=
args
.
benchmark
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
utility
.
create_predictor
(
args
,
'table'
,
logger
)
args
,
'table'
,
logger
)
...
@@ -85,16 +97,47 @@ class TableSystem(object):
...
@@ -85,16 +97,47 @@ class TableSystem(object):
def
__call__
(
self
,
img
,
return_ocr_result_in_table
=
False
):
def
__call__
(
self
,
img
,
return_ocr_result_in_table
=
False
):
result
=
dict
()
result
=
dict
()
ori_im
=
img
.
copy
()
time_dict
=
{
'det'
:
0
,
'rec'
:
0
,
'table'
:
0
,
'all'
:
0
,
'match'
:
0
}
start
=
time
.
time
()
structure_res
,
elapse
=
self
.
_structure
(
copy
.
deepcopy
(
img
))
time_dict
[
'table'
]
=
elapse
dt_boxes
,
rec_res
,
det_elapse
,
rec_elapse
=
self
.
_ocr
(
copy
.
deepcopy
(
img
))
time_dict
[
'det'
]
=
det_elapse
time_dict
[
'rec'
]
=
rec_elapse
if
return_ocr_result_in_table
:
result
[
'boxes'
]
=
dt_boxes
#[x.tolist() for x in dt_boxes]
result
[
'rec_res'
]
=
rec_res
tic
=
time
.
time
()
pred_html
=
self
.
match
(
structure_res
,
dt_boxes
,
rec_res
)
toc
=
time
.
time
()
time_dict
[
'match'
]
=
toc
-
tic
# pred_html = self.match(1, 1, 1,img_name)
result
[
'html'
]
=
pred_html
if
self
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
end
=
time
.
time
()
time_dict
[
'all'
]
=
end
-
start
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
return
result
,
time_dict
def
_structure
(
self
,
img
):
if
self
.
benchmark
:
if
self
.
benchmark
:
self
.
autolog
.
times
.
start
()
self
.
autolog
.
times
.
start
()
structure_res
,
elapse
=
self
.
table_structurer
(
copy
.
deepcopy
(
img
))
structure_res
,
elapse
=
self
.
table_structurer
(
copy
.
deepcopy
(
img
))
return
structure_res
,
elapse
def
_ocr
(
self
,
img
):
if
self
.
benchmark
:
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
self
.
autolog
.
times
.
stamp
()
dt_boxes
,
elapse
=
self
.
text_detector
(
copy
.
deepcopy
(
img
))
dt_boxes
,
det_
elapse
=
self
.
text_detector
(
copy
.
deepcopy
(
img
))
dt_boxes
=
sorted_boxes
(
dt_boxes
)
dt_boxes
=
sorted_boxes
(
dt_boxes
)
if
return_ocr_result_in_table
:
result
[
'boxes'
]
=
[
x
.
tolist
()
for
x
in
dt_boxes
]
r_boxes
=
[]
r_boxes
=
[]
for
box
in
dt_boxes
:
for
box
in
dt_boxes
:
x_min
=
box
[:,
0
].
min
()
-
1
x_min
=
box
[:,
0
].
min
()
-
1
...
@@ -105,125 +148,20 @@ class TableSystem(object):
...
@@ -105,125 +148,20 @@ class TableSystem(object):
r_boxes
.
append
(
box
)
r_boxes
.
append
(
box
)
dt_boxes
=
np
.
array
(
r_boxes
)
dt_boxes
=
np
.
array
(
r_boxes
)
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
len
(
dt_boxes
),
det_
elapse
))
if
dt_boxes
is
None
:
if
dt_boxes
is
None
:
return
None
,
None
return
None
,
None
img_crop_list
=
[]
img_crop_list
=
[]
for
i
in
range
(
len
(
dt_boxes
)):
for
i
in
range
(
len
(
dt_boxes
)):
det_box
=
dt_boxes
[
i
]
det_box
=
dt_boxes
[
i
]
x0
,
y0
,
x1
,
y1
=
expand
(
2
,
det_box
,
ori_im
.
shape
)
x0
,
y0
,
x1
,
y1
=
expand
(
2
,
det_box
,
img
.
shape
)
text_rect
=
ori_im
[
int
(
y0
):
int
(
y1
),
int
(
x0
):
int
(
x1
),
:]
text_rect
=
img
[
int
(
y0
):
int
(
y1
),
int
(
x0
):
int
(
x1
),
:]
img_crop_list
.
append
(
text_rect
)
img_crop_list
.
append
(
text_rect
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
rec_res
,
rec_
elapse
=
self
.
text_recognizer
(
img_crop_list
)
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
len
(
rec_res
),
rec_elapse
))
if
self
.
benchmark
:
return
dt_boxes
,
rec_res
,
det_elapse
,
rec_elapse
self
.
autolog
.
times
.
stamp
()
if
return_ocr_result_in_table
:
result
[
'rec_res'
]
=
rec_res
pred_html
,
pred
=
self
.
rebuild_table
(
structure_res
,
dt_boxes
,
rec_res
)
result
[
'html'
]
=
pred_html
if
self
.
benchmark
:
self
.
autolog
.
times
.
end
(
stamp
=
True
)
return
result
def
rebuild_table
(
self
,
structure_res
,
dt_boxes
,
rec_res
):
pred_structures
,
pred_bboxes
=
structure_res
dt_boxes
,
rec_res
=
self
.
filter_ocr_result
(
pred_bboxes
,
dt_boxes
,
rec_res
)
matched_index
=
self
.
match_result
(
dt_boxes
,
pred_bboxes
)
pred_html
,
pred
=
self
.
get_pred_html
(
pred_structures
,
matched_index
,
rec_res
)
return
pred_html
,
pred
def
filter_ocr_result
(
self
,
pred_bboxes
,
dt_boxes
,
rec_res
):
y1
=
pred_bboxes
[:,
1
::
2
].
min
()
new_dt_boxes
=
[]
new_rec_res
=
[]
for
box
,
rec
in
zip
(
dt_boxes
,
rec_res
):
if
np
.
max
(
box
[
1
::
2
])
<
y1
:
continue
new_dt_boxes
.
append
(
box
)
new_rec_res
.
append
(
rec
)
return
new_dt_boxes
,
new_rec_res
def
match_result
(
self
,
dt_boxes
,
pred_bboxes
):
matched
=
{}
for
i
,
gt_box
in
enumerate
(
dt_boxes
):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
((
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)
))
# 获取两两cell之间的L1距离和 1- IOU
sorted_distances
=
distances
.
copy
()
# 根据距离和IOU挑选最"近"的cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
def
get_pred_html
(
self
,
pred_structures
,
matched_index
,
ocr_contents
):
end_html
=
[]
td_index
=
0
for
tag
in
pred_structures
:
if
'</td>'
in
tag
:
if
td_index
in
matched_index
.
keys
():
b_with
=
False
if
'<b>'
in
ocr_contents
[
matched_index
[
td_index
][
0
]]
and
len
(
matched_index
[
td_index
])
>
1
:
b_with
=
True
end_html
.
extend
(
'<b>'
)
for
i
,
td_index_index
in
enumerate
(
matched_index
[
td_index
]):
content
=
ocr_contents
[
td_index_index
][
0
]
if
len
(
matched_index
[
td_index
])
>
1
:
if
len
(
content
)
==
0
:
continue
if
content
[
0
]
==
' '
:
content
=
content
[
1
:]
if
'<b>'
in
content
:
content
=
content
[
3
:]
if
'</b>'
in
content
:
content
=
content
[:
-
4
]
if
len
(
content
)
==
0
:
continue
if
i
!=
len
(
matched_index
[
td_index
])
-
1
and
' '
!=
content
[
-
1
]:
content
+=
' '
end_html
.
extend
(
content
)
if
b_with
:
end_html
.
extend
(
'</b>'
)
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
def
sorted_boxes
(
dt_boxes
):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes
=
dt_boxes
.
shape
[
0
]
sorted_boxes
=
sorted
(
dt_boxes
,
key
=
lambda
x
:
(
x
[
0
][
1
],
x
[
0
][
0
]))
_boxes
=
list
(
sorted_boxes
)
for
i
in
range
(
num_boxes
-
1
):
if
abs
(
_boxes
[
i
+
1
][
0
][
1
]
-
_boxes
[
i
][
0
][
1
])
<
10
and
\
(
_boxes
[
i
+
1
][
0
][
0
]
<
_boxes
[
i
][
0
][
0
]):
tmp
=
_boxes
[
i
]
_boxes
[
i
]
=
_boxes
[
i
+
1
]
_boxes
[
i
+
1
]
=
tmp
return
_boxes
def
to_excel
(
html_table
,
excel_path
):
def
to_excel
(
html_table
,
excel_path
):
...
@@ -249,7 +187,7 @@ def main(args):
...
@@ -249,7 +187,7 @@ def main(args):
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
continue
continue
starttime
=
time
.
time
()
starttime
=
time
.
time
()
pred_res
=
text_sys
(
img
)
pred_res
,
_
=
text_sys
(
img
)
pred_html
=
pred_res
[
'html'
]
pred_html
=
pred_res
[
'html'
]
logger
.
info
(
pred_html
)
logger
.
info
(
pred_html
)
to_excel
(
pred_html
,
excel_path
)
to_excel
(
pred_html
,
excel_path
)
...
...
ppstructure/table/table_master_match.py
0 → 100644
浏览文件 @
ddaa2c25
import
os
import
re
import
cv2
import
glob
import
copy
import
math
import
pickle
import
numpy
as
np
from
shapely.geometry
import
Polygon
,
MultiPoint
"""
Useful function in matching.
"""
def
remove_empty_bboxes
(
bboxes
):
"""
remove [0., 0., 0., 0.] in structure master bboxes.
len(bboxes.shape) must be 2.
:param bboxes:
:return:
"""
new_bboxes
=
[]
for
bbox
in
bboxes
:
if
sum
(
bbox
)
==
0.
:
continue
new_bboxes
.
append
(
bbox
)
return
np
.
array
(
new_bboxes
)
def
xywh2xyxy
(
bboxes
):
if
len
(
bboxes
.
shape
)
==
1
:
new_bboxes
=
np
.
empty_like
(
bboxes
)
new_bboxes
[
0
]
=
bboxes
[
0
]
-
bboxes
[
2
]
/
2
new_bboxes
[
1
]
=
bboxes
[
1
]
-
bboxes
[
3
]
/
2
new_bboxes
[
2
]
=
bboxes
[
0
]
+
bboxes
[
2
]
/
2
new_bboxes
[
3
]
=
bboxes
[
1
]
+
bboxes
[
3
]
/
2
return
new_bboxes
elif
len
(
bboxes
.
shape
)
==
2
:
new_bboxes
=
np
.
empty_like
(
bboxes
)
new_bboxes
[:,
0
]
=
bboxes
[:,
0
]
-
bboxes
[:,
2
]
/
2
new_bboxes
[:,
1
]
=
bboxes
[:,
1
]
-
bboxes
[:,
3
]
/
2
new_bboxes
[:,
2
]
=
bboxes
[:,
0
]
+
bboxes
[:,
2
]
/
2
new_bboxes
[:,
3
]
=
bboxes
[:,
1
]
+
bboxes
[:,
3
]
/
2
return
new_bboxes
else
:
raise
ValueError
def
xyxy2xywh
(
bboxes
):
if
len
(
bboxes
.
shape
)
==
1
:
new_bboxes
=
np
.
empty_like
(
bboxes
)
new_bboxes
[
0
]
=
bboxes
[
0
]
+
(
bboxes
[
2
]
-
bboxes
[
0
])
/
2
new_bboxes
[
1
]
=
bboxes
[
1
]
+
(
bboxes
[
3
]
-
bboxes
[
1
])
/
2
new_bboxes
[
2
]
=
bboxes
[
2
]
-
bboxes
[
0
]
new_bboxes
[
3
]
=
bboxes
[
3
]
-
bboxes
[
1
]
return
new_bboxes
elif
len
(
bboxes
.
shape
)
==
2
:
new_bboxes
=
np
.
empty_like
(
bboxes
)
new_bboxes
[:,
0
]
=
bboxes
[:,
0
]
+
(
bboxes
[:,
2
]
-
bboxes
[:,
0
])
/
2
new_bboxes
[:,
1
]
=
bboxes
[:,
1
]
+
(
bboxes
[:,
3
]
-
bboxes
[:,
1
])
/
2
new_bboxes
[:,
2
]
=
bboxes
[:,
2
]
-
bboxes
[:,
0
]
new_bboxes
[:,
3
]
=
bboxes
[:,
3
]
-
bboxes
[:,
1
]
return
new_bboxes
else
:
raise
ValueError
def
pickle_load
(
path
,
prefix
=
'end2end'
):
if
os
.
path
.
isfile
(
path
):
data
=
pickle
.
load
(
open
(
path
,
'rb'
))
elif
os
.
path
.
isdir
(
path
):
data
=
dict
()
search_path
=
os
.
path
.
join
(
path
,
'{}_*.pkl'
.
format
(
prefix
))
pkls
=
glob
.
glob
(
search_path
)
for
pkl
in
pkls
:
this_data
=
pickle
.
load
(
open
(
pkl
,
'rb'
))
data
.
update
(
this_data
)
else
:
raise
ValueError
return
data
def
convert_coord
(
xyxy
):
"""
Convert two points format to four points format.
:param xyxy:
:return:
"""
new_bbox
=
np
.
zeros
([
4
,
2
],
dtype
=
np
.
float32
)
new_bbox
[
0
,
0
],
new_bbox
[
0
,
1
]
=
xyxy
[
0
],
xyxy
[
1
]
new_bbox
[
1
,
0
],
new_bbox
[
1
,
1
]
=
xyxy
[
2
],
xyxy
[
1
]
new_bbox
[
2
,
0
],
new_bbox
[
2
,
1
]
=
xyxy
[
2
],
xyxy
[
3
]
new_bbox
[
3
,
0
],
new_bbox
[
3
,
1
]
=
xyxy
[
0
],
xyxy
[
3
]
return
new_bbox
def
cal_iou
(
bbox1
,
bbox2
):
bbox1_poly
=
Polygon
(
bbox1
).
convex_hull
bbox2_poly
=
Polygon
(
bbox2
).
convex_hull
union_poly
=
np
.
concatenate
((
bbox1
,
bbox2
))
if
not
bbox1_poly
.
intersects
(
bbox2_poly
):
iou
=
0
else
:
inter_area
=
bbox1_poly
.
intersection
(
bbox2_poly
).
area
union_area
=
MultiPoint
(
union_poly
).
convex_hull
.
area
if
union_area
==
0
:
iou
=
0
else
:
iou
=
float
(
inter_area
)
/
union_area
return
iou
def
cal_distance
(
p1
,
p2
):
delta_x
=
p1
[
0
]
-
p2
[
0
]
delta_y
=
p1
[
1
]
-
p2
[
1
]
d
=
math
.
sqrt
((
delta_x
**
2
)
+
(
delta_y
**
2
))
return
d
def
is_inside
(
center_point
,
corner_point
):
"""
Find if center_point inside the bbox(corner_point) or not.
:param center_point: center point (x, y)
:param corner_point: corner point ((x1,y1),(x2,y2))
:return:
"""
x_flag
=
False
y_flag
=
False
if
(
center_point
[
0
]
>=
corner_point
[
0
][
0
])
and
(
center_point
[
0
]
<=
corner_point
[
1
][
0
]):
x_flag
=
True
if
(
center_point
[
1
]
>=
corner_point
[
0
][
1
])
and
(
center_point
[
1
]
<=
corner_point
[
1
][
1
]):
y_flag
=
True
if
x_flag
and
y_flag
:
return
True
else
:
return
False
def
find_no_match
(
match_list
,
all_end2end_nums
,
type
=
'end2end'
):
"""
Find out no match end2end bbox in previous match list.
:param match_list: matching pairs.
:param all_end2end_nums: numbers of end2end_xywh
:param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
:return: no match pse bbox index list
"""
if
type
==
'end2end'
:
idx
=
0
elif
type
==
'master'
:
idx
=
1
else
:
raise
ValueError
no_match_indexs
=
[]
# m[0] is end2end index m[1] is master index
matched_bbox_indexs
=
[
m
[
idx
]
for
m
in
match_list
]
for
n
in
range
(
all_end2end_nums
):
if
n
not
in
matched_bbox_indexs
:
no_match_indexs
.
append
(
n
)
return
no_match_indexs
def
is_abs_lower_than_threshold
(
this_bbox
,
target_bbox
,
threshold
=
3
):
# only consider y axis, for grouping in row.
delta
=
abs
(
this_bbox
[
1
]
-
target_bbox
[
1
])
if
delta
<
threshold
:
return
True
else
:
return
False
def
sort_line_bbox
(
g
,
bg
):
"""
Sorted the bbox in the same line(group)
compare coord 'x' value, where 'y' value is closed in the same group.
:param g: index in the same group
:param bg: bbox in the same group
:return:
"""
xs
=
[
bg_item
[
0
]
for
bg_item
in
bg
]
xs_sorted
=
sorted
(
xs
)
g_sorted
=
[
None
]
*
len
(
xs_sorted
)
bg_sorted
=
[
None
]
*
len
(
xs_sorted
)
for
g_item
,
bg_item
in
zip
(
g
,
bg
):
idx
=
xs_sorted
.
index
(
bg_item
[
0
])
bg_sorted
[
idx
]
=
bg_item
g_sorted
[
idx
]
=
g_item
return
g_sorted
,
bg_sorted
def
flatten
(
sorted_groups
,
sorted_bbox_groups
):
idxs
=
[]
bboxes
=
[]
for
group
,
bbox_group
in
zip
(
sorted_groups
,
sorted_bbox_groups
):
for
g
,
bg
in
zip
(
group
,
bbox_group
):
idxs
.
append
(
g
)
bboxes
.
append
(
bg
)
return
idxs
,
bboxes
def
sort_bbox
(
end2end_xywh_bboxes
,
no_match_end2end_indexes
):
"""
This function will group the render end2end bboxes in row.
:param end2end_xywh_bboxes:
:param no_match_end2end_indexes:
:return:
"""
groups
=
[]
bbox_groups
=
[]
for
index
,
end2end_xywh_bbox
in
zip
(
no_match_end2end_indexes
,
end2end_xywh_bboxes
):
this_bbox
=
end2end_xywh_bbox
if
len
(
groups
)
==
0
:
groups
.
append
([
index
])
bbox_groups
.
append
([
this_bbox
])
else
:
flag
=
False
for
g
,
bg
in
zip
(
groups
,
bbox_groups
):
# this_bbox is belong to bg's row or not
if
is_abs_lower_than_threshold
(
this_bbox
,
bg
[
0
]):
g
.
append
(
index
)
bg
.
append
(
this_bbox
)
flag
=
True
break
if
not
flag
:
# this_bbox is not belong to bg's row, create a row.
groups
.
append
([
index
])
bbox_groups
.
append
([
this_bbox
])
# sorted bboxes in a group
tmp_groups
,
tmp_bbox_groups
=
[],
[]
for
g
,
bg
in
zip
(
groups
,
bbox_groups
):
g_sorted
,
bg_sorted
=
sort_line_bbox
(
g
,
bg
)
tmp_groups
.
append
(
g_sorted
)
tmp_bbox_groups
.
append
(
bg_sorted
)
# sorted groups, sort by coord y's value.
sorted_groups
=
[
None
]
*
len
(
tmp_groups
)
sorted_bbox_groups
=
[
None
]
*
len
(
tmp_bbox_groups
)
ys
=
[
bg
[
0
][
1
]
for
bg
in
tmp_bbox_groups
]
sorted_ys
=
sorted
(
ys
)
for
g
,
bg
in
zip
(
tmp_groups
,
tmp_bbox_groups
):
idx
=
sorted_ys
.
index
(
bg
[
0
][
1
])
sorted_groups
[
idx
]
=
g
sorted_bbox_groups
[
idx
]
=
bg
# flatten, get final result
end2end_sorted_idx_list
,
end2end_sorted_bbox_list
\
=
flatten
(
sorted_groups
,
sorted_bbox_groups
)
# check sorted
#img = cv2.imread('/data_0/yejiaquan/data/TableRecognization/singleVal/PMC3286376_004_00.png')
#img = drawBboxAfterSorted(img, sorted_groups, sorted_bbox_groups)
return
end2end_sorted_idx_list
,
end2end_sorted_bbox_list
,
sorted_groups
,
sorted_bbox_groups
def
get_bboxes_list
(
end2end_result
,
structure_master_result
):
"""
This function is use to convert end2end results and structure master results to
List of xyxy bbox format and List of xywh bbox format
:param end2end_result: bbox's format is xyxy
:param structure_master_result: bbox's format is xywh
:return: 4 kind list of bbox ()
"""
# end2end
end2end_xyxy_list
=
[]
end2end_xywh_list
=
[]
for
end2end_item
in
end2end_result
:
src_bbox
=
end2end_item
[
'bbox'
]
end2end_xyxy_list
.
append
(
src_bbox
)
xywh_bbox
=
xyxy2xywh
(
src_bbox
)
end2end_xywh_list
.
append
(
xywh_bbox
)
end2end_xyxy_bboxes
=
np
.
array
(
end2end_xyxy_list
)
end2end_xywh_bboxes
=
np
.
array
(
end2end_xywh_list
)
# structure master
src_bboxes
=
structure_master_result
[
'bbox'
]
src_bboxes
=
remove_empty_bboxes
(
src_bboxes
)
# structure_master_xywh_bboxes = src_bboxes
# xyxy_bboxes = xywh2xyxy(src_bboxes)
# structure_master_xyxy_bboxes = xyxy_bboxes
structure_master_xyxy_bboxes
=
src_bboxes
xywh_bbox
=
xyxy2xywh
(
src_bboxes
)
structure_master_xywh_bboxes
=
xywh_bbox
return
end2end_xyxy_bboxes
,
end2end_xywh_bboxes
,
structure_master_xywh_bboxes
,
structure_master_xyxy_bboxes
def
center_rule_match
(
end2end_xywh_bboxes
,
structure_master_xyxy_bboxes
):
"""
Judge end2end Bbox's center point is inside structure master Bbox or not,
if end2end Bbox's center is in structure master Bbox, get matching pair.
:param end2end_xywh_bboxes:
:param structure_master_xyxy_bboxes:
:return: match pairs list, e.g. [[0,1], [1,2], ...]
"""
match_pairs_list
=
[]
for
i
,
end2end_xywh
in
enumerate
(
end2end_xywh_bboxes
):
for
j
,
master_xyxy
in
enumerate
(
structure_master_xyxy_bboxes
):
x_end2end
,
y_end2end
=
end2end_xywh
[
0
],
end2end_xywh
[
1
]
x_master1
,
y_master1
,
x_master2
,
y_master2
\
=
master_xyxy
[
0
],
master_xyxy
[
1
],
master_xyxy
[
2
],
master_xyxy
[
3
]
center_point_end2end
=
(
x_end2end
,
y_end2end
)
corner_point_master
=
((
x_master1
,
y_master1
),
(
x_master2
,
y_master2
))
if
is_inside
(
center_point_end2end
,
corner_point_master
):
match_pairs_list
.
append
([
i
,
j
])
return
match_pairs_list
def
iou_rule_match
(
end2end_xyxy_bboxes
,
end2end_xyxy_indexes
,
structure_master_xyxy_bboxes
):
"""
Use iou to find matching list.
choose max iou value bbox as match pair.
:param end2end_xyxy_bboxes:
:param end2end_xyxy_indexes: original end2end indexes.
:param structure_master_xyxy_bboxes:
:return: match pairs list, e.g. [[0,1], [1,2], ...]
"""
match_pair_list
=
[]
for
end2end_xyxy_index
,
end2end_xyxy
in
zip
(
end2end_xyxy_indexes
,
end2end_xyxy_bboxes
):
max_iou
=
0
max_match
=
[
None
,
None
]
for
j
,
master_xyxy
in
enumerate
(
structure_master_xyxy_bboxes
):
end2end_4xy
=
convert_coord
(
end2end_xyxy
)
master_4xy
=
convert_coord
(
master_xyxy
)
iou
=
cal_iou
(
end2end_4xy
,
master_4xy
)
if
iou
>
max_iou
:
max_match
[
0
],
max_match
[
1
]
=
end2end_xyxy_index
,
j
max_iou
=
iou
if
max_match
[
0
]
is
None
:
# no match
continue
match_pair_list
.
append
(
max_match
)
return
match_pair_list
def
distance_rule_match
(
end2end_indexes
,
end2end_bboxes
,
master_indexes
,
master_bboxes
):
"""
Get matching between no-match end2end bboxes and no-match master bboxes.
Use min distance to match.
This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0)
It will Return master_bboxes_nums match-pairs.
:param end2end_indexes:
:param end2end_bboxes:
:param master_indexes:
:param master_bboxes:
:return: match_pairs list, e.g. [[0,1], [1,2], ...]
"""
min_match_list
=
[]
for
j
,
master_bbox
in
zip
(
master_indexes
,
master_bboxes
):
min_distance
=
np
.
inf
min_match
=
[
0
,
0
]
# i, j
for
i
,
end2end_bbox
in
zip
(
end2end_indexes
,
end2end_bboxes
):
x_end2end
,
y_end2end
=
end2end_bbox
[
0
],
end2end_bbox
[
1
]
x_master
,
y_master
=
master_bbox
[
0
],
master_bbox
[
1
]
end2end_point
=
(
x_end2end
,
y_end2end
)
master_point
=
(
x_master
,
y_master
)
dist
=
cal_distance
(
master_point
,
end2end_point
)
if
dist
<
min_distance
:
min_match
[
0
],
min_match
[
1
]
=
i
,
j
min_distance
=
dist
min_match_list
.
append
(
min_match
)
return
min_match_list
def
extra_match
(
no_match_end2end_indexes
,
master_bbox_nums
):
"""
This function will create some virtual master bboxes,
and get match with the no match end2end indexes.
:param no_match_end2end_indexes:
:param master_bbox_nums:
:return:
"""
end_nums
=
len
(
no_match_end2end_indexes
)
+
master_bbox_nums
extra_match_list
=
[]
for
i
in
range
(
master_bbox_nums
,
end_nums
):
end2end_index
=
no_match_end2end_indexes
[
i
-
master_bbox_nums
]
extra_match_list
.
append
([
end2end_index
,
i
])
return
extra_match_list
def
match_visual
(
file_name
,
match_list
,
end2end_xyxy
,
master_xyxy
,
prex
=
'ordinary_match'
):
"""
Show the match result by xyxy coord style.
:param file_name:
:param match_list:
:param end2end_xyxy:
:param master_xyxy:
:param prex:
:return:
"""
folder
=
''
save_folder
=
'/data_0/cache'
file_path
=
os
.
path
.
join
(
folder
,
file_name
)
img_end2end
=
cv2
.
imread
(
file_path
)
img_master
=
copy
.
deepcopy
(
img_end2end
)
text_color
=
(
0
,
0
,
255
)
bbox_color
=
(
255
,
0
,
0
)
master_nums
=
len
(
master_xyxy
)
for
idx
,
match_group
in
enumerate
(
match_list
):
end2end_idx
,
master_index
=
match_group
[
0
],
match_group
[
1
]
# master_index larger than master_nums, did not draw master bbox.
if
master_index
<
master_nums
:
# draw master
master_bbox
=
master_xyxy
[
master_index
]
img_master
=
cv2
.
rectangle
(
img_master
,
(
int
(
master_bbox
[
0
]),
int
(
master_bbox
[
1
])),
(
int
(
master_bbox
[
2
]),
int
(
master_bbox
[
3
])),
bbox_color
,
thickness
=
1
)
master_text_coord
=
(
int
(
master_bbox
[
0
])
-
4
,
int
(
master_bbox
[
1
]))
img_master
=
cv2
.
putText
(
img_master
,
str
(
master_index
),
master_text_coord
,
1
,
1
,
text_color
,
2
)
# draw end2end
end2end_bbox
=
end2end_xyxy
[
end2end_idx
]
img_end2end
=
cv2
.
rectangle
(
img_end2end
,
(
int
(
end2end_bbox
[
0
]),
int
(
end2end_bbox
[
1
])),
(
int
(
end2end_bbox
[
2
]),
int
(
end2end_bbox
[
3
])),
bbox_color
,
thickness
=
1
)
end2end_text_coord
=
(
int
(
end2end_bbox
[
0
])
-
4
,
int
(
end2end_bbox
[
1
]))
# write end2end bbox matching master bbox's index
img_end2end
=
cv2
.
putText
(
img_end2end
,
str
(
master_index
),
end2end_text_coord
,
1
,
1
,
text_color
,
2
)
img
=
np
.
hstack
([
img_end2end
,
img_master
])
save_path
=
os
.
path
.
join
(
save_folder
,
'{}_matchShow.png'
.
format
(
prex
))
cv2
.
imwrite
(
save_path
,
img
)
def
get_match_dict
(
match_list
):
"""
Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
:param match_list:
:return:
"""
match_dict
=
dict
()
for
match_pair
in
match_list
:
end2end_index
,
master_index
=
match_pair
[
0
],
match_pair
[
1
]
if
master_index
not
in
match_dict
.
keys
():
match_dict
[
master_index
]
=
[
end2end_index
]
else
:
match_dict
[
master_index
].
append
(
end2end_index
)
return
match_dict
def
deal_successive_space
(
text
):
"""
deal successive space character for text
1. Replace ' '*3 with '<space>' which is real space is text
2. Remove ' ', which is split token, not true space
3. Replace '<space>' with ' ', to get real text
:param text:
:return:
"""
text
=
text
.
replace
(
' '
*
3
,
'<space>'
)
text
=
text
.
replace
(
' '
,
''
)
text
=
text
.
replace
(
'<space>'
,
' '
)
return
text
def
reduce_repeat_bb
(
text_list
,
break_token
):
"""
convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>']
PS: maybe style <i>Local</i> is also exist, too. it can be processed like this.
:param text_list:
:param break_token:
:return:
"""
count
=
0
for
text
in
text_list
:
if
text
.
startswith
(
'<b>'
):
count
+=
1
if
count
==
len
(
text_list
):
new_text_list
=
[]
for
text
in
text_list
:
text
=
text
.
replace
(
'<b>'
,
''
).
replace
(
'</b>'
,
''
)
new_text_list
.
append
(
text
)
return
[
'<b>'
+
break_token
.
join
(
new_text_list
)
+
'</b>'
]
else
:
return
text_list
def
get_match_text_dict
(
match_dict
,
end2end_info
,
break_token
=
' '
):
match_text_dict
=
dict
()
for
master_index
,
end2end_index_list
in
match_dict
.
items
():
text_list
=
[
end2end_info
[
end2end_index
][
'text'
]
for
end2end_index
in
end2end_index_list
]
text_list
=
reduce_repeat_bb
(
text_list
,
break_token
)
text
=
break_token
.
join
(
text_list
)
match_text_dict
[
master_index
]
=
text
return
match_text_dict
def
merge_span_token
(
master_token_list
):
"""
Merge the span style token (row span or col span).
:param master_token_list:
:return:
"""
new_master_token_list
=
[]
pointer
=
0
if
master_token_list
[
-
1
]
!=
'</tbody>'
:
master_token_list
.
append
(
'</tbody>'
)
while
master_token_list
[
pointer
]
!=
'</tbody>'
:
try
:
if
master_token_list
[
pointer
]
==
'<td'
:
if
master_token_list
[
pointer
+
1
].
startswith
(
' colspan='
)
or
master_token_list
[
pointer
+
1
].
startswith
(
' rowspan='
):
"""
example:
pattern <td colspan="3">
'<td' + 'colspan=" "' + '>' + '</td>'
"""
# tmp = master_token_list[pointer] + master_token_list[pointer+1] + master_token_list[pointer+2] + \
# master_token_list[pointer+3]
tmp
=
''
.
join
(
master_token_list
[
pointer
:
pointer
+
3
+
1
])
pointer
+=
4
new_master_token_list
.
append
(
tmp
)
elif
master_token_list
[
pointer
+
2
].
startswith
(
' colspan='
)
or
master_token_list
[
pointer
+
2
].
startswith
(
' rowspan='
):
"""
example:
pattern <td rowspan="2" colspan="3">
'<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>'
"""
# tmp = master_token_list[pointer] + master_token_list[pointer+1] + \
# master_token_list[pointer+2] + master_token_list[pointer+3] + master_token_list[pointer+4]
tmp
=
''
.
join
(
master_token_list
[
pointer
:
pointer
+
4
+
1
])
pointer
+=
5
new_master_token_list
.
append
(
tmp
)
else
:
new_master_token_list
.
append
(
master_token_list
[
pointer
])
pointer
+=
1
else
:
new_master_token_list
.
append
(
master_token_list
[
pointer
])
pointer
+=
1
except
:
print
(
"Break in merge..."
)
break
new_master_token_list
.
append
(
'</tbody>'
)
return
new_master_token_list
def
deal_eb_token
(
master_token
):
"""
post process with <eb></eb>, <eb1></eb1>, ...
emptyBboxTokenDict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['
\\
u2028', '
\\
u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '
\\
u2028', ' ', '
\\
u2028', ' ', '</b>']": '<eb10></eb10>',
}
:param master_token:
:return:
"""
master_token
=
master_token
.
replace
(
'<eb></eb>'
,
'<td></td>'
)
master_token
=
master_token
.
replace
(
'<eb1></eb1>'
,
'<td> </td>'
)
master_token
=
master_token
.
replace
(
'<eb2></eb2>'
,
'<td><b> </b></td>'
)
master_token
=
master_token
.
replace
(
'<eb3></eb3>'
,
'<td>
\u2028\u2028
</td>'
)
master_token
=
master_token
.
replace
(
'<eb4></eb4>'
,
'<td><sup> </sup></td>'
)
master_token
=
master_token
.
replace
(
'<eb5></eb5>'
,
'<td><b></b></td>'
)
master_token
=
master_token
.
replace
(
'<eb6></eb6>'
,
'<td><i> </i></td>'
)
master_token
=
master_token
.
replace
(
'<eb7></eb7>'
,
'<td><b><i></i></b></td>'
)
master_token
=
master_token
.
replace
(
'<eb8></eb8>'
,
'<td><b><i> </i></b></td>'
)
master_token
=
master_token
.
replace
(
'<eb9></eb9>'
,
'<td><i></i></td>'
)
master_token
=
master_token
.
replace
(
'<eb10></eb10>'
,
'<td><b>
\u2028
\u2028
</b></td>'
)
return
master_token
def
insert_text_to_token
(
master_token_list
,
match_text_dict
):
"""
Insert OCR text result to structure token.
:param master_token_list:
:param match_text_dict:
:return:
"""
master_token_list
=
merge_span_token
(
master_token_list
)
merged_result_list
=
[]
text_count
=
0
for
master_token
in
master_token_list
:
if
master_token
.
startswith
(
'<td'
):
if
text_count
>
len
(
match_text_dict
)
-
1
:
text_count
+=
1
continue
elif
text_count
not
in
match_text_dict
.
keys
():
text_count
+=
1
continue
else
:
master_token
=
master_token
.
replace
(
'><'
,
'>{}<'
.
format
(
match_text_dict
[
text_count
]))
text_count
+=
1
master_token
=
deal_eb_token
(
master_token
)
merged_result_list
.
append
(
master_token
)
return
''
.
join
(
merged_result_list
)
def
deal_isolate_span
(
thead_part
):
"""
Deal with isolate span cases in this function.
It causes by wrong prediction in structure recognition model.
eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
:param thead_part:
:return:
"""
# 1. find out isolate span tokens.
isolate_pattern
=
"<td></td> rowspan=
\"
(\d)+
\"
colspan=
\"
(\d)+
\"
></b></td>|"
\
"<td></td> colspan=
\"
(\d)+
\"
rowspan=
\"
(\d)+
\"
></b></td>|"
\
"<td></td> rowspan=
\"
(\d)+
\"
></b></td>|"
\
"<td></td> colspan=
\"
(\d)+
\"
></b></td>"
isolate_iter
=
re
.
finditer
(
isolate_pattern
,
thead_part
)
isolate_list
=
[
i
.
group
()
for
i
in
isolate_iter
]
# 2. find out span number, by step 1 results.
span_pattern
=
" rowspan=
\"
(\d)+
\"
colspan=
\"
(\d)+
\"
|"
\
" colspan=
\"
(\d)+
\"
rowspan=
\"
(\d)+
\"
|"
\
" rowspan=
\"
(\d)+
\"
|"
\
" colspan=
\"
(\d)+
\"
"
corrected_list
=
[]
for
isolate_item
in
isolate_list
:
span_part
=
re
.
search
(
span_pattern
,
isolate_item
)
spanStr_in_isolateItem
=
span_part
.
group
()
# 3. merge the span number into the span token format string.
if
spanStr_in_isolateItem
is
not
None
:
corrected_item
=
'<td{}></td>'
.
format
(
spanStr_in_isolateItem
)
corrected_list
.
append
(
corrected_item
)
else
:
corrected_list
.
append
(
None
)
# 4. replace original isolated token.
for
corrected_item
,
isolate_item
in
zip
(
corrected_list
,
isolate_list
):
if
corrected_item
is
not
None
:
thead_part
=
thead_part
.
replace
(
isolate_item
,
corrected_item
)
else
:
pass
return
thead_part
def
deal_duplicate_bb
(
thead_part
):
"""
Deal duplicate <b> or </b> after replace.
Keep one <b></b> in a <td></td> token.
:param thead_part:
:return:
"""
# 1. find out <td></td> in <thead></thead>.
td_pattern
=
"<td rowspan=
\"
(\d)+
\"
colspan=
\"
(\d)+
\"
>(.+?)</td>|"
\
"<td colspan=
\"
(\d)+
\"
rowspan=
\"
(\d)+
\"
>(.+?)</td>|"
\
"<td rowspan=
\"
(\d)+
\"
>(.+?)</td>|"
\
"<td colspan=
\"
(\d)+
\"
>(.+?)</td>|"
\
"<td>(.*?)</td>"
td_iter
=
re
.
finditer
(
td_pattern
,
thead_part
)
td_list
=
[
t
.
group
()
for
t
in
td_iter
]
# 2. is multiply <b></b> in <td></td> or not?
new_td_list
=
[]
for
td_item
in
td_list
:
if
td_item
.
count
(
'<b>'
)
>
1
or
td_item
.
count
(
'</b>'
)
>
1
:
# multiply <b></b> in <td></td> case.
# 1. remove all <b></b>
td_item
=
td_item
.
replace
(
'<b>'
,
''
).
replace
(
'</b>'
,
''
)
# 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
td_item
=
td_item
.
replace
(
'<td>'
,
'<td><b>'
).
replace
(
'</td>'
,
'</b></td>'
)
new_td_list
.
append
(
td_item
)
else
:
new_td_list
.
append
(
td_item
)
# 3. replace original thead part.
for
td_item
,
new_td_item
in
zip
(
td_list
,
new_td_list
):
thead_part
=
thead_part
.
replace
(
td_item
,
new_td_item
)
return
thead_part
def
deal_bb
(
result_token
):
"""
In our opinion, <b></b> always occurs in <thead></thead> text's context.
This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
:param result_token:
:return:
"""
# find out <thead></thead> parts.
thead_pattern
=
'<thead>(.*?)</thead>'
if
re
.
search
(
thead_pattern
,
result_token
)
is
None
:
return
result_token
thead_part
=
re
.
search
(
thead_pattern
,
result_token
).
group
()
origin_thead_part
=
copy
.
deepcopy
(
thead_part
)
# check "rowspan" or "colspan" occur in <thead></thead> parts or not .
span_pattern
=
"<td rowspan=
\"
(\d)+
\"
colspan=
\"
(\d)+
\"
>|<td colspan=
\"
(\d)+
\"
rowspan=
\"
(\d)+
\"
>|<td rowspan=
\"
(\d)+
\"
>|<td colspan=
\"
(\d)+
\"
>"
span_iter
=
re
.
finditer
(
span_pattern
,
thead_part
)
span_list
=
[
s
.
group
()
for
s
in
span_iter
]
has_span_in_head
=
True
if
len
(
span_list
)
>
0
else
False
if
not
has_span_in_head
:
# <thead></thead> not include "rowspan" or "colspan" branch 1.
# 1. replace <td> to <td><b>, and </td> to </b></td>
# 2. it is possible to predict text include <b> or </b> by Text-line recognition,
# so we replace <b><b> to <b>, and </b></b> to </b>
thead_part
=
thead_part
.
replace
(
'<td>'
,
'<td><b>'
)
\
.
replace
(
'</td>'
,
'</b></td>'
)
\
.
replace
(
'<b><b>'
,
'<b>'
)
\
.
replace
(
'</b></b>'
,
'</b>'
)
else
:
# <thead></thead> include "rowspan" or "colspan" branch 2.
# Firstly, we deal rowspan or colspan cases.
# 1. replace > to ><b>
# 2. replace </td> to </b></td>
# 3. it is possible to predict text include <b> or </b> by Text-line recognition,
# so we replace <b><b> to <b>, and </b><b> to </b>
# Secondly, deal ordinary cases like branch 1
# replace ">" to "<b>"
replaced_span_list
=
[]
for
sp
in
span_list
:
replaced_span_list
.
append
(
sp
.
replace
(
'>'
,
'><b>'
))
for
sp
,
rsp
in
zip
(
span_list
,
replaced_span_list
):
thead_part
=
thead_part
.
replace
(
sp
,
rsp
)
# replace "</td>" to "</b></td>"
thead_part
=
thead_part
.
replace
(
'</td>'
,
'</b></td>'
)
# remove duplicated <b> by re.sub
mb_pattern
=
"(<b>)+"
single_b_string
=
"<b>"
thead_part
=
re
.
sub
(
mb_pattern
,
single_b_string
,
thead_part
)
mgb_pattern
=
"(</b>)+"
single_gb_string
=
"</b>"
thead_part
=
re
.
sub
(
mgb_pattern
,
single_gb_string
,
thead_part
)
# ordinary cases like branch 1
thead_part
=
thead_part
.
replace
(
'<td>'
,
'<td><b>'
).
replace
(
'<b><b>'
,
'<b>'
)
# convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
# but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
thead_part
=
thead_part
.
replace
(
'<td><b></b></td>'
,
'<td></td>'
)
# deal with duplicated <b></b>
thead_part
=
deal_duplicate_bb
(
thead_part
)
# deal with isolate span tokens, which causes by wrong predict by structure prediction.
# eg.PMC5994107_011_00.png
thead_part
=
deal_isolate_span
(
thead_part
)
# replace original result with new thead part.
result_token
=
result_token
.
replace
(
origin_thead_part
,
thead_part
)
return
result_token
class
Matcher
:
def
__init__
(
self
,
end2end_file
,
structure_master_file
):
"""
This class process the end2end results and structure recognition results.
:param end2end_file: end2end results predict by end2end inference.
:param structure_master_file: structure recognition results predict by structure master inference.
"""
self
.
end2end_file
=
end2end_file
self
.
structure_master_file
=
structure_master_file
self
.
end2end_results
=
pickle_load
(
end2end_file
,
prefix
=
'end2end'
)
self
.
structure_master_results
=
pickle_load
(
structure_master_file
,
prefix
=
'structure'
)
def
match
(
self
):
"""
Match process:
pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format.
1. Use pseBbox is inside masterBbox judge rule
2. Use iou between pseBbox and masterBbox rule
3. Use min distance of center point rule
:return:
"""
match_results
=
dict
()
for
idx
,
(
file_name
,
end2end_result
)
in
enumerate
(
self
.
end2end_results
.
items
()):
match_list
=
[]
if
file_name
not
in
self
.
structure_master_results
:
continue
structure_master_result
=
self
.
structure_master_results
[
file_name
]
end2end_xyxy_bboxes
,
end2end_xywh_bboxes
,
structure_master_xywh_bboxes
,
structure_master_xyxy_bboxes
=
\
get_bboxes_list
(
end2end_result
,
structure_master_result
)
# rule 1: center rule
center_rule_match_list
=
\
center_rule_match
(
end2end_xywh_bboxes
,
structure_master_xyxy_bboxes
)
match_list
.
extend
(
center_rule_match_list
)
# rule 2: iou rule
# firstly, find not match index in previous step.
center_no_match_end2end_indexs
=
\
find_no_match
(
match_list
,
len
(
end2end_xywh_bboxes
),
type
=
'end2end'
)
if
len
(
center_no_match_end2end_indexs
)
>
0
:
center_no_match_end2end_xyxy
=
end2end_xyxy_bboxes
[
center_no_match_end2end_indexs
]
# secondly, iou rule match
iou_rule_match_list
=
\
iou_rule_match
(
center_no_match_end2end_xyxy
,
center_no_match_end2end_indexs
,
structure_master_xyxy_bboxes
)
match_list
.
extend
(
iou_rule_match_list
)
# rule 3: distance rule
# match between no-match end2end bboxes and no-match master bboxes.
# it will return master_bboxes_nums match-pairs.
# firstly, find not match index in previous step.
centerIou_no_match_end2end_indexs
=
\
find_no_match
(
match_list
,
len
(
end2end_xywh_bboxes
),
type
=
'end2end'
)
centerIou_no_match_master_indexs
=
\
find_no_match
(
match_list
,
len
(
structure_master_xywh_bboxes
),
type
=
'master'
)
if
len
(
centerIou_no_match_master_indexs
)
>
0
and
len
(
centerIou_no_match_end2end_indexs
)
>
0
:
centerIou_no_match_end2end_xywh
=
end2end_xywh_bboxes
[
centerIou_no_match_end2end_indexs
]
centerIou_no_match_master_xywh
=
structure_master_xywh_bboxes
[
centerIou_no_match_master_indexs
]
distance_match_list
=
distance_rule_match
(
centerIou_no_match_end2end_indexs
,
centerIou_no_match_end2end_xywh
,
centerIou_no_match_master_indexs
,
centerIou_no_match_master_xywh
)
match_list
.
extend
(
distance_match_list
)
# TODO:
# The render no-match pseBbox, insert the last
# After step3 distance rule, a master bbox at least match one end2end bbox.
# But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length.
# For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
# The above extra insert bboxes will be further processed in "formatOutput" function.
# After this operation, it will increase TEDS score.
no_match_end2end_indexes
=
\
find_no_match
(
match_list
,
len
(
end2end_xywh_bboxes
),
type
=
'end2end'
)
if
len
(
no_match_end2end_indexes
)
>
0
:
no_match_end2end_xywh
=
end2end_xywh_bboxes
[
no_match_end2end_indexes
]
# sort the render no-match end2end bbox in row
end2end_sorted_indexes_list
,
end2end_sorted_bboxes_list
,
sorted_groups
,
sorted_bboxes_groups
=
\
sort_bbox
(
no_match_end2end_xywh
,
no_match_end2end_indexes
)
# make virtual master bboxes, and get matching with the no-match end2end bboxes.
extra_match_list
=
extra_match
(
end2end_sorted_indexes_list
,
len
(
structure_master_xywh_bboxes
))
match_list_add_extra_match
=
copy
.
deepcopy
(
match_list
)
match_list_add_extra_match
.
extend
(
extra_match_list
)
else
:
# no no-match end2end bboxes
match_list_add_extra_match
=
copy
.
deepcopy
(
match_list
)
sorted_groups
=
[]
sorted_bboxes_groups
=
[]
match_result_dict
=
{
'match_list'
:
match_list
,
'match_list_add_extra_match'
:
match_list_add_extra_match
,
'sorted_groups'
:
sorted_groups
,
'sorted_bboxes_groups'
:
sorted_bboxes_groups
}
# ordinary match show
# match_visual(file_name, match_list, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='ordinary_match')
# extra match show
# match_visual(file_name, match_list_add_extra_match, end2end_xyxy_bboxes, structure_master_xyxy_bboxes, prex='extra_match')
# format output
match_result_dict
=
self
.
_format
(
match_result_dict
,
file_name
)
match_results
[
file_name
]
=
match_result_dict
return
match_results
def
_format
(
self
,
match_result
,
file_name
):
"""
Extend the master token(insert virtual master token), and format matching result.
:param match_result:
:param file_name:
:return:
"""
end2end_info
=
self
.
end2end_results
[
file_name
]
master_info
=
self
.
structure_master_results
[
file_name
]
master_token
=
master_info
[
'text'
]
sorted_groups
=
match_result
[
'sorted_groups'
]
# creat virtual master token
virtual_master_token_list
=
[]
for
line_group
in
sorted_groups
:
tmp_list
=
[
'<tr>'
]
item_nums
=
len
(
line_group
)
for
_
in
range
(
item_nums
):
tmp_list
.
append
(
'<td></td>'
)
tmp_list
.
append
(
'</tr>'
)
virtual_master_token_list
.
extend
(
tmp_list
)
# insert virtual master token
master_token_list
=
master_token
.
split
(
','
)
if
master_token_list
[
-
1
]
==
'</tbody>'
:
# complete predict(no cut by max length)
# This situation insert virtual master token will drop TEDs score in val set.
# So we will not extend virtual token in this situation.
# fake extend virtual
master_token_list
[:
-
1
].
extend
(
virtual_master_token_list
)
# real extend virtual
# master_token_list = master_token_list[:-1]
# master_token_list.extend(virtual_master_token_list)
# master_token_list.append('</tbody>')
elif
master_token_list
[
-
1
]
==
'<td></td>'
:
master_token_list
.
append
(
'</tr>'
)
master_token_list
.
extend
(
virtual_master_token_list
)
master_token_list
.
append
(
'</tbody>'
)
else
:
master_token_list
.
extend
(
virtual_master_token_list
)
master_token_list
.
append
(
'</tbody>'
)
# format output
match_result
.
setdefault
(
'matched_master_token_list'
,
master_token_list
)
return
match_result
def
get_merge_result
(
self
,
match_results
):
"""
Merge the OCR result into structure token to get final results.
:param match_results:
:return:
"""
merged_results
=
dict
()
# break_token is linefeed token, when one master bbox has multiply end2end bboxes.
break_token
=
' '
for
idx
,
(
file_name
,
match_info
)
in
enumerate
(
match_results
.
items
()):
end2end_info
=
self
.
end2end_results
[
file_name
]
master_token_list
=
match_info
[
'matched_master_token_list'
]
match_list
=
match_info
[
'match_list_add_extra_match'
]
match_dict
=
get_match_dict
(
match_list
)
match_text_dict
=
get_match_text_dict
(
match_dict
,
end2end_info
,
break_token
)
merged_result
=
insert_text_to_token
(
master_token_list
,
match_text_dict
)
merged_result
=
deal_bb
(
merged_result
)
merged_results
[
file_name
]
=
merged_result
return
merged_results
class
TableMasterMatcher
(
Matcher
):
def
__init__
(
self
):
pass
def
__call__
(
self
,
structure_res
,
dt_boxes
,
rec_res
,
img_name
=
1
):
end2end_results
=
{
img_name
:
[]}
for
dt_box
,
res
in
zip
(
dt_boxes
,
rec_res
):
d
=
dict
(
bbox
=
np
.
array
(
dt_box
),
text
=
res
[
0
],
)
end2end_results
[
img_name
].
append
(
d
)
self
.
end2end_results
=
end2end_results
structure_master_result_dict
=
{
img_name
:
{}}
pred_structures
,
pred_bboxes
=
structure_res
pred_structures
=
','
.
join
(
pred_structures
[
3
:
-
3
])
structure_master_result_dict
[
img_name
][
'text'
]
=
pred_structures
structure_master_result_dict
[
img_name
][
'bbox'
]
=
pred_bboxes
self
.
structure_master_results
=
structure_master_result_dict
# match
match_results
=
self
.
match
()
merged_results
=
self
.
get_merge_result
(
match_results
)
pred_html
=
merged_results
[
img_name
]
# pred_html = '<html><body><table>' + pred_html + '</table></body></html>'
return
pred_html
ppstructure/utility.py
浏览文件 @
ddaa2c25
...
@@ -32,6 +32,7 @@ def init_args():
...
@@ -32,6 +32,7 @@ def init_args():
type
=
str
,
type
=
str
,
default
=
"../ppocr/utils/dict/table_structure_dict.txt"
)
default
=
"../ppocr/utils/dict/table_structure_dict.txt"
)
# params for layout
# params for layout
parser
.
add_argument
(
"--layout_model_dir"
,
type
=
str
)
parser
.
add_argument
(
parser
.
add_argument
(
"--layout_path_model"
,
"--layout_path_model"
,
type
=
str
,
type
=
str
,
...
@@ -87,7 +88,7 @@ def draw_structure_result(image, result, font_path):
...
@@ -87,7 +88,7 @@ def draw_structure_result(image, result, font_path):
image
=
Image
.
fromarray
(
image
)
image
=
Image
.
fromarray
(
image
)
boxes
,
txts
,
scores
=
[],
[],
[]
boxes
,
txts
,
scores
=
[],
[],
[]
for
region
in
result
:
for
region
in
result
:
if
region
[
'type'
]
==
'
T
able'
:
if
region
[
'type'
]
==
'
t
able'
:
pass
pass
else
:
else
:
for
text_result
in
region
[
'res'
]:
for
text_result
in
region
[
'res'
]:
...
...
test_tipc/configs/en_table_structure/table_mv3.yml
浏览文件 @
ddaa2c25
...
@@ -19,8 +19,6 @@ Global:
...
@@ -19,8 +19,6 @@ Global:
character_type
:
en
character_type
:
en
max_text_length
:
800
max_text_length
:
800
infer_mode
:
False
infer_mode
:
False
process_total_num
:
0
process_cut_num
:
0
Optimizer
:
Optimizer
:
name
:
Adam
name
:
Adam
...
...
test_tipc/configs/table_master/table_master.yml
浏览文件 @
ddaa2c25
...
@@ -16,8 +16,6 @@ Global:
...
@@ -16,8 +16,6 @@ Global:
character_dict_path
:
ppocr/utils/dict/table_master_structure_dict.txt
character_dict_path
:
ppocr/utils/dict/table_master_structure_dict.txt
infer_mode
:
false
infer_mode
:
false
max_text_length
:
500
max_text_length
:
500
process_total_num
:
0
process_cut_num
:
0
Optimizer
:
Optimizer
:
...
@@ -86,7 +84,7 @@ Train:
...
@@ -86,7 +84,7 @@ Train:
-
PaddingTableImage
:
-
PaddingTableImage
:
size
:
[
480
,
480
]
size
:
[
480
,
480
]
-
TableBoxEncode
:
-
TableBoxEncode
:
use_xywh
:
True
box_format
:
'
xywh'
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1./255.
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
mean
:
[
0.5
,
0.5
,
0.5
]
...
@@ -120,7 +118,7 @@ Eval:
...
@@ -120,7 +118,7 @@ Eval:
-
PaddingTableImage
:
-
PaddingTableImage
:
size
:
[
480
,
480
]
size
:
[
480
,
480
]
-
TableBoxEncode
:
-
TableBoxEncode
:
use_xywh
:
True
box_format
:
'
xywh'
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1./255.
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
mean
:
[
0.5
,
0.5
,
0.5
]
...
...
tools/infer/predict_system.py
浏览文件 @
ddaa2c25
...
@@ -65,9 +65,11 @@ class TextSystem(object):
...
@@ -65,9 +65,11 @@ class TextSystem(object):
self
.
crop_image_res_index
+=
bbox_num
self
.
crop_image_res_index
+=
bbox_num
def
__call__
(
self
,
img
,
cls
=
True
):
def
__call__
(
self
,
img
,
cls
=
True
):
time_dict
=
{
'det'
:
0
,
'rec'
:
0
,
'csl'
:
0
,
'all'
:
0
}
start
=
time
.
time
()
ori_im
=
img
.
copy
()
ori_im
=
img
.
copy
()
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
time_dict
[
'det'
]
=
elapse
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
if
dt_boxes
is
None
:
...
@@ -83,10 +85,12 @@ class TextSystem(object):
...
@@ -83,10 +85,12 @@ class TextSystem(object):
if
self
.
use_angle_cls
and
cls
:
if
self
.
use_angle_cls
and
cls
:
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
,
angle_list
,
elapse
=
self
.
text_classifier
(
img_crop_list
)
img_crop_list
)
time_dict
[
'cls'
]
=
elapse
logger
.
debug
(
"cls num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"cls num : {}, elapse : {}"
.
format
(
len
(
img_crop_list
),
elapse
))
len
(
img_crop_list
),
elapse
))
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
time_dict
[
'rec'
]
=
elapse
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
len
(
rec_res
),
elapse
))
if
self
.
args
.
save_crop_res
:
if
self
.
args
.
save_crop_res
:
...
@@ -98,7 +102,9 @@ class TextSystem(object):
...
@@ -98,7 +102,9 @@ class TextSystem(object):
if
score
>=
self
.
drop_score
:
if
score
>=
self
.
drop_score
:
filter_boxes
.
append
(
box
)
filter_boxes
.
append
(
box
)
filter_rec_res
.
append
(
rec_result
)
filter_rec_res
.
append
(
rec_result
)
return
filter_boxes
,
filter_rec_res
end
=
time
.
time
()
time_dict
[
'all'
]
=
end
-
start
return
filter_boxes
,
filter_rec_res
,
time_dict
def
sorted_boxes
(
dt_boxes
):
def
sorted_boxes
(
dt_boxes
):
...
@@ -133,8 +139,10 @@ def main(args):
...
@@ -133,8 +139,10 @@ def main(args):
os
.
makedirs
(
draw_img_save_dir
,
exist_ok
=
True
)
os
.
makedirs
(
draw_img_save_dir
,
exist_ok
=
True
)
save_results
=
[]
save_results
=
[]
logger
.
info
(
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
logger
.
info
(
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
)
"In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', "
"if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320"
)
# warm up 10 times
# warm up 10 times
if
args
.
warmup
:
if
args
.
warmup
:
...
@@ -155,7 +163,7 @@ def main(args):
...
@@ -155,7 +163,7 @@ def main(args):
logger
.
debug
(
"error in loading image:{}"
.
format
(
image_file
))
logger
.
debug
(
"error in loading image:{}"
.
format
(
image_file
))
continue
continue
starttime
=
time
.
time
()
starttime
=
time
.
time
()
dt_boxes
,
rec_res
=
text_sys
(
img
)
dt_boxes
,
rec_res
,
time_dict
=
text_sys
(
img
)
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
total_time
+=
elapse
total_time
+=
elapse
...
@@ -198,7 +206,10 @@ def main(args):
...
@@ -198,7 +206,10 @@ def main(args):
text_sys
.
text_detector
.
autolog
.
report
()
text_sys
.
text_detector
.
autolog
.
report
()
text_sys
.
text_recognizer
.
autolog
.
report
()
text_sys
.
text_recognizer
.
autolog
.
report
()
with
open
(
os
.
path
.
join
(
draw_img_save_dir
,
"system_results.txt"
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
os
.
path
.
join
(
draw_img_save_dir
,
"system_results.txt"
),
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
writelines
(
save_results
)
f
.
writelines
(
save_results
)
...
...
tools/infer/utility.py
浏览文件 @
ddaa2c25
...
@@ -155,6 +155,8 @@ def create_predictor(args, mode, logger):
...
@@ -155,6 +155,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
table_model_dir
model_dir
=
args
.
table_model_dir
elif
mode
==
'ser'
:
elif
mode
==
'ser'
:
model_dir
=
args
.
ser_model_dir
model_dir
=
args
.
ser_model_dir
elif
mode
==
'layout'
:
model_dir
=
args
.
layout_model_dir
else
:
else
:
model_dir
=
args
.
e2e_model_dir
model_dir
=
args
.
e2e_model_dir
...
...
tools/infer_table.py
浏览文件 @
ddaa2c25
...
@@ -56,7 +56,6 @@ def main(config, device, logger, vdl_writer):
...
@@ -56,7 +56,6 @@ def main(config, device, logger, vdl_writer):
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
use_xywh
=
algorithm
in
[
'TableMaster'
]
load_model
(
config
,
model
)
load_model
(
config
,
model
)
...
@@ -106,7 +105,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -106,7 +105,7 @@ def main(config, device, logger, vdl_writer):
f_w
.
write
(
"result: {}, {}
\n
"
.
format
(
structure_str_list
,
f_w
.
write
(
"result: {}, {}
\n
"
.
format
(
structure_str_list
,
bbox_list_str
))
bbox_list_str
))
img
=
draw_rectangle
(
file
,
bbox_list
,
use_xywh
)
img
=
draw_rectangle
(
file
,
bbox_list
)
cv2
.
imwrite
(
cv2
.
imwrite
(
os
.
path
.
join
(
save_res_path
,
os
.
path
.
basename
(
file
)),
img
)
os
.
path
.
join
(
save_res_path
,
os
.
path
.
basename
(
file
)),
img
)
logger
.
info
(
"success!"
)
logger
.
info
(
"success!"
)
...
...
tools/program.py
浏览文件 @
ddaa2c25
...
@@ -154,6 +154,7 @@ def check_xpu(use_xpu):
...
@@ -154,6 +154,7 @@ def check_xpu(use_xpu):
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
def
to_float32
(
preds
):
def
to_float32
(
preds
):
if
isinstance
(
preds
,
dict
):
if
isinstance
(
preds
,
dict
):
for
k
in
preds
:
for
k
in
preds
:
...
@@ -173,6 +174,7 @@ def to_float32(preds):
...
@@ -173,6 +174,7 @@ def to_float32(preds):
preds
=
preds
.
astype
(
paddle
.
float32
)
preds
=
preds
.
astype
(
paddle
.
float32
)
return
preds
return
preds
def
train
(
config
,
def
train
(
config
,
train_dataloader
,
train_dataloader
,
valid_dataloader
,
valid_dataloader
,
...
@@ -596,7 +598,7 @@ def preprocess(is_train=False):
...
@@ -596,7 +598,7 @@ def preprocess(is_train=False):
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'LayoutLMv2'
,
'PREN'
,
'FCE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'LayoutLMv2'
,
'PREN'
,
'FCE'
,
'SVTR'
,
'ViTSTR'
,
'ABINet'
,
'DB++'
,
'TableMaster'
,
'SPIN'
'SVTR'
,
'ViTSTR'
,
'ABINet'
,
'DB++'
,
'TableMaster'
,
'SPIN'
,
'SLANet'
]
]
if
use_xpu
:
if
use_xpu
:
...
...
tools/train.py
浏览文件 @
ddaa2c25
...
@@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer):
...
@@ -119,6 +119,10 @@ def main(config, device, logger, vdl_writer):
config
[
'Loss'
][
'ignore_index'
]
=
char_num
-
1
config
[
'Loss'
][
'ignore_index'
]
=
char_num
-
1
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
use_sync_bn
=
config
[
"Global"
].
get
(
"use_sync_bn"
,
False
)
if
use_sync_bn
:
model
=
paddle
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
logger
.
info
(
'convert_sync_batchnorm'
)
if
config
[
'Global'
][
'distributed'
]:
if
config
[
'Global'
][
'distributed'
]:
model
=
paddle
.
DataParallel
(
model
)
model
=
paddle
.
DataParallel
(
model
)
...
@@ -157,7 +161,8 @@ def main(config, device, logger, vdl_writer):
...
@@ -157,7 +161,8 @@ def main(config, device, logger, vdl_writer):
scaler
=
paddle
.
amp
.
GradScaler
(
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
scale_loss
,
init_loss_scaling
=
scale_loss
,
use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
)
use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
)
model
,
optimizer
=
paddle
.
amp
.
decorate
(
models
=
model
,
optimizers
=
optimizer
,
level
=
'O2'
,
master_weight
=
True
)
model
,
optimizer
=
paddle
.
amp
.
decorate
(
models
=
model
,
optimizers
=
optimizer
,
level
=
'O2'
,
master_weight
=
True
)
else
:
else
:
scaler
=
None
scaler
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录