Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
a0c33908
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a0c33908
编写于
6月 16, 2022
作者:
文幕地方
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add TableMaster
上级
2d89b2ce
变更
24
展开全部
显示空白变更内容
内联
并排
Showing
24 changed file
with
1699 addition
and
3391 deletion
+1699
-3391
configs/table/table_master.yml
configs/table/table_master.yml
+138
-0
configs/table/table_mv3.yml
configs/table/table_mv3.yml
+23
-16
ppocr/data/imaug/gen_table_mask.py
ppocr/data/imaug/gen_table_mask.py
+39
-54
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+173
-146
ppocr/data/pubtab_dataset.py
ppocr/data/pubtab_dataset.py
+74
-56
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+3
-2
ppocr/losses/table_master_loss.py
ppocr/losses/table_master_loss.py
+65
-0
ppocr/metrics/table_metric.py
ppocr/metrics/table_metric.py
+102
-13
ppocr/modeling/backbones/__init__.py
ppocr/modeling/backbones/__init__.py
+4
-1
ppocr/modeling/backbones/table_master_resnet.py
ppocr/modeling/backbones/table_master_resnet.py
+369
-0
ppocr/modeling/heads/__init__.py
ppocr/modeling/heads/__init__.py
+2
-1
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+9
-127
ppocr/modeling/heads/table_master_head.py
ppocr/modeling/heads/table_master_head.py
+276
-0
ppocr/optimizer/learning_rate.py
ppocr/optimizer/learning_rate.py
+43
-0
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+3
-2
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+0
-140
ppocr/postprocess/table_postprocess.py
ppocr/postprocess/table_postprocess.py
+160
-0
ppocr/utils/dict/table_master_structure_dict.txt
ppocr/utils/dict/table_master_structure_dict.txt
+39
-0
ppocr/utils/dict/table_structure_dict.txt
ppocr/utils/dict/table_structure_dict.txt
+1
-2732
ppstructure/table/predict_structure.py
ppstructure/table/predict_structure.py
+90
-54
ppstructure/utility.py
ppstructure/utility.py
+2
-1
tools/export_model.py
tools/export_model.py
+2
-0
tools/infer_table.py
tools/infer_table.py
+54
-32
tools/program.py
tools/program.py
+28
-14
未找到文件。
configs/table/table_master.yml
0 → 100755
浏览文件 @
a0c33908
Global
:
use_gpu
:
true
epoch_num
:
17
log_smooth_window
:
20
print_batch_step
:
5
save_model_dir
:
./output/table_master/
save_epoch_step
:
17
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step
:
[
0
,
400
]
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
ppstructure/docs/table/table.jpg
save_res_path
:
output/table_master
# for data or label process
character_dict_path
:
ppocr/utils/dict/table_master_structure_dict.txt
infer_mode
:
False
max_text_length
:
500
process_total_num
:
0
process_cut_num
:
0
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
MultiStepDecay
learning_rate
:
0.001
milestones
:
[
12
,
15
]
gamma
:
0.1
warmup_epoch
:
0.02
regularizer
:
name
:
'
L2'
factor
:
0.00000
Architecture
:
model_type
:
table
algorithm
:
TableMaster
Backbone
:
name
:
TableResNetExtra
gcb_config
:
ratio
:
0.0625
headers
:
1
att_scale
:
False
fusion_type
:
channel_add
layers
:
[
False
,
True
,
True
,
True
]
layers
:
[
1
,
2
,
5
,
3
]
Head
:
name
:
TableMasterHead
hidden_size
:
512
headers
:
8
dropout
:
0
d_ff
:
2024
max_text_length
:
500
Loss
:
name
:
TableMasterLoss
ignore_index
:
42
# set to len of dict + 3
PostProcess
:
name
:
TableMasterLabelDecode
box_shape
:
pad
Metric
:
name
:
TableMetric
main_indicator
:
acc
compute_bbox_metric
:
true
# cost many time, set False for training
Train
:
dataset
:
name
:
PubTabDataSet
data_dir
:
/home/zhoujun20/table/PubTabNe/pubtabnet/train/
label_file_list
:
[
/home/zhoujun20/table/PubTabNe/pubtabnet/PubTabNet_2.0.0_train.jsonl
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
TableMasterLabelEncode
:
learn_empty_box
:
False
merge_no_span_structure
:
True
replace_empty_cell_token
:
True
-
ResizeTableImage
:
max_len
:
480
resize_bboxes
:
True
-
PaddingTableImage
:
size
:
[
480
,
480
]
-
TableBoxEncode
:
use_xywh
:
true
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
structure'
,
'
bboxes'
,
'
bbox_masks'
,
'
shape'
]
loader
:
shuffle
:
True
batch_size_per_card
:
8
drop_last
:
True
num_workers
:
1
Eval
:
dataset
:
name
:
PubTabDataSet
data_dir
:
/home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list
:
[
/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
TableMasterLabelEncode
:
learn_empty_box
:
False
merge_no_span_structure
:
True
replace_empty_cell_token
:
True
-
ResizeTableImage
:
max_len
:
480
resize_bboxes
:
True
-
PaddingTableImage
:
size
:
[
480
,
480
]
-
TableBoxEncode
:
use_xywh
:
true
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.5
,
0.5
,
0.5
]
std
:
[
0.5
,
0.5
,
0.5
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
structure'
,
'
bboxes'
,
'
bbox_masks'
,
'
shape'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
2
num_workers
:
8
configs/table/table_mv3.yml
浏览文件 @
a0c33908
...
@@ -4,7 +4,7 @@ Global:
...
@@ -4,7 +4,7 @@ Global:
log_smooth_window
:
20
log_smooth_window
:
20
print_batch_step
:
5
print_batch_step
:
5
save_model_dir
:
./output/table_mv3/
save_model_dir
:
./output/table_mv3/
save_epoch_step
:
3
save_epoch_step
:
400
# evaluation is run every 400 iterations after the 0th iteration
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step
:
[
0
,
400
]
eval_batch_step
:
[
0
,
400
]
cal_metric_during_train
:
True
cal_metric_during_train
:
True
...
@@ -12,13 +12,12 @@ Global:
...
@@ -12,13 +12,12 @@ Global:
checkpoints
:
checkpoints
:
save_inference_dir
:
save_inference_dir
:
use_visualdl
:
False
use_visualdl
:
False
infer_img
:
doc/table/table.jpg
infer_img
:
ppstructure/docs/table/table.jpg
save_res_path
:
output/table_mv3
# for data or label process
# for data or label process
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_type
:
en
character_type
:
en
max_text_length
:
100
max_text_length
:
500
max_elem_length
:
800
max_cell_num
:
500
infer_mode
:
False
infer_mode
:
False
process_total_num
:
0
process_total_num
:
0
process_cut_num
:
0
process_cut_num
:
0
...
@@ -44,11 +43,8 @@ Architecture:
...
@@ -44,11 +43,8 @@ Architecture:
Head
:
Head
:
name
:
TableAttentionHead
name
:
TableAttentionHead
hidden_size
:
256
hidden_size
:
256
l2_decay
:
0.00001
loc_type
:
2
loc_type
:
2
max_text_length
:
100
max_text_length
:
500
max_elem_length
:
800
max_cell_num
:
500
Loss
:
Loss
:
name
:
TableAttentionLoss
name
:
TableAttentionLoss
...
@@ -61,6 +57,7 @@ PostProcess:
...
@@ -61,6 +57,7 @@ PostProcess:
Metric
:
Metric
:
name
:
TableMetric
name
:
TableMetric
main_indicator
:
acc
main_indicator
:
acc
compute_bbox_metric
:
False
# cost many time, set False for training
Train
:
Train
:
dataset
:
dataset
:
...
@@ -71,18 +68,23 @@ Train:
...
@@ -71,18 +68,23 @@ Train:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
TableLabelEncode
:
learn_empty_box
:
False
merge_no_span_structure
:
False
replace_empty_cell_token
:
False
-
TableBoxEncode
:
-
ResizeTableImage
:
-
ResizeTableImage
:
max_len
:
488
max_len
:
488
-
TableLabelEncode
:
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1./255.
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
order
:
'
hwc'
-
PaddingTableImage
:
-
PaddingTableImage
:
size
:
[
488
,
488
]
-
ToCHWImage
:
-
ToCHWImage
:
-
KeepKeys
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
structure'
,
'
bbox_list'
,
'
sp_tokens'
,
'
bbox_list_mask'
]
keep_keys
:
[
'
image'
,
'
structure'
,
'
bboxes'
,
'
bbox_masks'
,
'
shape'
]
loader
:
loader
:
shuffle
:
True
shuffle
:
True
batch_size_per_card
:
32
batch_size_per_card
:
32
...
@@ -92,24 +94,29 @@ Train:
...
@@ -92,24 +94,29 @@ Train:
Eval
:
Eval
:
dataset
:
dataset
:
name
:
PubTabDataSet
name
:
PubTabDataSet
data_dir
:
train_data/tabl
e/pubtabnet/val/
data_dir
:
/home/zhoujun20/table/PubTabN
e/pubtabnet/val/
label_file_
path
:
train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
label_file_
list
:
[
/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl
]
transforms
:
transforms
:
-
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
TableLabelEncode
:
learn_empty_box
:
False
merge_no_span_structure
:
False
replace_empty_cell_token
:
False
-
TableBoxEncode
:
-
ResizeTableImage
:
-
ResizeTableImage
:
max_len
:
488
max_len
:
488
-
TableLabelEncode
:
-
NormalizeImage
:
-
NormalizeImage
:
scale
:
1./255.
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
order
:
'
hwc'
-
PaddingTableImage
:
-
PaddingTableImage
:
size
:
[
488
,
488
]
-
ToCHWImage
:
-
ToCHWImage
:
-
KeepKeys
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
structure'
,
'
bbox_list'
,
'
sp_tokens'
,
'
bbox_list_mask'
]
keep_keys
:
[
'
image'
,
'
structure'
,
'
bboxes'
,
'
bbox_masks'
,
'
shape'
]
loader
:
loader
:
shuffle
:
False
shuffle
:
False
drop_last
:
False
drop_last
:
False
...
...
ppocr/data/imaug/gen_table_mask.py
浏览文件 @
a0c33908
...
@@ -48,10 +48,12 @@ class GenTableMask(object):
...
@@ -48,10 +48,12 @@ class GenTableMask(object):
in_text
=
False
# 是否遍历到了字符区内
in_text
=
False
# 是否遍历到了字符区内
box_list
=
[]
box_list
=
[]
for
i
in
range
(
len
(
project_val_array
)):
for
i
in
range
(
len
(
project_val_array
)):
if
in_text
==
False
and
project_val_array
[
i
]
>
spilt_threshold
:
# 进入字符区了
if
in_text
==
False
and
project_val_array
[
i
]
>
spilt_threshold
:
# 进入字符区了
in_text
=
True
in_text
=
True
start_idx
=
i
start_idx
=
i
elif
project_val_array
[
i
]
<=
spilt_threshold
and
in_text
==
True
:
# 进入空白区了
elif
project_val_array
[
i
]
<=
spilt_threshold
and
in_text
==
True
:
# 进入空白区了
end_idx
=
i
end_idx
=
i
in_text
=
False
in_text
=
False
if
end_idx
-
start_idx
<=
2
:
if
end_idx
-
start_idx
<=
2
:
...
@@ -70,7 +72,8 @@ class GenTableMask(object):
...
@@ -70,7 +72,8 @@ class GenTableMask(object):
box_gray_img
=
cv2
.
cvtColor
(
box_img
,
cv2
.
COLOR_BGR2GRAY
)
box_gray_img
=
cv2
.
cvtColor
(
box_img
,
cv2
.
COLOR_BGR2GRAY
)
h
,
w
=
box_gray_img
.
shape
h
,
w
=
box_gray_img
.
shape
# 灰度图片进行二值化处理
# 灰度图片进行二值化处理
ret
,
thresh1
=
cv2
.
threshold
(
box_gray_img
,
200
,
255
,
cv2
.
THRESH_BINARY_INV
)
ret
,
thresh1
=
cv2
.
threshold
(
box_gray_img
,
200
,
255
,
cv2
.
THRESH_BINARY_INV
)
# 纵向腐蚀
# 纵向腐蚀
if
h
<
w
:
if
h
<
w
:
kernel
=
np
.
ones
((
2
,
1
),
np
.
uint8
)
kernel
=
np
.
ones
((
2
,
1
),
np
.
uint8
)
...
@@ -95,10 +98,12 @@ class GenTableMask(object):
...
@@ -95,10 +98,12 @@ class GenTableMask(object):
box_list
=
[]
box_list
=
[]
spilt_threshold
=
0
spilt_threshold
=
0
for
i
in
range
(
len
(
project_val_array
)):
for
i
in
range
(
len
(
project_val_array
)):
if
in_text
==
False
and
project_val_array
[
i
]
>
spilt_threshold
:
# 进入字符区了
if
in_text
==
False
and
project_val_array
[
i
]
>
spilt_threshold
:
# 进入字符区了
in_text
=
True
in_text
=
True
start_idx
=
i
start_idx
=
i
elif
project_val_array
[
i
]
<=
spilt_threshold
and
in_text
==
True
:
# 进入空白区了
elif
project_val_array
[
i
]
<=
spilt_threshold
and
in_text
==
True
:
# 进入空白区了
end_idx
=
i
end_idx
=
i
in_text
=
False
in_text
=
False
if
end_idx
-
start_idx
<=
2
:
if
end_idx
-
start_idx
<=
2
:
...
@@ -120,7 +125,8 @@ class GenTableMask(object):
...
@@ -120,7 +125,8 @@ class GenTableMask(object):
h_end
=
h
h_end
=
h
word_img
=
erosion
[
h_start
:
h_end
+
1
,
:]
word_img
=
erosion
[
h_start
:
h_end
+
1
,
:]
word_h
,
word_w
=
word_img
.
shape
word_h
,
word_w
=
word_img
.
shape
w_split_list
,
w_projection_map
=
self
.
projection
(
word_img
.
T
,
word_w
,
word_h
)
w_split_list
,
w_projection_map
=
self
.
projection
(
word_img
.
T
,
word_w
,
word_h
)
w_start
,
w_end
=
w_split_list
[
0
][
0
],
w_split_list
[
-
1
][
1
]
w_start
,
w_end
=
w_split_list
[
0
][
0
],
w_split_list
[
-
1
][
1
]
if
h_start
>
0
:
if
h_start
>
0
:
h_start
-=
1
h_start
-=
1
...
@@ -170,7 +176,8 @@ class GenTableMask(object):
...
@@ -170,7 +176,8 @@ class GenTableMask(object):
for
sno
in
range
(
len
(
split_bbox_list
)):
for
sno
in
range
(
len
(
split_bbox_list
)):
left
,
top
,
right
,
bottom
=
split_bbox_list
[
sno
]
left
,
top
,
right
,
bottom
=
split_bbox_list
[
sno
]
left
,
top
,
right
,
bottom
=
self
.
shrink_bbox
([
left
,
top
,
right
,
bottom
])
left
,
top
,
right
,
bottom
=
self
.
shrink_bbox
(
[
left
,
top
,
right
,
bottom
])
if
self
.
mask_type
==
1
:
if
self
.
mask_type
==
1
:
mask_img
[
top
:
bottom
,
left
:
right
]
=
1.0
mask_img
[
top
:
bottom
,
left
:
right
]
=
1.0
data
[
'mask_img'
]
=
mask_img
data
[
'mask_img'
]
=
mask_img
...
@@ -179,66 +186,44 @@ class GenTableMask(object):
...
@@ -179,66 +186,44 @@ class GenTableMask(object):
data
[
'image'
]
=
mask_img
data
[
'image'
]
=
mask_img
return
data
return
data
class
ResizeTableImage
(
object
):
class
ResizeTableImage
(
object
):
def
__init__
(
self
,
max_len
,
**
kwargs
):
def
__init__
(
self
,
max_len
,
resize_bboxes
=
False
,
infer_mode
=
False
,
**
kwargs
):
super
(
ResizeTableImage
,
self
).
__init__
()
super
(
ResizeTableImage
,
self
).
__init__
()
self
.
max_len
=
max_len
self
.
max_len
=
max_len
self
.
resize_bboxes
=
resize_bboxes
self
.
infer_mode
=
infer_mode
def
get_img_bbox
(
self
,
cells
):
def
__call__
(
self
,
data
):
bbox_list
=
[]
img
=
data
[
'image'
]
if
len
(
cells
)
==
0
:
return
bbox_list
cell_num
=
len
(
cells
)
for
cno
in
range
(
cell_num
):
if
"bbox"
in
cells
[
cno
]:
bbox
=
cells
[
cno
][
'bbox'
]
bbox_list
.
append
(
bbox
)
return
bbox_list
def
resize_img_table
(
self
,
img
,
bbox_list
,
max_len
):
height
,
width
=
img
.
shape
[
0
:
2
]
height
,
width
=
img
.
shape
[
0
:
2
]
ratio
=
max_len
/
(
max
(
height
,
width
)
*
1.0
)
ratio
=
self
.
max_len
/
(
max
(
height
,
width
)
*
1.0
)
resize_h
=
int
(
height
*
ratio
)
resize_h
=
int
(
height
*
ratio
)
resize_w
=
int
(
width
*
ratio
)
resize_w
=
int
(
width
*
ratio
)
img_new
=
cv2
.
resize
(
img
,
(
resize_w
,
resize_h
))
resize_img
=
cv2
.
resize
(
img
,
(
resize_w
,
resize_h
))
bbox_list_new
=
[]
if
self
.
resize_bboxes
and
not
self
.
infer_mode
:
for
bno
in
range
(
len
(
bbox_list
)):
data
[
'bboxes'
]
=
data
[
'bboxes'
]
*
ratio
left
,
top
,
right
,
bottom
=
bbox_list
[
bno
].
copy
()
data
[
'image'
]
=
resize_img
left
=
int
(
left
*
ratio
)
data
[
'src_img'
]
=
img
top
=
int
(
top
*
ratio
)
data
[
'shape'
]
=
np
.
array
([
resize_h
,
resize_w
,
ratio
,
ratio
])
right
=
int
(
right
*
ratio
)
bottom
=
int
(
bottom
*
ratio
)
bbox_list_new
.
append
([
left
,
top
,
right
,
bottom
])
return
img_new
,
bbox_list_new
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
if
'cells'
not
in
data
:
cells
=
[]
else
:
cells
=
data
[
'cells'
]
bbox_list
=
self
.
get_img_bbox
(
cells
)
img_new
,
bbox_list_new
=
self
.
resize_img_table
(
img
,
bbox_list
,
self
.
max_len
)
data
[
'image'
]
=
img_new
cell_num
=
len
(
cells
)
bno
=
0
for
cno
in
range
(
cell_num
):
if
"bbox"
in
data
[
'cells'
][
cno
]:
data
[
'cells'
][
cno
][
'bbox'
]
=
bbox_list_new
[
bno
]
bno
+=
1
data
[
'max_len'
]
=
self
.
max_len
data
[
'max_len'
]
=
self
.
max_len
return
data
return
data
class
PaddingTableImage
(
object
):
class
PaddingTableImage
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
size
,
**
kwargs
):
super
(
PaddingTableImage
,
self
).
__init__
()
super
(
PaddingTableImage
,
self
).
__init__
()
self
.
size
=
size
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
img
=
data
[
'image'
]
max_len
=
data
[
'max_len'
]
pad_h
,
pad_w
=
self
.
size
padding_img
=
np
.
zeros
((
max_len
,
max_len
,
3
),
dtype
=
np
.
float32
)
padding_img
=
np
.
zeros
((
pad_h
,
pad_w
,
3
),
dtype
=
np
.
float32
)
height
,
width
=
img
.
shape
[
0
:
2
]
height
,
width
=
img
.
shape
[
0
:
2
]
padding_img
[
0
:
height
,
0
:
width
,
:]
=
img
.
copy
()
padding_img
[
0
:
height
,
0
:
width
,
:]
=
img
.
copy
()
data
[
'image'
]
=
padding_img
data
[
'image'
]
=
padding_img
shape
=
data
[
'shape'
].
tolist
()
shape
.
extend
([
pad_h
,
pad_w
])
data
[
'shape'
]
=
np
.
array
(
shape
)
return
data
return
data
\ No newline at end of file
ppocr/data/imaug/label_ops.py
浏览文件 @
a0c33908
...
@@ -443,7 +443,9 @@ class KieLabelEncode(object):
...
@@ -443,7 +443,9 @@ class KieLabelEncode(object):
elif
'key_cls'
in
anno
.
keys
():
elif
'key_cls'
in
anno
.
keys
():
labels
.
append
(
anno
[
'key_cls'
])
labels
.
append
(
anno
[
'key_cls'
])
else
:
else
:
raise
ValueError
(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
raise
ValueError
(
"Cannot found 'key_cls' in ann.keys(), please check your training annotation."
)
edges
.
append
(
ann
.
get
(
'edge'
,
0
))
edges
.
append
(
ann
.
get
(
'edge'
,
0
))
ann_infos
=
dict
(
ann_infos
=
dict
(
image
=
data
[
'image'
],
image
=
data
[
'image'
],
...
@@ -580,171 +582,197 @@ class SRNLabelEncode(BaseRecLabelEncode):
...
@@ -580,171 +582,197 @@ class SRNLabelEncode(BaseRecLabelEncode):
return
idx
return
idx
class
TableLabelEncode
(
object
):
class
TableLabelEncode
(
AttnLabelEncode
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
def
__init__
(
self
,
def
__init__
(
self
,
max_text_length
,
max_text_length
,
max_elem_length
,
max_cell_num
,
character_dict_path
,
character_dict_path
,
span_weight
=
1.0
,
replace_empty_cell_token
=
False
,
merge_no_span_structure
=
False
,
learn_empty_box
=
False
,
point_num
=
4
,
**
kwargs
):
**
kwargs
):
self
.
max_text_length
=
max_text_length
self
.
max_text_len
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
lower
=
False
self
.
max_cell_num
=
max_cell_num
self
.
learn_empty_box
=
learn_empty_box
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
self
.
merge_no_span_structure
=
merge_no_span_structure
character_dict_path
)
self
.
replace_empty_cell_token
=
replace_empty_cell_token
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
dict_character
=
[]
self
.
dict_character
=
{}
for
i
,
char
in
enumerate
(
list_character
):
self
.
dict_character
[
char
]
=
i
self
.
dict_elem
=
{}
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_elem
[
elem
]
=
i
self
.
span_weight
=
span_weight
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
lines
=
fin
.
readlines
()
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
for
line
in
lines
:
character_num
=
int
(
substr
[
0
])
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
elem_num
=
int
(
substr
[
1
])
dict_character
.
append
(
line
)
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\r\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
get_span_idx_list
(
self
):
dict_character
=
self
.
add_special_char
(
dict_character
)
span_idx_list
=
[]
self
.
dict
=
{}
for
elem
in
self
.
dict_elem
:
for
i
,
char
in
enumerate
(
dict_character
):
if
'span'
in
elem
:
self
.
dict
[
char
]
=
i
span_idx_list
.
append
(
self
.
dict_elem
[
elem
])
self
.
idx2char
=
{
v
:
k
for
k
,
v
in
self
.
dict
.
items
()}
return
span_idx_list
self
.
character
=
dict_character
self
.
point_num
=
point_num
self
.
pad_idx
=
self
.
dict
[
self
.
beg_str
]
self
.
start_idx
=
self
.
dict
[
self
.
beg_str
]
self
.
end_idx
=
self
.
dict
[
self
.
end_str
]
self
.
td_token
=
[
'<td>'
,
'<td'
,
'<eb></eb>'
,
'<td></td>'
]
self
.
empty_bbox_token_dict
=
{
"[]"
:
'<eb></eb>'
,
"[' ']"
:
'<eb1></eb1>'
,
"['<b>', ' ', '</b>']"
:
'<eb2></eb2>'
,
"['
\\
u2028', '
\\
u2028']"
:
'<eb3></eb3>'
,
"['<sup>', ' ', '</sup>']"
:
'<eb4></eb4>'
,
"['<b>', '</b>']"
:
'<eb5></eb5>'
,
"['<i>', ' ', '</i>']"
:
'<eb6></eb6>'
,
"['<b>', '<i>', '</i>', '</b>']"
:
'<eb7></eb7>'
,
"['<b>', '<i>', ' ', '</i>', '</b>']"
:
'<eb8></eb8>'
,
"['<i>', '</i>']"
:
'<eb9></eb9>'
,
"['<b>', ' ', '
\\
u2028', ' ', '
\\
u2028', ' ', '</b>']"
:
'<eb10></eb10>'
,
}
@
property
def
_max_text_len
(
self
):
return
self
.
max_text_len
+
2
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
cells
=
data
[
'cells'
]
cells
=
data
[
'cells'
]
structure
=
data
[
'structure'
][
'tokens'
]
structure
=
data
[
'structure'
]
structure
=
self
.
encode
(
structure
,
'elem'
)
if
self
.
merge_no_span_structure
:
structure
=
self
.
_merge_no_span_structure
(
structure
)
if
self
.
replace_empty_cell_token
:
structure
=
self
.
_replace_empty_cell_token
(
structure
,
cells
)
# remove empty token and add " " to span token
new_structure
=
[]
for
token
in
structure
:
if
token
!=
''
:
if
'span'
in
token
and
token
[
0
]
!=
' '
:
token
=
' '
+
token
new_structure
.
append
(
token
)
# encode structure
structure
=
self
.
encode
(
new_structure
)
if
structure
is
None
:
if
structure
is
None
:
return
None
return
None
elem_num
=
len
(
structure
)
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
structure
=
[
self
.
start_idx
]
+
structure
+
[
self
.
end_idx
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
)
]
# add sos abd eos
)
structure
=
structure
+
[
self
.
pad_idx
]
*
(
self
.
_max_text_len
-
len
(
structure
))
# pad
structure
=
np
.
array
(
structure
)
structure
=
np
.
array
(
structure
)
data
[
'structure'
]
=
structure
data
[
'structure'
]
=
structure
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
span_idx_list
=
self
.
get_span_idx_list
()
td_idx_list
=
np
.
logical_or
(
structure
==
elem_char_idx1
,
structure
==
elem_char_idx2
)
td_idx_list
=
np
.
where
(
td_idx_list
)[
0
]
structure_mask
=
np
.
ones
(
(
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
bbox_list
=
np
.
zeros
((
self
.
max_elem_length
+
2
,
4
),
dtype
=
np
.
float32
)
bbox_list_mask
=
np
.
zeros
(
(
self
.
max_elem_length
+
2
,
1
),
dtype
=
np
.
float32
)
img_height
,
img_width
,
img_ch
=
data
[
'image'
].
shape
if
len
(
span_idx_list
)
>
0
:
span_weight
=
len
(
td_idx_list
)
*
1.0
/
len
(
span_idx_list
)
span_weight
=
min
(
max
(
span_weight
,
1.0
),
self
.
span_weight
)
for
cno
in
range
(
len
(
cells
)):
if
'bbox'
in
cells
[
cno
]:
bbox
=
cells
[
cno
][
'bbox'
].
copy
()
bbox
[
0
]
=
bbox
[
0
]
*
1.0
/
img_width
bbox
[
1
]
=
bbox
[
1
]
*
1.0
/
img_height
bbox
[
2
]
=
bbox
[
2
]
*
1.0
/
img_width
bbox
[
3
]
=
bbox
[
3
]
*
1.0
/
img_height
td_idx
=
td_idx_list
[
cno
]
bbox_list
[
td_idx
]
=
bbox
bbox_list_mask
[
td_idx
]
=
1.0
cand_span_idx
=
td_idx
+
1
if
cand_span_idx
<
(
self
.
max_elem_length
+
2
):
if
structure
[
cand_span_idx
]
in
span_idx_list
:
structure_mask
[
cand_span_idx
]
=
span_weight
data
[
'bbox_list'
]
=
bbox_list
data
[
'bbox_list_mask'
]
=
bbox_list_mask
data
[
'structure_mask'
]
=
structure_mask
char_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'char'
)
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
data
[
'sp_tokens'
]
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
return
data
def
encode
(
self
,
text
,
char_or_elem
):
if
len
(
structure
)
>
self
.
_max_text_len
:
"""convert text-label into text-index.
"""
if
char_or_elem
==
"char"
:
max_len
=
self
.
max_text_length
current_dict
=
self
.
dict_character
else
:
max_len
=
self
.
max_elem_length
current_dict
=
self
.
dict_elem
if
len
(
text
)
>
max_len
:
return
None
if
len
(
text
)
==
0
:
if
char_or_elem
==
"char"
:
return
[
self
.
dict_character
[
'space'
]]
else
:
return
None
text_list
=
[]
for
char
in
text
:
if
char
not
in
current_dict
:
return
None
text_list
.
append
(
current_dict
[
char
])
if
len
(
text_list
)
==
0
:
if
char_or_elem
==
"char"
:
return
[
self
.
dict_character
[
'space'
]]
else
:
return
None
return
None
return
text_list
def
get_ignored_tokens
(
self
,
char_or_elem
):
# encode box
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
,
char_or_elem
)
bboxes
=
np
.
zeros
(
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
,
char_or_elem
)
(
self
.
_max_text_len
,
self
.
point_num
),
dtype
=
np
.
float32
)
return
[
beg_idx
,
end_idx
]
bbox_masks
=
np
.
zeros
((
self
.
_max_text_len
,
1
),
dtype
=
np
.
float32
)
bbox_idx
=
0
for
i
,
token
in
enumerate
(
structure
):
if
self
.
idx2char
[
token
]
in
self
.
td_token
:
if
'bbox'
in
cells
[
bbox_idx
]:
bbox
=
cells
[
bbox_idx
][
'bbox'
].
copy
()
bbox
=
np
.
array
(
bbox
,
dtype
=
np
.
float32
).
reshape
(
-
1
)
bboxes
[
i
]
=
bbox
bbox_masks
[
i
]
=
1.0
if
self
.
learn_empty_box
:
bbox_masks
[
i
]
=
1.0
bbox_idx
+=
1
data
[
'bboxes'
]
=
bboxes
data
[
'bbox_masks'
]
=
bbox_masks
return
data
def
get_beg_end_flag_idx
(
self
,
beg_or_end
,
char_or_elem
):
def
_merge_no_span_structure
(
self
,
structure
):
if
char_or_elem
==
"char"
:
new_structure
=
[]
if
beg_or_end
==
"beg"
:
i
=
0
idx
=
np
.
array
(
self
.
dict_character
[
self
.
beg_str
])
while
i
<
len
(
structure
):
elif
beg_or_end
==
"end"
:
token
=
structure
[
i
]
idx
=
np
.
array
(
self
.
dict_character
[
self
.
end_str
])
if
token
==
'<td>'
:
else
:
token
=
'<td></td>'
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of char"
\
i
+=
1
%
beg_or_end
new_structure
.
append
(
token
)
elif
char_or_elem
==
"elem"
:
i
+=
1
if
beg_or_end
==
"beg"
:
return
new_structure
idx
=
np
.
array
(
self
.
dict_elem
[
self
.
beg_str
])
elif
beg_or_end
==
"end"
:
def
_replace_empty_cell_token
(
self
,
token_list
,
cells
):
idx
=
np
.
array
(
self
.
dict_elem
[
self
.
end_str
])
bbox_idx
=
0
else
:
add_empty_bbox_token_list
=
[]
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
for
token
in
token_list
:
%
beg_or_end
if
token
in
[
'<td></td>'
,
'<td'
,
'<td>'
]:
if
'bbox'
not
in
cells
[
bbox_idx
].
keys
():
content
=
str
(
cells
[
bbox_idx
][
'tokens'
])
token
=
self
.
empty_bbox_token_dict
[
content
]
add_empty_bbox_token_list
.
append
(
token
)
bbox_idx
+=
1
else
:
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
add_empty_bbox_token_list
.
append
(
token
)
%
char_or_elem
return
add_empty_bbox_token_list
return
idx
class
TableMasterLabelEncode
(
TableLabelEncode
):
""" Convert between text-label and text-index """
def
__init__
(
self
,
max_text_length
,
character_dict_path
,
replace_empty_cell_token
=
False
,
merge_no_span_structure
=
False
,
learn_empty_box
=
False
,
point_num
=
4
,
**
kwargs
):
super
(
TableMasterLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
replace_empty_cell_token
,
merge_no_span_structure
,
learn_empty_box
,
point_num
,
**
kwargs
)
@
property
def
_max_text_len
(
self
):
return
self
.
max_text_len
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
'<SOS>'
self
.
end_str
=
'<EOS>'
self
.
unknown_str
=
'<UKN>'
self
.
pad_str
=
'<PAD>'
dict_character
=
dict_character
dict_character
=
dict_character
+
[
self
.
unknown_str
,
self
.
beg_str
,
self
.
end_str
,
self
.
pad_str
]
return
dict_character
class
TableBoxEncode
(
object
):
def
__init__
(
self
,
use_xywh
=
False
,
**
kwargs
):
self
.
use_xywh
=
use_xywh
def
__call__
(
self
,
data
):
img_height
,
img_width
=
data
[
'image'
].
shape
[:
2
]
bboxes
=
data
[
'bboxes'
]
if
self
.
use_xywh
and
bboxes
.
shape
[
1
]
==
4
:
bboxes
=
self
.
xyxy2xywh
(
bboxes
)
bboxes
[:,
0
::
2
]
/=
img_width
bboxes
[:,
1
::
2
]
/=
img_height
data
[
'bboxes'
]
=
bboxes
return
data
def
xyxy2xywh
(
self
,
bboxes
):
"""
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
where (x1,y1) is top-left, (x2,y2) is bottom-right.
(x,y) is bbox center and (w,h) is width and height.
:param bboxes: (x1, y1, x2, y2)
:return:
"""
new_bboxes
=
np
.
empty_like
(
bboxes
)
new_bboxes
[:,
0
]
=
(
bboxes
[:,
0
]
+
bboxes
[:,
2
])
/
2
# x center
new_bboxes
[:,
1
]
=
(
bboxes
[:,
1
]
+
bboxes
[:,
3
])
/
2
# y center
new_bboxes
[:,
2
]
=
bboxes
[:,
2
]
-
bboxes
[:,
0
]
# width
new_bboxes
[:,
3
]
=
bboxes
[:,
3
]
-
bboxes
[:,
1
]
# height
return
new_bboxes
class
SARLabelEncode
(
BaseRecLabelEncode
):
class
SARLabelEncode
(
BaseRecLabelEncode
):
...
@@ -1030,7 +1058,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
...
@@ -1030,7 +1058,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
use_space_char
,
**
kwargs
)
use_space_char
,
**
kwargs
)
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
data_ctc
=
copy
.
deepcopy
(
data
)
data_ctc
=
copy
.
deepcopy
(
data
)
data_sar
=
copy
.
deepcopy
(
data
)
data_sar
=
copy
.
deepcopy
(
data
)
data_out
=
dict
()
data_out
=
dict
()
...
...
ppocr/data/pubtab_dataset.py
浏览文件 @
a0c33908
...
@@ -16,6 +16,7 @@ import os
...
@@ -16,6 +16,7 @@ import os
import
random
import
random
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
import
json
import
json
from
copy
import
deepcopy
from
.imaug
import
transform
,
create_operators
from
.imaug
import
transform
,
create_operators
...
@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
...
@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
dataset_config
=
config
[
mode
][
'dataset'
]
dataset_config
=
config
[
mode
][
'dataset'
]
loader_config
=
config
[
mode
][
'loader'
]
loader_config
=
config
[
mode
][
'loader'
]
label_file_path
=
dataset_config
.
pop
(
'label_file_path'
)
label_file_list
=
dataset_config
.
pop
(
'label_file_list'
)
data_source_num
=
len
(
label_file_list
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
if
isinstance
(
ratio_list
,
(
float
,
int
)):
ratio_list
=
[
float
(
ratio_list
)]
*
int
(
data_source_num
)
assert
len
(
ratio_list
)
==
data_source_num
,
"The length of ratio_list should be the same as the file_list."
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
data_dir
=
dataset_config
[
'data_dir'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
do_shuffle
=
loader_config
[
'shuffle'
]
self
.
do_hard_select
=
False
if
'hard_select'
in
loader_config
:
self
.
do_hard_select
=
loader_config
[
'hard_select'
]
self
.
hard_prob
=
loader_config
[
'hard_prob'
]
if
self
.
do_hard_select
:
self
.
img_select_prob
=
self
.
load_hard_select_prob
()
self
.
table_select_type
=
None
if
'table_select_type'
in
loader_config
:
self
.
table_select_type
=
loader_config
[
'table_select_type'
]
self
.
table_select_prob
=
loader_config
[
'table_select_prob'
]
self
.
seed
=
seed
self
.
seed
=
seed
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_path
)
self
.
mode
=
mode
.
lower
()
with
open
(
label_file_path
,
"rb"
)
as
f
:
logger
.
info
(
"Initialize indexs of datasets:%s"
%
label_file_list
)
self
.
data_lines
=
f
.
readlines
()
self
.
data_lines
=
self
.
get_image_info_list
(
label_file_list
,
ratio_list
)
self
.
data_idx_order_list
=
list
(
range
(
len
(
self
.
data_lines
)))
# self.check(config['Global']['max_text_length'])
if
mode
.
lower
()
==
"train"
:
if
mode
.
lower
()
==
"train"
and
self
.
do_shuffle
:
self
.
shuffle_data_random
()
self
.
shuffle_data_random
()
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
self
.
ops
=
create_operators
(
dataset_config
[
'transforms'
],
global_config
)
ratio_list
=
dataset_config
.
get
(
"ratio_list"
,
[
1.0
])
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
self
.
need_reset
=
True
in
[
x
<
1
for
x
in
ratio_list
]
def
get_image_info_list
(
self
,
file_list
,
ratio_list
):
if
isinstance
(
file_list
,
str
):
file_list
=
[
file_list
]
data_lines
=
[]
for
idx
,
file
in
enumerate
(
file_list
):
with
open
(
file
,
"rb"
)
as
f
:
lines
=
f
.
readlines
()
if
self
.
mode
==
"train"
or
ratio_list
[
idx
]
<
1.0
:
random
.
seed
(
self
.
seed
)
lines
=
random
.
sample
(
lines
,
round
(
len
(
lines
)
*
ratio_list
[
idx
]))
data_lines
.
extend
(
lines
)
return
data_lines
def
check
(
self
,
max_text_length
):
data_lines
=
[]
for
line
in
self
.
data_lines
:
data_line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
)
info
=
json
.
loads
(
data_line
)
file_name
=
info
[
'filename'
]
cells
=
info
[
'html'
][
'cells'
].
copy
()
structure
=
info
[
'html'
][
'structure'
][
'tokens'
].
copy
()
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
if
not
os
.
path
.
exists
(
img_path
):
self
.
logger
.
warning
(
"{} does not exist!"
.
format
(
img_path
))
continue
if
len
(
structure
)
==
0
or
len
(
structure
)
>
max_text_length
:
continue
# data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
data_lines
.
append
(
line
)
self
.
data_lines
=
data_lines
def
shuffle_data_random
(
self
):
def
shuffle_data_random
(
self
):
if
self
.
do_shuffle
:
if
self
.
do_shuffle
:
random
.
seed
(
self
.
seed
)
random
.
seed
(
self
.
seed
)
...
@@ -68,47 +99,34 @@ class PubTabDataSet(Dataset):
...
@@ -68,47 +99,34 @@ class PubTabDataSet(Dataset):
data_line
=
data_line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
)
data_line
=
data_line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
)
info
=
json
.
loads
(
data_line
)
info
=
json
.
loads
(
data_line
)
file_name
=
info
[
'filename'
]
file_name
=
info
[
'filename'
]
select_flag
=
True
if
self
.
do_hard_select
:
prob
=
self
.
img_select_prob
[
file_name
]
if
prob
<
random
.
uniform
(
0
,
1
):
select_flag
=
False
if
self
.
table_select_type
:
structure
=
info
[
'html'
][
'structure'
][
'tokens'
].
copy
()
structure_str
=
''
.
join
(
structure
)
table_type
=
"simple"
if
'colspan'
in
structure_str
or
'rowspan'
in
structure_str
:
table_type
=
"complex"
if
table_type
==
"complex"
:
if
self
.
table_select_prob
<
random
.
uniform
(
0
,
1
):
select_flag
=
False
if
select_flag
:
cells
=
info
[
'html'
][
'cells'
].
copy
()
cells
=
info
[
'html'
][
'cells'
].
copy
()
structure
=
info
[
'html'
][
'structure'
].
copy
()
structure
=
info
[
'html'
][
'structure'
][
'tokens'
].
copy
()
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
img_path
=
os
.
path
.
join
(
self
.
data_dir
,
file_name
)
if
not
os
.
path
.
exists
(
img_path
):
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
data
=
{
data
=
{
'img_path'
:
img_path
,
'img_path'
:
img_path
,
'cells'
:
cells
,
'cells'
:
cells
,
'structure'
:
structure
'structure'
:
structure
,
'file_name'
:
file_name
}
}
if
not
os
.
path
.
exists
(
img_path
):
raise
Exception
(
"{} does not exist!"
.
format
(
img_path
))
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
with
open
(
data
[
'img_path'
],
'rb'
)
as
f
:
img
=
f
.
read
()
img
=
f
.
read
()
data
[
'image'
]
=
img
data
[
'image'
]
=
img
outs
=
transform
(
data
,
self
.
ops
)
outs
=
transform
(
data
,
self
.
ops
)
else
:
except
:
outs
=
None
import
traceback
except
Exception
as
e
:
err
=
traceback
.
format_exc
()
self
.
logger
.
error
(
self
.
logger
.
error
(
"When parsing line {}, error happened with msg: {}"
.
format
(
"When parsing line {}, error happened with msg: {}"
.
format
(
err
))
data_line
,
e
))
outs
=
None
outs
=
None
if
outs
is
None
:
if
outs
is
None
:
return
self
.
__getitem__
(
np
.
random
.
randint
(
self
.
__len__
()))
rnd_idx
=
np
.
random
.
randint
(
self
.
__len__
(
))
if
self
.
mode
==
"train"
else
(
idx
+
1
)
%
self
.
__len__
()
return
self
.
__getitem__
(
rnd_idx
)
return
outs
return
outs
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
data_
idx_order_list
)
return
len
(
self
.
data_
lines
)
ppocr/losses/__init__.py
浏览文件 @
a0c33908
...
@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss
...
@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss
# table loss
# table loss
from
.table_att_loss
import
TableAttentionLoss
from
.table_att_loss
import
TableAttentionLoss
from
.table_master_loss
import
TableMasterLoss
# vqa token loss
# vqa token loss
from
.vqa_token_layoutlm_loss
import
VQASerTokenLayoutLMLoss
from
.vqa_token_layoutlm_loss
import
VQASerTokenLayoutLMLoss
...
@@ -61,7 +61,8 @@ def build_loss(config):
...
@@ -61,7 +61,8 @@ def build_loss(config):
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'FCELoss'
,
'CTCLoss'
,
'DBLoss'
,
'PSELoss'
,
'EASTLoss'
,
'SASTLoss'
,
'FCELoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'NRTRLoss'
,
'TableAttentionLoss'
,
'SARLoss'
,
'AsterLoss'
,
'SDMGRLoss'
,
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
'VQASerTokenLayoutLMLoss'
,
'LossFromOutput'
,
'PRENLoss'
,
'MultiLoss'
,
'TableMasterLoss'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
'name'
)
...
...
ppocr/losses/table_master_loss.py
0 → 100644
浏览文件 @
a0c33908
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
from
paddle
import
nn
class
TableMasterLoss
(
nn
.
Layer
):
def
__init__
(
self
,
ignore_index
=-
1
):
super
(
TableMasterLoss
,
self
).
__init__
()
self
.
structure_loss
=
nn
.
CrossEntropyLoss
(
ignore_index
=
ignore_index
,
reduction
=
'mean'
)
self
.
box_loss
=
nn
.
L1Loss
(
reduction
=
'sum'
)
self
.
eps
=
1e-12
def
forward
(
self
,
predicts
,
batch
):
# structure_loss
structure_probs
=
predicts
[
'structure_probs'
]
structure_targets
=
batch
[
1
]
structure_targets
=
structure_targets
[:,
1
:]
structure_probs
=
structure_probs
.
reshape
(
[
-
1
,
structure_probs
.
shape
[
-
1
]])
structure_targets
=
structure_targets
.
reshape
([
-
1
])
structure_loss
=
self
.
structure_loss
(
structure_probs
,
structure_targets
)
structure_loss
=
structure_loss
.
mean
()
losses
=
dict
(
structure_loss
=
structure_loss
)
# box loss
bboxes_preds
=
predicts
[
'loc_preds'
]
bboxes_targets
=
batch
[
2
][:,
1
:,
:]
bbox_masks
=
batch
[
3
][:,
1
:]
# mask empty-bbox or non-bbox structure token's bbox.
masked_bboxes_preds
=
bboxes_preds
*
bbox_masks
masked_bboxes_targets
=
bboxes_targets
*
bbox_masks
# horizon loss (x and width)
horizon_sum_loss
=
self
.
box_loss
(
masked_bboxes_preds
[:,
:,
0
::
2
],
masked_bboxes_targets
[:,
:,
0
::
2
])
horizon_loss
=
horizon_sum_loss
/
(
bbox_masks
.
sum
()
+
self
.
eps
)
# vertical loss (y and height)
vertical_sum_loss
=
self
.
box_loss
(
masked_bboxes_preds
[:,
:,
1
::
2
],
masked_bboxes_targets
[:,
:,
1
::
2
])
vertical_loss
=
vertical_sum_loss
/
(
bbox_masks
.
sum
()
+
self
.
eps
)
horizon_loss
=
horizon_loss
.
mean
()
vertical_loss
=
vertical_loss
.
mean
()
all_loss
=
structure_loss
+
horizon_loss
+
vertical_loss
losses
.
update
({
'loss'
:
all_loss
,
'horizon_bbox_loss'
:
horizon_loss
,
'vertical_bbox_loss'
:
vertical_loss
})
return
losses
ppocr/metrics/table_metric.py
浏览文件 @
a0c33908
...
@@ -12,29 +12,30 @@
...
@@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
numpy
as
np
import
numpy
as
np
from
ppocr.metrics.det_metric
import
DetMetric
class
TableMetric
(
object
):
class
Table
Structure
Metric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
**
kwargs
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
eps
=
1e-6
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
main_indicator
=
main_indicator
self
.
eps
=
1e-5
self
.
eps
=
eps
self
.
reset
()
self
.
reset
()
def
__call__
(
self
,
pred
,
batch
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
pred_label
,
batch
=
None
,
*
args
,
**
kwargs
):
structure_probs
=
pred
[
'structure_probs'
].
numpy
()
preds
,
labels
=
pred_label
structure_labels
=
batch
[
1
]
pred_structure_batch_list
=
preds
[
'structure_batch_list'
]
gt_structure_batch_list
=
labels
[
'structure_batch_list'
]
correct_num
=
0
correct_num
=
0
all_num
=
0
all_num
=
0
structure_probs
=
np
.
argmax
(
structure_probs
,
axis
=
2
)
for
(
pred
,
pred_conf
),
target
in
zip
(
pred_structure_batch_list
,
structure_labels
=
structure_labels
[:,
1
:]
gt_structure_batch_list
):
batch_size
=
structure_probs
.
shape
[
0
]
pred_str
=
''
.
join
(
pred
)
for
bno
in
range
(
batch_size
):
target_str
=
''
.
join
(
target
)
all_num
+=
1
if
pred_str
==
target_str
:
if
(
structure_probs
[
bno
]
==
structure_labels
[
bno
]).
all
():
correct_num
+=
1
correct_num
+=
1
all_num
+=
1
self
.
correct_num
+=
correct_num
self
.
correct_num
+=
correct_num
self
.
all_num
+=
all_num
self
.
all_num
+=
all_num
return
{
'acc'
:
correct_num
*
1.0
/
(
all_num
+
self
.
eps
),
}
def
get_metric
(
self
):
def
get_metric
(
self
):
"""
"""
...
@@ -49,3 +50,91 @@ class TableMetric(object):
...
@@ -49,3 +50,91 @@ class TableMetric(object):
def
reset
(
self
):
def
reset
(
self
):
self
.
correct_num
=
0
self
.
correct_num
=
0
self
.
all_num
=
0
self
.
all_num
=
0
self
.
len_acc_num
=
0
self
.
token_nums
=
0
self
.
anys_dict
=
dict
()
from
collections
import
defaultdict
self
.
error_num_dict
=
defaultdict
(
int
)
class
TableMetric
(
object
):
def
__init__
(
self
,
main_indicator
=
'acc'
,
compute_bbox_metric
=
False
,
point_num
=
4
,
**
kwargs
):
"""
@param sub_metrics: configs of sub_metric
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self
.
structure_metric
=
TableStructureMetric
()
self
.
bbox_metric
=
DetMetric
()
if
compute_bbox_metric
else
None
self
.
main_indicator
=
main_indicator
self
.
point_num
=
point_num
self
.
reset
()
def
__call__
(
self
,
pred_label
,
batch
=
None
,
*
args
,
**
kwargs
):
self
.
structure_metric
(
pred_label
)
if
self
.
bbox_metric
is
not
None
:
self
.
bbox_metric
(
*
self
.
prepare_bbox_metric_input
(
pred_label
))
def
prepare_bbox_metric_input
(
self
,
pred_label
):
pred_bbox_batch_list
=
[]
gt_ignore_tags_batch_list
=
[]
gt_bbox_batch_list
=
[]
preds
,
labels
=
pred_label
batch_num
=
len
(
preds
[
'bbox_batch_list'
])
for
batch_idx
in
range
(
batch_num
):
# pred
pred_bbox_list
=
[
self
.
format_box
(
pred_box
)
for
pred_box
in
preds
[
'bbox_batch_list'
][
batch_idx
]
]
pred_bbox_batch_list
.
append
({
'points'
:
pred_bbox_list
})
# gt
gt_bbox_list
=
[]
gt_ignore_tags_list
=
[]
for
gt_box
in
labels
[
'bbox_batch_list'
][
batch_idx
]:
gt_bbox_list
.
append
(
self
.
format_box
(
gt_box
))
gt_ignore_tags_list
.
append
(
0
)
gt_bbox_batch_list
.
append
(
gt_bbox_list
)
gt_ignore_tags_batch_list
.
append
(
gt_ignore_tags_list
)
return
[
pred_bbox_batch_list
,
[
0
,
0
,
gt_bbox_batch_list
,
gt_ignore_tags_batch_list
]
]
def
get_metric
(
self
):
structure_metric
=
self
.
structure_metric
.
get_metric
()
if
self
.
bbox_metric
is
None
:
return
structure_metric
bbox_metric
=
self
.
bbox_metric
.
get_metric
()
if
self
.
main_indicator
==
self
.
bbox_metric
.
main_indicator
:
output
=
bbox_metric
for
sub_key
in
structure_metric
:
output
[
"structure_metric_{}"
.
format
(
sub_key
)]
=
structure_metric
[
sub_key
]
else
:
output
=
structure_metric
for
sub_key
in
bbox_metric
:
output
[
"bbox_metric_{}"
.
format
(
sub_key
)]
=
bbox_metric
[
sub_key
]
return
output
def
reset
(
self
):
self
.
structure_metric
.
reset
()
if
self
.
bbox_metric
is
not
None
:
self
.
bbox_metric
.
reset
()
def
format_box
(
self
,
box
):
if
self
.
point_num
==
4
:
x1
,
y1
,
x2
,
y2
=
box
box
=
[[
x1
,
y1
],
[
x2
,
y1
],
[
x2
,
y2
],
[
x1
,
y2
]]
elif
self
.
point_num
==
8
:
x1
,
y1
,
x2
,
y2
,
x3
,
y3
,
x4
,
y4
=
box
box
=
[[
x1
,
y1
],
[
x2
,
y2
],
[
x3
,
y3
],
[
x4
,
y4
]]
return
box
ppocr/modeling/backbones/__init__.py
浏览文件 @
a0c33908
...
@@ -20,7 +20,10 @@ def build_backbone(config, model_type):
...
@@ -20,7 +20,10 @@ def build_backbone(config, model_type):
from
.det_mobilenet_v3
import
MobileNetV3
from
.det_mobilenet_v3
import
MobileNetV3
from
.det_resnet_vd
import
ResNet
from
.det_resnet_vd
import
ResNet
from
.det_resnet_vd_sast
import
ResNet_SAST
from
.det_resnet_vd_sast
import
ResNet_SAST
support_dict
=
[
"MobileNetV3"
,
"ResNet"
,
"ResNet_SAST"
]
from
.table_master_resnet
import
TableResNetExtra
support_dict
=
[
"MobileNetV3"
,
"ResNet"
,
"ResNet_SAST"
,
"TableResNetExtra"
]
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
elif
model_type
==
"rec"
or
model_type
==
"cls"
:
from
.rec_mobilenet_v3
import
MobileNetV3
from
.rec_mobilenet_v3
import
MobileNetV3
from
.rec_resnet_vd
import
ResNet
from
.rec_resnet_vd
import
ResNet
...
...
ppocr/modeling/backbones/table_master_resnet.py
0 → 100644
浏览文件 @
a0c33908
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
class
BasicBlock
(
nn
.
Layer
):
expansion
=
1
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
,
downsample
=
None
,
gcb_config
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2D
(
inplanes
,
planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
planes
,
momentum
=
0.9
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
nn
.
Conv2D
(
planes
,
planes
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn2
=
nn
.
BatchNorm2D
(
planes
,
momentum
=
0.9
)
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
gcb_config
=
gcb_config
if
self
.
gcb_config
is
not
None
:
gcb_ratio
=
gcb_config
[
'ratio'
]
gcb_headers
=
gcb_config
[
'headers'
]
att_scale
=
gcb_config
[
'att_scale'
]
fusion_type
=
gcb_config
[
'fusion_type'
]
self
.
context_block
=
MultiAspectGCAttention
(
inplanes
=
planes
,
ratio
=
gcb_ratio
,
headers
=
gcb_headers
,
att_scale
=
att_scale
,
fusion_type
=
fusion_type
)
def
forward
(
self
,
x
):
residual
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
gcb_config
is
not
None
:
out
=
self
.
context_block
(
out
)
if
self
.
downsample
is
not
None
:
residual
=
self
.
downsample
(
x
)
out
+=
residual
out
=
self
.
relu
(
out
)
return
out
def
get_gcb_config
(
gcb_config
,
layer
):
if
gcb_config
is
None
or
not
gcb_config
[
'layers'
][
layer
]:
return
None
else
:
return
gcb_config
class
TableResNetExtra
(
nn
.
Layer
):
def
__init__
(
self
,
layers
,
in_channels
=
3
,
gcb_config
=
None
):
assert
len
(
layers
)
>=
4
super
(
TableResNetExtra
,
self
).
__init__
()
self
.
inplanes
=
128
self
.
conv1
=
nn
.
Conv2D
(
in_channels
,
64
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn1
=
nn
.
BatchNorm2D
(
64
)
self
.
relu1
=
nn
.
ReLU
()
self
.
conv2
=
nn
.
Conv2D
(
64
,
128
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn2
=
nn
.
BatchNorm2D
(
128
)
self
.
relu2
=
nn
.
ReLU
()
self
.
maxpool1
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
)
self
.
layer1
=
self
.
_make_layer
(
BasicBlock
,
256
,
layers
[
0
],
stride
=
1
,
gcb_config
=
get_gcb_config
(
gcb_config
,
0
))
self
.
conv3
=
nn
.
Conv2D
(
256
,
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn3
=
nn
.
BatchNorm2D
(
256
)
self
.
relu3
=
nn
.
ReLU
()
self
.
maxpool2
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
)
self
.
layer2
=
self
.
_make_layer
(
BasicBlock
,
256
,
layers
[
1
],
stride
=
1
,
gcb_config
=
get_gcb_config
(
gcb_config
,
1
))
self
.
conv4
=
nn
.
Conv2D
(
256
,
256
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn4
=
nn
.
BatchNorm2D
(
256
)
self
.
relu4
=
nn
.
ReLU
()
self
.
maxpool3
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
BasicBlock
,
512
,
layers
[
2
],
stride
=
1
,
gcb_config
=
get_gcb_config
(
gcb_config
,
2
))
self
.
conv5
=
nn
.
Conv2D
(
512
,
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn5
=
nn
.
BatchNorm2D
(
512
)
self
.
relu5
=
nn
.
ReLU
()
self
.
layer4
=
self
.
_make_layer
(
BasicBlock
,
512
,
layers
[
3
],
stride
=
1
,
gcb_config
=
get_gcb_config
(
gcb_config
,
3
))
self
.
conv6
=
nn
.
Conv2D
(
512
,
512
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias_attr
=
False
)
self
.
bn6
=
nn
.
BatchNorm2D
(
512
)
self
.
relu6
=
nn
.
ReLU
()
self
.
out_channels
=
[
256
,
256
,
512
]
def
_make_layer
(
self
,
block
,
planes
,
blocks
,
stride
=
1
,
gcb_config
=
None
):
downsample
=
None
if
stride
!=
1
or
self
.
inplanes
!=
planes
*
block
.
expansion
:
downsample
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
planes
*
block
.
expansion
,
kernel_size
=
1
,
stride
=
stride
,
bias_attr
=
False
),
nn
.
BatchNorm2D
(
planes
*
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
block
(
self
.
inplanes
,
planes
,
stride
,
downsample
,
gcb_config
=
gcb_config
))
self
.
inplanes
=
planes
*
block
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
block
(
self
.
inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
f
=
[]
x
=
self
.
conv1
(
x
)
# 1,64,480,480
x
=
self
.
bn1
(
x
)
x
=
self
.
relu1
(
x
)
x
=
self
.
conv2
(
x
)
# 1,128,480,480
x
=
self
.
bn2
(
x
)
x
=
self
.
relu2
(
x
)
# (48, 160)
x
=
self
.
maxpool1
(
x
)
# 1,64,240,240
x
=
self
.
layer1
(
x
)
x
=
self
.
conv3
(
x
)
# 1,256,240,240
x
=
self
.
bn3
(
x
)
x
=
self
.
relu3
(
x
)
f
.
append
(
x
)
# (24, 80)
x
=
self
.
maxpool2
(
x
)
# 1,256,120,120
x
=
self
.
layer2
(
x
)
x
=
self
.
conv4
(
x
)
# 1,256,120,120
x
=
self
.
bn4
(
x
)
x
=
self
.
relu4
(
x
)
f
.
append
(
x
)
# (12, 40)
x
=
self
.
maxpool3
(
x
)
# 1,256,60,60
x
=
self
.
layer3
(
x
)
# 1,512,60,60
x
=
self
.
conv5
(
x
)
# 1,512,60,60
x
=
self
.
bn5
(
x
)
x
=
self
.
relu5
(
x
)
x
=
self
.
layer4
(
x
)
# 1,512,60,60
x
=
self
.
conv6
(
x
)
# 1,512,60,60
x
=
self
.
bn6
(
x
)
x
=
self
.
relu6
(
x
)
f
.
append
(
x
)
# (6, 40)
return
f
class
MultiAspectGCAttention
(
nn
.
Layer
):
def
__init__
(
self
,
inplanes
,
ratio
,
headers
,
pooling_type
=
'att'
,
att_scale
=
False
,
fusion_type
=
'channel_add'
):
super
(
MultiAspectGCAttention
,
self
).
__init__
()
assert
pooling_type
in
[
'avg'
,
'att'
]
assert
fusion_type
in
[
'channel_add'
,
'channel_mul'
,
'channel_concat'
]
assert
inplanes
%
headers
==
0
and
inplanes
>=
8
# inplanes must be divided by headers evenly
self
.
headers
=
headers
self
.
inplanes
=
inplanes
self
.
ratio
=
ratio
self
.
planes
=
int
(
inplanes
*
ratio
)
self
.
pooling_type
=
pooling_type
self
.
fusion_type
=
fusion_type
self
.
att_scale
=
False
self
.
single_header_inplanes
=
int
(
inplanes
/
headers
)
if
pooling_type
==
'att'
:
self
.
conv_mask
=
nn
.
Conv2D
(
self
.
single_header_inplanes
,
1
,
kernel_size
=
1
)
self
.
softmax
=
nn
.
Softmax
(
axis
=
2
)
else
:
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
if
fusion_type
==
'channel_add'
:
self
.
channel_add_conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
self
.
planes
,
kernel_size
=
1
),
nn
.
LayerNorm
([
self
.
planes
,
1
,
1
]),
nn
.
ReLU
(),
nn
.
Conv2D
(
self
.
planes
,
self
.
inplanes
,
kernel_size
=
1
))
elif
fusion_type
==
'channel_concat'
:
self
.
channel_concat_conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
self
.
planes
,
kernel_size
=
1
),
nn
.
LayerNorm
([
self
.
planes
,
1
,
1
]),
nn
.
ReLU
(),
nn
.
Conv2D
(
self
.
planes
,
self
.
inplanes
,
kernel_size
=
1
))
# for concat
self
.
cat_conv
=
nn
.
Conv2D
(
2
*
self
.
inplanes
,
self
.
inplanes
,
kernel_size
=
1
)
elif
fusion_type
==
'channel_mul'
:
self
.
channel_mul_conv
=
nn
.
Sequential
(
nn
.
Conv2D
(
self
.
inplanes
,
self
.
planes
,
kernel_size
=
1
),
nn
.
LayerNorm
([
self
.
planes
,
1
,
1
]),
nn
.
ReLU
(),
nn
.
Conv2D
(
self
.
planes
,
self
.
inplanes
,
kernel_size
=
1
))
def
spatial_pool
(
self
,
x
):
batch
,
channel
,
height
,
width
=
x
.
shape
if
self
.
pooling_type
==
'att'
:
# [N*headers, C', H , W] C = headers * C'
x
=
x
.
reshape
([
batch
*
self
.
headers
,
self
.
single_header_inplanes
,
height
,
width
])
input_x
=
x
# [N*headers, C', H * W] C = headers * C'
# input_x = input_x.view(batch, channel, height * width)
input_x
=
input_x
.
reshape
([
batch
*
self
.
headers
,
self
.
single_header_inplanes
,
height
*
width
])
# [N*headers, 1, C', H * W]
input_x
=
input_x
.
unsqueeze
(
1
)
# [N*headers, 1, H, W]
context_mask
=
self
.
conv_mask
(
x
)
# [N*headers, 1, H * W]
context_mask
=
context_mask
.
reshape
(
[
batch
*
self
.
headers
,
1
,
height
*
width
])
# scale variance
if
self
.
att_scale
and
self
.
headers
>
1
:
context_mask
=
context_mask
/
paddle
.
sqrt
(
self
.
single_header_inplanes
)
# [N*headers, 1, H * W]
context_mask
=
self
.
softmax
(
context_mask
)
# [N*headers, 1, H * W, 1]
context_mask
=
context_mask
.
unsqueeze
(
-
1
)
# [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
context
=
paddle
.
matmul
(
input_x
,
context_mask
)
# [N, headers * C', 1, 1]
context
=
context
.
reshape
(
[
batch
,
self
.
headers
*
self
.
single_header_inplanes
,
1
,
1
])
else
:
# [N, C, 1, 1]
context
=
self
.
avg_pool
(
x
)
return
context
def
forward
(
self
,
x
):
# [N, C, 1, 1]
context
=
self
.
spatial_pool
(
x
)
out
=
x
if
self
.
fusion_type
==
'channel_mul'
:
# [N, C, 1, 1]
channel_mul_term
=
F
.
sigmoid
(
self
.
channel_mul_conv
(
context
))
out
=
out
*
channel_mul_term
elif
self
.
fusion_type
==
'channel_add'
:
# [N, C, 1, 1]
channel_add_term
=
self
.
channel_add_conv
(
context
)
out
=
out
+
channel_add_term
else
:
# [N, C, 1, 1]
channel_concat_term
=
self
.
channel_concat_conv
(
context
)
# use concat
_
,
C1
,
_
,
_
=
channel_concat_term
.
shape
N
,
C2
,
H
,
W
=
out
.
shape
out
=
paddle
.
concat
(
[
out
,
channel_concat_term
.
expand
([
-
1
,
-
1
,
H
,
W
])],
axis
=
1
)
out
=
self
.
cat_conv
(
out
)
out
=
F
.
layer_norm
(
out
,
[
self
.
inplanes
,
H
,
W
])
out
=
F
.
relu
(
out
)
return
out
ppocr/modeling/heads/__init__.py
浏览文件 @
a0c33908
...
@@ -41,12 +41,13 @@ def build_head(config):
...
@@ -41,12 +41,13 @@ def build_head(config):
from
.kie_sdmgr_head
import
SDMGRHead
from
.kie_sdmgr_head
import
SDMGRHead
from
.table_att_head
import
TableAttentionHead
from
.table_att_head
import
TableAttentionHead
from
.table_master_head
import
TableMasterHead
support_dict
=
[
support_dict
=
[
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'DBHead'
,
'PSEHead'
,
'FCEHead'
,
'EASTHead'
,
'SASTHead'
,
'CTCHead'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'ClsHead'
,
'AttentionHead'
,
'SRNHead'
,
'PGHead'
,
'Transformer'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'TableAttentionHead'
,
'SARHead'
,
'AsterHead'
,
'SDMGRHead'
,
'PRENHead'
,
'MultiHead'
'MultiHead'
,
'TableMasterHead'
]
]
#table head
#table head
...
...
ppocr/modeling/heads/table_att_head.py
浏览文件 @
a0c33908
...
@@ -21,6 +21,8 @@ import paddle.nn as nn
...
@@ -21,6 +21,8 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
.rec_att_head
import
AttentionGRUCell
class
TableAttentionHead
(
nn
.
Layer
):
class
TableAttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -28,17 +30,13 @@ class TableAttentionHead(nn.Layer):
...
@@ -28,17 +30,13 @@ class TableAttentionHead(nn.Layer):
hidden_size
,
hidden_size
,
loc_type
,
loc_type
,
in_max_len
=
488
,
in_max_len
=
488
,
max_text_length
=
100
,
max_text_length
=
800
,
max_elem_length
=
800
,
max_cell_num
=
500
,
**
kwargs
):
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
input_size
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
elem_num
=
30
self
.
elem_num
=
30
self
.
max_text_length
=
max_text_length
self
.
max_text_length
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
max_cell_num
=
max_cell_num
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
...
@@ -50,11 +48,11 @@ class TableAttentionHead(nn.Layer):
...
@@ -50,11 +48,11 @@ class TableAttentionHead(nn.Layer):
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
else
:
else
:
if
self
.
in_max_len
==
640
:
if
self
.
in_max_len
==
640
:
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_
elem
_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_
text
_length
+
1
)
elif
self
.
in_max_len
==
800
:
elif
self
.
in_max_len
==
800
:
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_
elem
_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_
text
_length
+
1
)
else
:
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_
elem
_length
+
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_
text
_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
...
@@ -77,7 +75,7 @@ class TableAttentionHead(nn.Layer):
...
@@ -77,7 +75,7 @@ class TableAttentionHead(nn.Layer):
output_hiddens
=
[]
output_hiddens
=
[]
if
self
.
training
and
targets
is
not
None
:
if
self
.
training
and
targets
is
not
None
:
structure
=
targets
[
0
]
structure
=
targets
[
0
]
for
i
in
range
(
self
.
max_
elem
_length
+
1
):
for
i
in
range
(
self
.
max_
text
_length
+
1
):
elem_onehots
=
self
.
_char_to_onehot
(
elem_onehots
=
self
.
_char_to_onehot
(
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
@@ -102,9 +100,9 @@ class TableAttentionHead(nn.Layer):
...
@@ -102,9 +100,9 @@ class TableAttentionHead(nn.Layer):
elem_onehots
=
None
elem_onehots
=
None
outputs
=
None
outputs
=
None
alpha
=
None
alpha
=
None
max_
elem_length
=
paddle
.
to_tensor
(
self
.
max_elem
_length
)
max_
text_length
=
paddle
.
to_tensor
(
self
.
max_text
_length
)
i
=
0
i
=
0
while
i
<
max_
elem
_length
+
1
:
while
i
<
max_
text
_length
+
1
:
elem_onehots
=
self
.
_char_to_onehot
(
elem_onehots
=
self
.
_char_to_onehot
(
temp_elem
,
onehot_dim
=
self
.
elem_num
)
temp_elem
,
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
@@ -128,119 +126,3 @@ class TableAttentionHead(nn.Layer):
...
@@ -128,119 +126,3 @@ class TableAttentionHead(nn.Layer):
loc_preds
=
self
.
loc_generator
(
loc_concat
)
loc_preds
=
self
.
loc_generator
(
loc_concat
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
class
AttentionGRUCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionGRUCell
,
self
).
__init__
()
self
.
i2h
=
nn
.
Linear
(
input_size
,
hidden_size
,
bias_attr
=
False
)
self
.
h2h
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
score
=
nn
.
Linear
(
hidden_size
,
1
,
bias_attr
=
False
)
self
.
rnn
=
nn
.
GRUCell
(
input_size
=
input_size
+
num_embeddings
,
hidden_size
=
hidden_size
)
self
.
hidden_size
=
hidden_size
def
forward
(
self
,
prev_hidden
,
batch_H
,
char_onehots
):
batch_H_proj
=
self
.
i2h
(
batch_H
)
prev_hidden_proj
=
paddle
.
unsqueeze
(
self
.
h2h
(
prev_hidden
),
axis
=
1
)
res
=
paddle
.
add
(
batch_H_proj
,
prev_hidden_proj
)
res
=
paddle
.
tanh
(
res
)
e
=
self
.
score
(
res
)
alpha
=
F
.
softmax
(
e
,
axis
=
1
)
alpha
=
paddle
.
transpose
(
alpha
,
[
0
,
2
,
1
])
context
=
paddle
.
squeeze
(
paddle
.
mm
(
alpha
,
batch_H
),
axis
=
1
)
concat_context
=
paddle
.
concat
([
context
,
char_onehots
],
1
)
cur_hidden
=
self
.
rnn
(
concat_context
,
prev_hidden
)
return
cur_hidden
,
alpha
class
AttentionLSTM
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
hidden_size
,
**
kwargs
):
super
(
AttentionLSTM
,
self
).
__init__
()
self
.
input_size
=
in_channels
self
.
hidden_size
=
hidden_size
self
.
num_classes
=
out_channels
self
.
attention_cell
=
AttentionLSTMCell
(
in_channels
,
hidden_size
,
out_channels
,
use_gru
=
False
)
self
.
generator
=
nn
.
Linear
(
hidden_size
,
out_channels
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
return
input_ont_hot
def
forward
(
self
,
inputs
,
targets
=
None
,
batch_max_length
=
25
):
batch_size
=
inputs
.
shape
[
0
]
num_steps
=
batch_max_length
hidden
=
(
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
)),
paddle
.
zeros
(
(
batch_size
,
self
.
hidden_size
)))
output_hiddens
=
[]
if
targets
is
not
None
:
for
i
in
range
(
num_steps
):
# one-hot vectors for a i-th char
char_onehots
=
self
.
_char_to_onehot
(
targets
[:,
i
],
onehot_dim
=
self
.
num_classes
)
hidden
,
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
hidden
=
(
hidden
[
1
][
0
],
hidden
[
1
][
1
])
output_hiddens
.
append
(
paddle
.
unsqueeze
(
hidden
[
0
],
axis
=
1
))
output
=
paddle
.
concat
(
output_hiddens
,
axis
=
1
)
probs
=
self
.
generator
(
output
)
else
:
targets
=
paddle
.
zeros
(
shape
=
[
batch_size
],
dtype
=
"int32"
)
probs
=
None
for
i
in
range
(
num_steps
):
char_onehots
=
self
.
_char_to_onehot
(
targets
,
onehot_dim
=
self
.
num_classes
)
hidden
,
alpha
=
self
.
attention_cell
(
hidden
,
inputs
,
char_onehots
)
probs_step
=
self
.
generator
(
hidden
[
0
])
hidden
=
(
hidden
[
1
][
0
],
hidden
[
1
][
1
])
if
probs
is
None
:
probs
=
paddle
.
unsqueeze
(
probs_step
,
axis
=
1
)
else
:
probs
=
paddle
.
concat
(
[
probs
,
paddle
.
unsqueeze
(
probs_step
,
axis
=
1
)],
axis
=
1
)
next_input
=
probs_step
.
argmax
(
axis
=
1
)
targets
=
next_input
return
probs
class
AttentionLSTMCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionLSTMCell
,
self
).
__init__
()
self
.
i2h
=
nn
.
Linear
(
input_size
,
hidden_size
,
bias_attr
=
False
)
self
.
h2h
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
score
=
nn
.
Linear
(
hidden_size
,
1
,
bias_attr
=
False
)
if
not
use_gru
:
self
.
rnn
=
nn
.
LSTMCell
(
input_size
=
input_size
+
num_embeddings
,
hidden_size
=
hidden_size
)
else
:
self
.
rnn
=
nn
.
GRUCell
(
input_size
=
input_size
+
num_embeddings
,
hidden_size
=
hidden_size
)
self
.
hidden_size
=
hidden_size
def
forward
(
self
,
prev_hidden
,
batch_H
,
char_onehots
):
batch_H_proj
=
self
.
i2h
(
batch_H
)
prev_hidden_proj
=
paddle
.
unsqueeze
(
self
.
h2h
(
prev_hidden
[
0
]),
axis
=
1
)
res
=
paddle
.
add
(
batch_H_proj
,
prev_hidden_proj
)
res
=
paddle
.
tanh
(
res
)
e
=
self
.
score
(
res
)
alpha
=
F
.
softmax
(
e
,
axis
=
1
)
alpha
=
paddle
.
transpose
(
alpha
,
[
0
,
2
,
1
])
context
=
paddle
.
squeeze
(
paddle
.
mm
(
alpha
,
batch_H
),
axis
=
1
)
concat_context
=
paddle
.
concat
([
context
,
char_onehots
],
1
)
cur_hidden
=
self
.
rnn
(
concat_context
,
prev_hidden
)
return
cur_hidden
,
alpha
ppocr/modeling/heads/table_master_head.py
0 → 100644
浏览文件 @
a0c33908
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
import
math
import
paddle
from
paddle
import
nn
from
paddle.nn
import
functional
as
F
class
TableMasterHead
(
nn
.
Layer
):
"""
Split to two transformer header at the last layer.
Cls_layer is used to structure token classification.
Bbox_layer is used to regress bbox coord.
"""
def
__init__
(
self
,
in_channels
,
out_channels
=
30
,
headers
=
8
,
d_ff
=
2048
,
dropout
=
0
,
max_text_length
=
500
,
point_num
=
4
,
**
kwargs
):
super
(
TableMasterHead
,
self
).
__init__
()
hidden_size
=
in_channels
[
-
1
]
self
.
layers
=
clones
(
DecoderLayer
(
headers
,
hidden_size
,
dropout
,
d_ff
),
2
)
self
.
cls_layer
=
clones
(
DecoderLayer
(
headers
,
hidden_size
,
dropout
,
d_ff
),
1
)
self
.
bbox_layer
=
clones
(
DecoderLayer
(
headers
,
hidden_size
,
dropout
,
d_ff
),
1
)
self
.
cls_fc
=
nn
.
Linear
(
hidden_size
,
out_channels
)
self
.
bbox_fc
=
nn
.
Sequential
(
# nn.Linear(hidden_size, hidden_size),
nn
.
Linear
(
hidden_size
,
point_num
),
nn
.
Sigmoid
())
self
.
norm
=
nn
.
LayerNorm
(
hidden_size
)
self
.
embedding
=
Embeddings
(
d_model
=
hidden_size
,
vocab
=
out_channels
)
self
.
positional_encoding
=
PositionalEncoding
(
d_model
=
hidden_size
)
self
.
SOS
=
out_channels
-
3
self
.
PAD
=
out_channels
-
1
self
.
out_channels
=
out_channels
self
.
point_num
=
point_num
self
.
max_text_length
=
max_text_length
def
make_mask
(
self
,
tgt
):
"""
Make mask for self attention.
:param src: [b, c, h, l_src]
:param tgt: [b, l_tgt]
:return:
"""
trg_pad_mask
=
(
tgt
!=
self
.
PAD
).
unsqueeze
(
1
).
unsqueeze
(
3
)
tgt_len
=
paddle
.
shape
(
tgt
)[
1
]
trg_sub_mask
=
paddle
.
tril
(
paddle
.
ones
(
([
tgt_len
,
tgt_len
]),
dtype
=
paddle
.
float32
))
tgt_mask
=
paddle
.
logical_and
(
trg_pad_mask
.
astype
(
paddle
.
float32
),
trg_sub_mask
)
return
tgt_mask
.
astype
(
paddle
.
float32
)
def
decode
(
self
,
input
,
feature
,
src_mask
,
tgt_mask
):
# main process of transformer decoder.
x
=
self
.
embedding
(
input
)
# x: 1*x*512, feature: 1*3600,512
x
=
self
.
positional_encoding
(
x
)
# origin transformer layers
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
feature
,
src_mask
,
tgt_mask
)
# cls head
for
layer
in
self
.
cls_layer
:
cls_x
=
layer
(
x
,
feature
,
src_mask
,
tgt_mask
)
cls_x
=
self
.
norm
(
cls_x
)
# bbox head
for
layer
in
self
.
bbox_layer
:
bbox_x
=
layer
(
x
,
feature
,
src_mask
,
tgt_mask
)
bbox_x
=
self
.
norm
(
bbox_x
)
return
self
.
cls_fc
(
cls_x
),
self
.
bbox_fc
(
bbox_x
)
def
greedy_forward
(
self
,
SOS
,
feature
):
input
=
SOS
output
=
paddle
.
zeros
(
[
input
.
shape
[
0
],
self
.
max_text_length
+
1
,
self
.
out_channels
])
bbox_output
=
paddle
.
zeros
(
[
input
.
shape
[
0
],
self
.
max_text_length
+
1
,
self
.
point_num
])
max_text_length
=
paddle
.
to_tensor
(
self
.
max_text_length
)
for
i
in
range
(
max_text_length
+
1
):
target_mask
=
self
.
make_mask
(
input
)
out_step
,
bbox_output_step
=
self
.
decode
(
input
,
feature
,
None
,
target_mask
)
prob
=
F
.
softmax
(
out_step
,
axis
=-
1
)
next_word
=
prob
.
argmax
(
axis
=
2
,
dtype
=
"int64"
)
input
=
paddle
.
concat
(
[
input
,
next_word
[:,
-
1
].
unsqueeze
(
-
1
)],
axis
=
1
)
if
i
==
self
.
max_text_length
:
output
=
out_step
bbox_output
=
bbox_output_step
return
output
,
bbox_output
def
forward_train
(
self
,
out_enc
,
targets
):
# x is token of label
# feat is feature after backbone before pe.
# out_enc is feature after pe.
padded_targets
=
targets
[
0
]
src_mask
=
None
tgt_mask
=
self
.
make_mask
(
padded_targets
[:,
:
-
1
])
output
,
bbox_output
=
self
.
decode
(
padded_targets
[:,
:
-
1
],
out_enc
,
src_mask
,
tgt_mask
)
return
{
'structure_probs'
:
output
,
'loc_preds'
:
bbox_output
}
def
forward_test
(
self
,
out_enc
):
batch_size
=
out_enc
.
shape
[
0
]
SOS
=
paddle
.
zeros
([
batch_size
,
1
],
dtype
=
'int64'
)
+
self
.
SOS
output
,
bbox_output
=
self
.
greedy_forward
(
SOS
,
out_enc
)
# output = F.softmax(output)
return
{
'structure_probs'
:
output
,
'loc_preds'
:
bbox_output
}
def
forward
(
self
,
feat
,
targets
=
None
):
feat
=
feat
[
-
1
]
b
,
c
,
h
,
w
=
feat
.
shape
feat
=
feat
.
reshape
([
b
,
c
,
h
*
w
])
# flatten 2D feature map
feat
=
feat
.
transpose
((
0
,
2
,
1
))
out_enc
=
self
.
positional_encoding
(
feat
)
if
self
.
training
:
return
self
.
forward_train
(
out_enc
,
targets
)
return
self
.
forward_test
(
out_enc
)
class
DecoderLayer
(
nn
.
Layer
):
"""
Decoder is made of self attention, srouce attention and feed forward.
"""
def
__init__
(
self
,
headers
,
d_model
,
dropout
,
d_ff
):
super
(
DecoderLayer
,
self
).
__init__
()
self
.
self_attn
=
MultiHeadAttention
(
headers
,
d_model
,
dropout
)
self
.
src_attn
=
MultiHeadAttention
(
headers
,
d_model
,
dropout
)
self
.
feed_forward
=
FeedForward
(
d_model
,
d_ff
,
dropout
)
self
.
sublayer
=
clones
(
SubLayerConnection
(
d_model
,
dropout
),
3
)
def
forward
(
self
,
x
,
feature
,
src_mask
,
tgt_mask
):
x
=
self
.
sublayer
[
0
](
x
,
lambda
x
:
self
.
self_attn
(
x
,
x
,
x
,
tgt_mask
))
x
=
self
.
sublayer
[
1
](
x
,
lambda
x
:
self
.
src_attn
(
x
,
feature
,
feature
,
src_mask
))
return
self
.
sublayer
[
2
](
x
,
self
.
feed_forward
)
class
MultiHeadAttention
(
nn
.
Layer
):
def
__init__
(
self
,
headers
,
d_model
,
dropout
):
super
(
MultiHeadAttention
,
self
).
__init__
()
assert
d_model
%
headers
==
0
self
.
d_k
=
int
(
d_model
/
headers
)
self
.
headers
=
headers
self
.
linears
=
clones
(
nn
.
Linear
(
d_model
,
d_model
),
4
)
self
.
attn
=
None
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
query
,
key
,
value
,
mask
=
None
):
B
=
query
.
shape
[
0
]
# 1) Do all the linear projections in batch from d_model => h x d_k
query
,
key
,
value
=
\
[
l
(
x
).
reshape
([
B
,
0
,
self
.
headers
,
self
.
d_k
]).
transpose
([
0
,
2
,
1
,
3
])
for
l
,
x
in
zip
(
self
.
linears
,
(
query
,
key
,
value
))]
# 2) Apply attention on all the projected vectors in batch
x
,
self
.
attn
=
self_attention
(
query
,
key
,
value
,
mask
=
mask
,
dropout
=
self
.
dropout
)
x
=
x
.
transpose
([
0
,
2
,
1
,
3
]).
reshape
([
B
,
0
,
self
.
headers
*
self
.
d_k
])
return
self
.
linears
[
-
1
](
x
)
class
FeedForward
(
nn
.
Layer
):
def
__init__
(
self
,
d_model
,
d_ff
,
dropout
):
super
(
FeedForward
,
self
).
__init__
()
self
.
w_1
=
nn
.
Linear
(
d_model
,
d_ff
)
self
.
w_2
=
nn
.
Linear
(
d_ff
,
d_model
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
):
return
self
.
w_2
(
self
.
dropout
(
F
.
relu
(
self
.
w_1
(
x
))))
class
SubLayerConnection
(
nn
.
Layer
):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def
__init__
(
self
,
size
,
dropout
):
super
(
SubLayerConnection
,
self
).
__init__
()
self
.
norm
=
nn
.
LayerNorm
(
size
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
sublayer
):
return
x
+
self
.
dropout
(
sublayer
(
self
.
norm
(
x
)))
def
masked_fill
(
x
,
mask
,
value
):
mask
=
mask
.
astype
(
x
.
dtype
)
return
x
*
paddle
.
logical_not
(
mask
).
astype
(
x
.
dtype
)
+
mask
*
value
def
self_attention
(
query
,
key
,
value
,
mask
=
None
,
dropout
=
None
):
"""
Compute 'Scale Dot Product Attention'
"""
d_k
=
value
.
shape
[
-
1
]
score
=
paddle
.
matmul
(
query
,
key
.
transpose
([
0
,
1
,
3
,
2
])
/
math
.
sqrt
(
d_k
))
if
mask
is
not
None
:
# score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
score
=
masked_fill
(
score
,
mask
==
0
,
-
6.55e4
)
# for fp16
p_attn
=
F
.
softmax
(
score
,
axis
=-
1
)
if
dropout
is
not
None
:
p_attn
=
dropout
(
p_attn
)
return
paddle
.
matmul
(
p_attn
,
value
),
p_attn
def
clones
(
module
,
N
):
""" Produce N identical layers """
return
nn
.
LayerList
([
copy
.
deepcopy
(
module
)
for
_
in
range
(
N
)])
class
Embeddings
(
nn
.
Layer
):
def
__init__
(
self
,
d_model
,
vocab
):
super
(
Embeddings
,
self
).
__init__
()
self
.
lut
=
nn
.
Embedding
(
vocab
,
d_model
)
self
.
d_model
=
d_model
def
forward
(
self
,
*
input
):
x
=
input
[
0
]
return
self
.
lut
(
x
)
*
math
.
sqrt
(
self
.
d_model
)
class
PositionalEncoding
(
nn
.
Layer
):
""" Implement the PE function. """
def
__init__
(
self
,
d_model
,
dropout
=
0.
,
max_len
=
5000
):
super
(
PositionalEncoding
,
self
).
__init__
()
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
# Compute the positional encodings once in log space.
pe
=
paddle
.
zeros
([
max_len
,
d_model
])
position
=
paddle
.
arange
(
0
,
max_len
).
unsqueeze
(
1
).
astype
(
'float32'
)
div_term
=
paddle
.
exp
(
paddle
.
arange
(
0
,
d_model
,
2
)
*
-
math
.
log
(
10000.0
)
/
d_model
)
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
)
self
.
register_buffer
(
'pe'
,
pe
)
def
forward
(
self
,
feat
,
**
kwargs
):
feat
=
feat
+
self
.
pe
[:,
:
paddle
.
shape
(
feat
)[
1
]]
# pe 1*5000*512
return
self
.
dropout
(
feat
)
ppocr/optimizer/learning_rate.py
浏览文件 @
a0c33908
...
@@ -308,3 +308,46 @@ class Const(object):
...
@@ -308,3 +308,46 @@ class Const(object):
end_lr
=
self
.
learning_rate
,
end_lr
=
self
.
learning_rate
,
last_epoch
=
self
.
last_epoch
)
last_epoch
=
self
.
last_epoch
)
return
learning_rate
return
learning_rate
class
MultiStepDecay
(
object
):
"""
Piecewise learning rate decay
Args:
step_each_epoch(int): steps each epoch
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def
__init__
(
self
,
learning_rate
,
milestones
,
step_each_epoch
,
gamma
,
warmup_epoch
=
0
,
last_epoch
=-
1
,
**
kwargs
):
super
(
MultiStepDecay
,
self
).
__init__
()
self
.
milestones
=
[
step_each_epoch
*
e
for
e
in
milestones
]
self
.
learning_rate
=
learning_rate
self
.
gamma
=
gamma
self
.
last_epoch
=
last_epoch
self
.
warmup_epoch
=
round
(
warmup_epoch
*
step_each_epoch
)
def
__call__
(
self
):
learning_rate
=
lr
.
MultiStepDecay
(
learning_rate
=
self
.
learning_rate
,
milestones
=
self
.
milestones
,
gamma
=
self
.
gamma
,
last_epoch
=
self
.
last_epoch
)
if
self
.
warmup_epoch
>
0
:
learning_rate
=
lr
.
LinearWarmup
(
learning_rate
=
learning_rate
,
warmup_steps
=
self
.
warmup_epoch
,
start_lr
=
0.0
,
end_lr
=
self
.
learning_rate
,
last_epoch
=
self
.
last_epoch
)
return
learning_rate
\ No newline at end of file
ppocr/postprocess/__init__.py
浏览文件 @
a0c33908
...
@@ -26,8 +26,9 @@ from .east_postprocess import EASTPostProcess
...
@@ -26,8 +26,9 @@ from .east_postprocess import EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.fce_postprocess
import
FCEPostProcess
from
.fce_postprocess
import
FCEPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
\
DistillationCTCLabelDecode
,
TableLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
DistillationCTCLabelDecode
,
NRTRLabelDecode
,
SARLabelDecode
,
\
SEEDLabelDecode
,
PRENLabelDecode
SEEDLabelDecode
,
PRENLabelDecode
from
.table_postprocess
import
TableMasterLabelDecode
,
TableLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.pg_postprocess
import
PGPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
from
.vqa_token_ser_layoutlm_postprocess
import
VQASerTokenLayoutLMPostProcess
...
@@ -42,7 +43,7 @@ def build_post_process(config, global_config=None):
...
@@ -42,7 +43,7 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'DistillationDBPostProcess'
,
'NRTRLabelDecode'
,
'SARLabelDecode'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'SEEDLabelDecode'
,
'VQASerTokenLayoutLMPostProcess'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'VQAReTokenLayoutLMPostProcess'
,
'PRENLabelDecode'
,
'DistillationSARLabelDecode'
'DistillationSARLabelDecode'
,
'TableMasterLabelDecode'
]
]
if
config
[
'name'
]
==
'PSEPostProcess'
:
if
config
[
'name'
]
==
'PSEPostProcess'
:
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
a0c33908
...
@@ -444,146 +444,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
...
@@ -444,146 +444,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
return
idx
return
idx
class
TableLabelDecode
(
object
):
""" """
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
self
.
dict_idx_character
=
{}
for
i
,
char
in
enumerate
(
list_character
):
self
.
dict_idx_character
[
i
]
=
char
self
.
dict_character
[
char
]
=
i
self
.
dict_elem
=
{}
self
.
dict_idx_elem
=
{}
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_idx_elem
[
i
]
=
elem
self
.
dict_elem
[
elem
]
=
i
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
loc_preds
=
loc_preds
.
numpy
()
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
res_html_code_list
=
[]
res_loc_list
=
[]
batch_num
=
len
(
structure_str
)
for
bno
in
range
(
batch_num
):
res_loc
=
[]
for
sno
in
range
(
len
(
structure_str
[
bno
])):
text
=
structure_str
[
bno
][
sno
]
if
text
in
[
'<td>'
,
'<td'
]:
pos
=
structure_pos
[
bno
][
sno
]
res_loc
.
append
(
loc_preds
[
bno
,
pos
])
res_html_code
=
''
.
join
(
structure_str
[
bno
])
res_loc
=
np
.
array
(
res_loc
)
res_html_code_list
.
append
(
res_html_code
)
res_loc_list
.
append
(
res_loc
)
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
def
decode
(
self
,
text_index
,
structure_probs
,
char_or_elem
):
"""convert text-label into text-index.
"""
if
char_or_elem
==
"char"
:
current_dict
=
self
.
dict_idx_character
else
:
current_dict
=
self
.
dict_idx_elem
ignored_tokens
=
self
.
get_ignored_tokens
(
'elem'
)
beg_idx
,
end_idx
=
ignored_tokens
result_list
=
[]
result_pos_list
=
[]
result_score_list
=
[]
result_elem_idx_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
elem_pos_list
=
[]
elem_idx_list
=
[]
score_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
tmp_elem_idx
=
int
(
text_index
[
batch_idx
][
idx
])
if
idx
>
0
and
tmp_elem_idx
==
end_idx
:
break
if
tmp_elem_idx
in
ignored_tokens
:
continue
char_list
.
append
(
current_dict
[
tmp_elem_idx
])
elem_pos_list
.
append
(
idx
)
score_list
.
append
(
structure_probs
[
batch_idx
,
idx
])
elem_idx_list
.
append
(
tmp_elem_idx
)
result_list
.
append
(
char_list
)
result_pos_list
.
append
(
elem_pos_list
)
result_score_list
.
append
(
score_list
)
result_elem_idx_list
.
append
(
elem_idx_list
)
return
result_list
,
result_pos_list
,
result_score_list
,
result_elem_idx_list
def
get_ignored_tokens
(
self
,
char_or_elem
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
,
char_or_elem
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
,
char_or_elem
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
,
char_or_elem
):
if
char_or_elem
==
"char"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_character
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_character
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of char"
\
%
beg_or_end
elif
char_or_elem
==
"elem"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_elem
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_elem
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
return
idx
class
SARLabelDecode
(
BaseRecLabelDecode
):
class
SARLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
...
...
ppocr/postprocess/table_postprocess.py
0 → 100644
浏览文件 @
a0c33908
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
from
.rec_postprocess
import
AttnLabelDecode
class
TableLabelDecode
(
AttnLabelDecode
):
""" """
def
__init__
(
self
,
character_dict_path
,
**
kwargs
):
super
(
TableLabelDecode
,
self
).
__init__
(
character_dict_path
)
self
.
td_token
=
[
'<td>'
,
'<td'
,
'<eb></eb>'
,
'<td></td>'
]
def
__call__
(
self
,
preds
,
batch
=
None
):
structure_probs
=
preds
[
'structure_probs'
]
bbox_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
bbox_preds
,
paddle
.
Tensor
):
bbox_preds
=
bbox_preds
.
numpy
()
shape_list
=
batch
[
-
1
]
result
=
self
.
decode
(
structure_probs
,
bbox_preds
,
shape_list
)
if
len
(
batch
)
==
1
:
# only contains shape
return
result
label_decode_result
=
self
.
decode_label
(
batch
)
return
result
,
label_decode_result
def
decode
(
self
,
structure_probs
,
bbox_preds
,
shape_list
):
"""convert text-label into text-index.
"""
ignored_tokens
=
self
.
get_ignored_tokens
()
end_idx
=
self
.
dict
[
self
.
end_str
]
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_batch_list
=
[]
bbox_batch_list
=
[]
batch_size
=
len
(
structure_idx
)
for
batch_idx
in
range
(
batch_size
):
structure_list
=
[]
bbox_list
=
[]
score_list
=
[]
for
idx
in
range
(
len
(
structure_idx
[
batch_idx
])):
char_idx
=
int
(
structure_idx
[
batch_idx
][
idx
])
if
idx
>
0
and
char_idx
==
end_idx
:
break
if
char_idx
in
ignored_tokens
:
continue
text
=
self
.
character
[
char_idx
]
if
text
in
self
.
td_token
:
bbox
=
bbox_preds
[
batch_idx
,
idx
]
bbox
=
self
.
_bbox_decode
(
bbox
,
shape_list
[
batch_idx
])
bbox_list
.
append
(
bbox
)
structure_list
.
append
(
text
)
score_list
.
append
(
structure_probs
[
batch_idx
,
idx
])
structure_batch_list
.
append
([
structure_list
,
np
.
mean
(
score_list
)])
bbox_batch_list
.
append
(
np
.
array
(
bbox_list
))
result
=
{
'bbox_batch_list'
:
bbox_batch_list
,
'structure_batch_list'
:
structure_batch_list
,
}
return
result
def
decode_label
(
self
,
batch
):
"""convert text-label into text-index.
"""
structure_idx
=
batch
[
1
]
gt_bbox_list
=
batch
[
2
]
shape_list
=
batch
[
-
1
]
ignored_tokens
=
self
.
get_ignored_tokens
()
end_idx
=
self
.
dict
[
self
.
end_str
]
structure_batch_list
=
[]
bbox_batch_list
=
[]
batch_size
=
len
(
structure_idx
)
for
batch_idx
in
range
(
batch_size
):
structure_list
=
[]
bbox_list
=
[]
for
idx
in
range
(
len
(
structure_idx
[
batch_idx
])):
char_idx
=
int
(
structure_idx
[
batch_idx
][
idx
])
if
idx
>
0
and
char_idx
==
end_idx
:
break
if
char_idx
in
ignored_tokens
:
continue
structure_list
.
append
(
self
.
character
[
char_idx
])
bbox
=
gt_bbox_list
[
batch_idx
][
idx
]
if
bbox
.
sum
()
!=
0
:
bbox
=
self
.
_bbox_decode
(
bbox
,
shape_list
[
batch_idx
])
bbox_list
.
append
(
bbox
)
structure_batch_list
.
append
(
structure_list
)
bbox_batch_list
.
append
(
bbox_list
)
result
=
{
'bbox_batch_list'
:
bbox_batch_list
,
'structure_batch_list'
:
structure_batch_list
,
}
return
result
def
_bbox_decode
(
self
,
bbox
,
shape
):
h
,
w
,
ratio_h
,
ratio_w
,
pad_h
,
pad_w
=
shape
src_h
=
h
/
ratio_h
src_w
=
w
/
ratio_w
bbox
[
0
::
2
]
*=
src_w
bbox
[
1
::
2
]
*=
src_h
return
bbox
class
TableMasterLabelDecode
(
TableLabelDecode
):
""" """
def
__init__
(
self
,
character_dict_path
,
box_shape
=
'ori'
,
**
kwargs
):
super
(
TableMasterLabelDecode
,
self
).
__init__
(
character_dict_path
)
self
.
box_shape
=
box_shape
assert
box_shape
in
[
'ori'
,
'pad'
],
'The shape used for box normalization must be ori or pad'
def
add_special_char
(
self
,
dict_character
):
self
.
beg_str
=
'<SOS>'
self
.
end_str
=
'<EOS>'
self
.
unknown_str
=
'<UKN>'
self
.
pad_str
=
'<PAD>'
dict_character
=
dict_character
dict_character
=
dict_character
+
[
self
.
unknown_str
,
self
.
beg_str
,
self
.
end_str
,
self
.
pad_str
]
return
dict_character
def
get_ignored_tokens
(
self
):
pad_idx
=
self
.
dict
[
self
.
pad_str
]
start_idx
=
self
.
dict
[
self
.
beg_str
]
end_idx
=
self
.
dict
[
self
.
end_str
]
unknown_idx
=
self
.
dict
[
self
.
unknown_str
]
return
[
start_idx
,
end_idx
,
pad_idx
,
unknown_idx
]
def
_bbox_decode
(
self
,
bbox
,
shape
):
h
,
w
,
ratio_h
,
ratio_w
,
pad_h
,
pad_w
=
shape
if
self
.
box_shape
==
'pad'
:
h
,
w
=
pad_h
,
pad_w
bbox
[
0
::
2
]
*=
w
bbox
[
1
::
2
]
*=
h
bbox
[
0
::
2
]
/=
ratio_w
bbox
[
1
::
2
]
/=
ratio_h
return
bbox
ppocr/utils/dict/table_master_structure_dict.txt
0 → 100644
浏览文件 @
a0c33908
<thead>
<tr>
<td></td>
</tr>
</thead>
<tbody>
<eb></eb>
</tbody>
<td
colspan="5"
>
</td>
colspan="2"
colspan="3"
<eb2></eb2>
<eb1></eb1>
rowspan="2"
colspan="4"
colspan="6"
rowspan="3"
colspan="9"
colspan="10"
colspan="7"
rowspan="4"
rowspan="5"
rowspan="9"
colspan="8"
rowspan="8"
rowspan="6"
rowspan="7"
rowspan="10"
<eb3></eb3>
<eb4></eb4>
<eb5></eb5>
<eb6></eb6>
<eb7></eb7>
<eb8></eb8>
<eb9></eb9>
<eb10></eb10>
ppocr/utils/dict/table_structure_dict.txt
浏览文件 @
a0c33908
此差异已折叠。
点击以展开。
ppstructure/table/predict_structure.py
浏览文件 @
a0c33908
...
@@ -23,6 +23,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
...
@@ -23,6 +23,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
import
time
import
time
import
json
import
tools.infer.utility
as
utility
import
tools.infer.utility
as
utility
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
...
@@ -34,32 +35,50 @@ from ppstructure.utility import parse_args
...
@@ -34,32 +35,50 @@ from ppstructure.utility import parse_args
logger
=
get_logger
()
logger
=
get_logger
()
class
TableStructurer
(
object
):
def
build_pre_process_list
(
args
):
def
__init__
(
self
,
args
):
resize_op
=
{
'ResizeTableImage'
:
{
'max_len'
:
args
.
table_max_len
,
}}
pre_process_list
=
[
{
pad_op
=
{
'Resize
TableImage'
:
{
'Padding
TableImage'
:
{
'max_len'
:
args
.
table_max_len
'size'
:
[
args
.
table_max_len
,
args
.
table_max_len
]
}
}
},
{
}
normalize_op
=
{
'NormalizeImage'
:
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
'std'
:
[
0.229
,
0.224
,
0.225
]
if
'mean'
:
[
0.485
,
0.456
,
0.406
],
args
.
table_algorithm
not
in
[
'TableMaster'
]
else
[
0.5
,
0.5
,
0.5
],
'mean'
:
[
0.485
,
0.456
,
0.406
]
if
args
.
table_algorithm
not
in
[
'TableMaster'
]
else
[
0.5
,
0.5
,
0.5
],
'scale'
:
'1./255.'
,
'scale'
:
'1./255.'
,
'order'
:
'hwc'
'order'
:
'hwc'
}
}
},
{
'PaddingTableImage'
:
None
},
{
'ToCHWImage'
:
None
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
]
}
}
}]
to_chw_op
=
{
'ToCHWImage'
:
None
}
keep_keys_op
=
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
,
'shape'
]}}
if
args
.
table_algorithm
not
in
[
'TableMaster'
]:
pre_process_list
=
[
resize_op
,
normalize_op
,
pad_op
,
to_chw_op
,
keep_keys_op
]
else
:
pre_process_list
=
[
resize_op
,
pad_op
,
normalize_op
,
to_chw_op
,
keep_keys_op
]
return
pre_process_list
class
TableStructurer
(
object
):
def
__init__
(
self
,
args
):
pre_process_list
=
build_pre_process_list
(
args
)
if
args
.
table_algorithm
not
in
[
'TableMaster'
]:
postprocess_params
=
{
postprocess_params
=
{
'name'
:
'TableLabelDecode'
,
'name'
:
'TableLabelDecode'
,
"character_dict_path"
:
args
.
table_char_dict_path
,
"character_dict_path"
:
args
.
table_char_dict_path
,
}
}
else
:
postprocess_params
=
{
'name'
:
'TableMasterLabelDecode'
,
"character_dict_path"
:
args
.
table_char_dict_path
,
'box_shape'
:
'pad'
}
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
...
@@ -88,27 +107,30 @@ class TableStructurer(object):
...
@@ -88,27 +107,30 @@ class TableStructurer(object):
preds
[
'structure_probs'
]
=
outputs
[
1
]
preds
[
'structure_probs'
]
=
outputs
[
1
]
preds
[
'loc_preds'
]
=
outputs
[
0
]
preds
[
'loc_preds'
]
=
outputs
[
0
]
post_result
=
self
.
postprocess_op
(
preds
)
shape_list
=
np
.
expand_dims
(
data
[
-
1
],
axis
=
0
)
post_result
=
self
.
postprocess_op
(
preds
,
[
shape_list
])
structure_str_list
=
post_result
[
'structure_str_list'
]
res_loc
=
post_result
[
'res_loc'
]
structure_str_list
=
post_result
[
'structure_batch_list'
][
0
]
imgh
,
imgw
=
ori_im
.
shape
[
0
:
2
]
bbox_list
=
post_result
[
'bbox_batch_list'
][
0
]
res_loc_final
=
[]
structure_str_list
=
structure_str_list
[
0
]
for
rno
in
range
(
len
(
res_loc
[
0
])):
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
structure_str_list
=
structure_str_list
[
0
][:
-
1
]
structure_str_list
=
[
structure_str_list
=
[
'<html>'
,
'<body>'
,
'<table>'
'<html>'
,
'<body>'
,
'<table>'
]
+
structure_str_list
+
[
'</table>'
,
'</body>'
,
'</html>'
]
]
+
structure_str_list
+
[
'</table>'
,
'</body>'
,
'</html>'
]
elapse
=
time
.
time
()
-
starttime
elapse
=
time
.
time
()
-
starttime
return
(
structure_str_list
,
res_loc_final
),
elapse
return
structure_str_list
,
bbox_list
,
elapse
def
draw_rectangle
(
img_path
,
boxes
,
use_xywh
=
False
):
img
=
cv2
.
imread
(
img_path
)
img_show
=
img
.
copy
()
for
box
in
boxes
.
astype
(
int
):
if
use_xywh
:
x
,
y
,
w
,
h
=
box
x1
,
y1
,
x2
,
y2
=
x
-
w
//
2
,
y
-
h
//
2
,
x
+
w
//
2
,
y
+
h
//
2
else
:
x1
,
y1
,
x2
,
y2
=
box
cv2
.
rectangle
(
img_show
,
(
x1
,
y1
),
(
x2
,
y2
),
(
255
,
0
,
0
),
2
)
return
img_show
def
main
(
args
):
def
main
(
args
):
...
@@ -116,6 +138,11 @@ def main(args):
...
@@ -116,6 +138,11 @@ def main(args):
table_structurer
=
TableStructurer
(
args
)
table_structurer
=
TableStructurer
(
args
)
count
=
0
count
=
0
total_time
=
0
total_time
=
0
use_xywh
=
args
.
table_algorithm
in
[
'TableMaster'
]
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
args
.
output
,
'infer.txt'
),
mode
=
'w'
,
encoding
=
'utf-8'
)
as
f_w
:
for
image_file
in
image_file_list
:
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
if
not
flag
:
...
@@ -123,10 +150,19 @@ def main(args):
...
@@ -123,10 +150,19 @@ def main(args):
if
img
is
None
:
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
continue
structure_res
,
elapse
=
table_structurer
(
img
)
structure_str_list
,
bbox_list
,
elapse
=
table_structurer
(
img
)
logger
.
info
(
"result: {}"
.
format
(
structure_res
))
bbox_list_str
=
json
.
dumps
(
bbox_list
.
tolist
())
logger
.
info
(
"result: {}, {}"
.
format
(
structure_str_list
,
bbox_list_str
))
f_w
.
write
(
"result: {}, {}
\n
"
.
format
(
structure_str_list
,
bbox_list_str
))
img
=
draw_rectangle
(
image_file
,
bbox_list
,
use_xywh
)
img_save_path
=
os
.
path
.
join
(
args
.
output
,
os
.
path
.
basename
(
image_file
))
cv2
.
imwrite
(
img_save_path
,
img
)
logger
.
info
(
"save vis result to {}"
.
format
(
img_save_path
))
if
count
>
0
:
if
count
>
0
:
total_time
+=
elapse
total_time
+=
elapse
count
+=
1
count
+=
1
...
...
ppstructure/utility.py
浏览文件 @
a0c33908
...
@@ -25,6 +25,7 @@ def init_args():
...
@@ -25,6 +25,7 @@ def init_args():
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'./output'
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'./output'
)
# params for table structure
# params for table structure
parser
.
add_argument
(
"--table_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--table_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--table_algorithm"
,
type
=
str
,
default
=
'TableAttn'
)
parser
.
add_argument
(
"--table_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--table_model_dir"
,
type
=
str
)
parser
.
add_argument
(
parser
.
add_argument
(
"--table_char_dict_path"
,
"--table_char_dict_path"
,
...
...
tools/export_model.py
浏览文件 @
a0c33908
...
@@ -88,6 +88,8 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
...
@@ -88,6 +88,8 @@ def export_single_model(model, arch_config, save_path, logger, quanter=None):
infer_shape
=
[
1
,
32
,
100
]
infer_shape
=
[
1
,
32
,
100
]
elif
arch_config
[
"model_type"
]
==
"table"
:
elif
arch_config
[
"model_type"
]
==
"table"
:
infer_shape
=
[
3
,
488
,
488
]
infer_shape
=
[
3
,
488
,
488
]
if
arch_config
[
"algorithm"
]
==
"TableMaster"
:
infer_shape
=
[
3
,
480
,
480
]
model
=
to_static
(
model
=
to_static
(
model
,
model
,
input_spec
=
[
input_spec
=
[
...
...
tools/infer_table.py
浏览文件 @
a0c33908
...
@@ -40,6 +40,7 @@ import tools.program as program
...
@@ -40,6 +40,7 @@ import tools.program as program
import
cv2
import
cv2
@
paddle
.
no_grad
()
def
main
(
config
,
device
,
logger
,
vdl_writer
):
def
main
(
config
,
device
,
logger
,
vdl_writer
):
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
...
@@ -53,27 +54,31 @@ def main(config, device, logger, vdl_writer):
...
@@ -53,27 +54,31 @@ def main(config, device, logger, vdl_writer):
getattr
(
post_process_class
,
'character'
))
getattr
(
post_process_class
,
'character'
))
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
algorithm
=
config
[
'Architecture'
][
'algorithm'
]
use_xywh
=
algorithm
in
[
'TableMaster'
]
load_model
(
config
,
model
)
load_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
use_padding
=
False
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
for
op
in
config
[
'Eval'
][
'dataset'
][
'transforms'
]:
op_name
=
list
(
op
)[
0
]
op_name
=
list
(
op
)[
0
]
if
'
Label
'
in
op_name
:
if
'
Encode
'
in
op_name
:
continue
continue
if
op_name
==
'KeepKeys'
:
if
op_name
==
'KeepKeys'
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
op
[
op_name
][
'keep_keys'
]
=
[
'image'
,
'shape'
]
if
op_name
==
"ResizeTableImage"
:
use_padding
=
True
padding_max_len
=
op
[
'ResizeTableImage'
][
'max_len'
]
transforms
.
append
(
op
)
transforms
.
append
(
op
)
global_config
[
'infer_mode'
]
=
True
global_config
[
'infer_mode'
]
=
True
ops
=
create_operators
(
transforms
,
global_config
)
ops
=
create_operators
(
transforms
,
global_config
)
save_res_path
=
config
[
'Global'
][
'save_res_path'
]
os
.
makedirs
(
save_res_path
,
exist_ok
=
True
)
model
.
eval
()
model
.
eval
()
with
open
(
os
.
path
.
join
(
save_res_path
,
'infer.txt'
),
mode
=
'w'
,
encoding
=
'utf-8'
)
as
f_w
:
for
file
in
get_image_file_list
(
config
[
'Global'
][
'infer_img'
]):
for
file
in
get_image_file_list
(
config
[
'Global'
][
'infer_img'
]):
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
logger
.
info
(
"infer_img: {}"
.
format
(
file
))
with
open
(
file
,
'rb'
)
as
f
:
with
open
(
file
,
'rb'
)
as
f
:
...
@@ -81,27 +86,44 @@ def main(config, device, logger, vdl_writer):
...
@@ -81,27 +86,44 @@ def main(config, device, logger, vdl_writer):
data
=
{
'image'
:
img
}
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
batch
=
transform
(
data
,
ops
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
shape_list
=
np
.
expand_dims
(
batch
[
1
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
)
post_result
=
post_process_class
(
preds
,
[
shape_list
])
res_html_code
=
post_result
[
'res_html_code'
]
res_loc
=
post_result
[
'res_loc'
]
structure_str_list
=
post_result
[
'structure_batch_list'
][
0
]
img
=
cv2
.
imread
(
file
)
bbox_list
=
post_result
[
'bbox_batch_list'
][
0
]
imgh
,
imgw
=
img
.
shape
[
0
:
2
]
structure_str_list
=
structure_str_list
[
0
]
res_loc_final
=
[]
structure_str_list
=
[
for
rno
in
range
(
len
(
res_loc
[
0
])):
'<html>'
,
'<body>'
,
'<table>'
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
]
+
structure_str_list
+
[
'</table>'
,
'</body>'
,
'</html>'
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
bbox_list_str
=
json
.
dumps
(
bbox_list
.
tolist
())
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
logger
.
info
(
"result: {}, {}"
.
format
(
structure_str_list
,
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
bbox_list_str
))
cv2
.
rectangle
(
img
,
(
left
,
top
),
(
right
,
bottom
),
(
0
,
0
,
255
),
2
)
f_w
.
write
(
"result: {}, {}
\n
"
.
format
(
structure_str_list
,
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
bbox_list_str
))
res_loc_str
=
json
.
dumps
(
res_loc_final
)
logger
.
info
(
"result: {}, {}"
.
format
(
res_html_code
,
res_loc_final
))
img
=
draw_rectangle
(
file
,
bbox_list
,
use_xywh
)
cv2
.
imwrite
(
os
.
path
.
join
(
save_res_path
,
os
.
path
.
basename
(
file
)),
img
)
logger
.
info
(
"success!"
)
logger
.
info
(
"success!"
)
def
draw_rectangle
(
img_path
,
boxes
,
use_xywh
=
False
):
img
=
cv2
.
imread
(
img_path
)
img_show
=
img
.
copy
()
for
box
in
boxes
.
astype
(
int
):
if
use_xywh
:
x
,
y
,
w
,
h
=
box
x1
,
y1
,
x2
,
y2
=
x
-
w
//
2
,
y
-
h
//
2
,
x
+
w
//
2
,
y
+
h
//
2
else
:
x1
,
y1
,
x2
,
y2
=
box
cv2
.
rectangle
(
img_show
,
(
x1
,
y1
),
(
x2
,
y2
),
(
255
,
0
,
0
),
2
)
return
img_show
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
(
config
,
device
,
logger
,
vdl_writer
)
main
(
config
,
device
,
logger
,
vdl_writer
)
tools/program.py
浏览文件 @
a0c33908
...
@@ -274,8 +274,11 @@ def train(config,
...
@@ -274,8 +274,11 @@ def train(config,
if
cal_metric_during_train
and
epoch
%
calc_epoch_interval
==
0
:
# only rec and cls need
if
cal_metric_during_train
and
epoch
%
calc_epoch_interval
==
0
:
# only rec and cls need
batch
=
[
item
.
numpy
()
for
item
in
batch
]
batch
=
[
item
.
numpy
()
for
item
in
batch
]
if
model_type
in
[
'
table'
,
'
kie'
]:
if
model_type
in
[
'kie'
]:
eval_class
(
preds
,
batch
)
eval_class
(
preds
,
batch
)
elif
model_type
in
[
'table'
]:
post_result
=
post_process_class
(
preds
,
batch
)
eval_class
(
post_result
,
batch
)
else
:
else
:
if
config
[
'Loss'
][
'name'
]
in
[
'MultiLoss'
,
'MultiLoss_v2'
if
config
[
'Loss'
][
'name'
]
in
[
'MultiLoss'
,
'MultiLoss_v2'
]:
# for multi head loss
]:
# for multi head loss
...
@@ -302,7 +305,8 @@ def train(config,
...
@@ -302,7 +305,8 @@ def train(config,
train_stats
.
update
(
stats
)
train_stats
.
update
(
stats
)
if
log_writer
is
not
None
and
dist
.
get_rank
()
==
0
:
if
log_writer
is
not
None
and
dist
.
get_rank
()
==
0
:
log_writer
.
log_metrics
(
metrics
=
train_stats
.
get
(),
prefix
=
"TRAIN"
,
step
=
global_step
)
log_writer
.
log_metrics
(
metrics
=
train_stats
.
get
(),
prefix
=
"TRAIN"
,
step
=
global_step
)
if
dist
.
get_rank
()
==
0
and
(
if
dist
.
get_rank
()
==
0
and
(
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
(
global_step
>
0
and
global_step
%
print_batch_step
==
0
)
or
...
@@ -349,7 +353,8 @@ def train(config,
...
@@ -349,7 +353,8 @@ def train(config,
# logger metric
# logger metric
if
log_writer
is
not
None
:
if
log_writer
is
not
None
:
log_writer
.
log_metrics
(
metrics
=
cur_metric
,
prefix
=
"EVAL"
,
step
=
global_step
)
log_writer
.
log_metrics
(
metrics
=
cur_metric
,
prefix
=
"EVAL"
,
step
=
global_step
)
if
cur_metric
[
main_indicator
]
>=
best_model_dict
[
if
cur_metric
[
main_indicator
]
>=
best_model_dict
[
main_indicator
]:
main_indicator
]:
...
@@ -372,11 +377,18 @@ def train(config,
...
@@ -372,11 +377,18 @@ def train(config,
logger
.
info
(
best_str
)
logger
.
info
(
best_str
)
# logger best metric
# logger best metric
if
log_writer
is
not
None
:
if
log_writer
is
not
None
:
log_writer
.
log_metrics
(
metrics
=
{
log_writer
.
log_metrics
(
"best_{}"
.
format
(
main_indicator
):
best_model_dict
[
main_indicator
]
metrics
=
{
},
prefix
=
"EVAL"
,
step
=
global_step
)
"best_{}"
.
format
(
main_indicator
):
best_model_dict
[
main_indicator
]
log_writer
.
log_model
(
is_best
=
True
,
prefix
=
"best_accuracy"
,
metadata
=
best_model_dict
)
},
prefix
=
"EVAL"
,
step
=
global_step
)
log_writer
.
log_model
(
is_best
=
True
,
prefix
=
"best_accuracy"
,
metadata
=
best_model_dict
)
reader_start
=
time
.
time
()
reader_start
=
time
.
time
()
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
...
@@ -408,7 +420,8 @@ def train(config,
...
@@ -408,7 +420,8 @@ def train(config,
epoch
=
epoch
,
epoch
=
epoch
,
global_step
=
global_step
)
global_step
=
global_step
)
if
log_writer
is
not
None
:
if
log_writer
is
not
None
:
log_writer
.
log_model
(
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
))
log_writer
.
log_model
(
is_best
=
False
,
prefix
=
'iter_epoch_{}'
.
format
(
epoch
))
best_str
=
'best metric, {}'
.
format
(
', '
.
join
(
best_str
=
'best metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
best_model_dict
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
best_model_dict
.
items
()]))
...
@@ -446,7 +459,6 @@ def eval(model,
...
@@ -446,7 +459,6 @@ def eval(model,
preds
=
model
(
batch
)
preds
=
model
(
batch
)
else
:
else
:
preds
=
model
(
images
)
preds
=
model
(
images
)
batch_numpy
=
[]
batch_numpy
=
[]
for
item
in
batch
:
for
item
in
batch
:
if
isinstance
(
item
,
paddle
.
Tensor
):
if
isinstance
(
item
,
paddle
.
Tensor
):
...
@@ -456,9 +468,9 @@ def eval(model,
...
@@ -456,9 +468,9 @@ def eval(model,
# Obtain usable results from post-processing methods
# Obtain usable results from post-processing methods
total_time
+=
time
.
time
()
-
start
total_time
+=
time
.
time
()
-
start
# Evaluate the results of the current batch
# Evaluate the results of the current batch
if
model_type
in
[
'
table'
,
'
kie'
]:
if
model_type
in
[
'kie'
]:
eval_class
(
preds
,
batch_numpy
)
eval_class
(
preds
,
batch_numpy
)
elif
model_type
in
[
'vqa'
]:
elif
model_type
in
[
'
table'
,
'
vqa'
]:
post_result
=
post_process_class
(
preds
,
batch_numpy
)
post_result
=
post_process_class
(
preds
,
batch_numpy
)
eval_class
(
post_result
,
batch_numpy
)
eval_class
(
post_result
,
batch_numpy
)
else
:
else
:
...
@@ -559,7 +571,8 @@ def preprocess(is_train=False):
...
@@ -559,7 +571,8 @@ def preprocess(is_train=False):
assert
alg
in
[
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'CLS'
,
'PGNet'
,
'Distillation'
,
'NRTR'
,
'TableAttn'
,
'SAR'
,
'PSE'
,
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
'SEED'
,
'SDMGR'
,
'LayoutXLM'
,
'LayoutLM'
,
'PREN'
,
'FCE'
,
'SVTR'
,
'TableMaster'
]
]
device
=
'cpu'
device
=
'cpu'
...
@@ -578,7 +591,8 @@ def preprocess(is_train=False):
...
@@ -578,7 +591,8 @@ def preprocess(is_train=False):
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
vdl_writer_path
=
'{}/vdl/'
.
format
(
save_model_dir
)
log_writer
=
VDLLogger
(
save_model_dir
)
log_writer
=
VDLLogger
(
save_model_dir
)
loggers
.
append
(
log_writer
)
loggers
.
append
(
log_writer
)
if
(
'use_wandb'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_wandb'
])
or
'wandb'
in
config
:
if
(
'use_wandb'
in
config
[
'Global'
]
and
config
[
'Global'
][
'use_wandb'
])
or
'wandb'
in
config
:
save_dir
=
config
[
'Global'
][
'save_model_dir'
]
save_dir
=
config
[
'Global'
][
'save_model_dir'
]
wandb_writer_path
=
"{}/wandb"
.
format
(
save_dir
)
wandb_writer_path
=
"{}/wandb"
.
format
(
save_dir
)
if
"wandb"
in
config
:
if
"wandb"
in
config
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录