Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
16c247ac
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16c247ac
编写于
6月 21, 2021
作者:
M
MissPenguin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine
上级
7c8b2c8d
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
40 addition
and
101 deletion
+40
-101
configs/table/table_mv3.yml
configs/table/table_mv3.yml
+12
-12
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+0
-20
ppocr/data/pubtab_dataset.py
ppocr/data/pubtab_dataset.py
+2
-20
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+1
-1
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+10
-12
ppocr/modeling/necks/table_fpn.py
ppocr/modeling/necks/table_fpn.py
+10
-19
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+0
-12
tools/export_model.py
tools/export_model.py
+2
-1
tools/infer_table.py
tools/infer_table.py
+1
-3
tools/program.py
tools/program.py
+2
-1
未找到文件。
configs/table/table_mv3.yml
浏览文件 @
16c247ac
Global
:
use_gpu
:
true
epoch_num
:
4
0
epoch_num
:
5
0
log_smooth_window
:
20
print_batch_step
:
5
save_model_dir
:
./output/table_mv3/
save_epoch_step
:
3
# evaluation is run every
5000 iterations after the 400
0th iteration
save_epoch_step
:
5
# evaluation is run every
400 iterations after the
0th iteration
eval_batch_step
:
[
0
,
400
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
pretrained_model
:
checkpoints
:
...
...
@@ -18,19 +17,20 @@ Global:
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_type
:
en
max_text_length
:
100
max_elem_length
:
8
00
max_elem_length
:
5
00
max_cell_num
:
500
infer_mode
:
False
process_total_num
:
0
process_cut_num
:
0
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
clip_norm
:
5.0
lr
:
learning_rate
:
0.00
0
1
learning_rate
:
0.001
regularizer
:
name
:
'
L2'
factor
:
0.00000
...
...
@@ -41,12 +41,12 @@ Architecture:
Backbone
:
name
:
MobileNetV3
scale
:
1.0
model_name
:
large
model_name
:
small
disable_se
:
True
Head
:
name
:
TableAttentionHead
# AttentionHead
hidden_size
:
256
#
name
:
TableAttentionHead
hidden_size
:
256
l2_decay
:
0.00001
# loc_type: 1
loc_type
:
2
Loss
:
...
...
@@ -86,7 +86,7 @@ Train:
shuffle
:
True
batch_size_per_card
:
32
drop_last
:
True
num_workers
:
4
num_workers
:
1
Eval
:
dataset
:
...
...
@@ -113,4 +113,4 @@ Eval:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
16
num_workers
:
4
num_workers
:
1
ppocr/data/imaug/label_ops.py
浏览文件 @
16c247ac
...
...
@@ -412,7 +412,6 @@ class TableLabelEncode(object):
return
None
elem_num
=
len
(
structure
)
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
# structure = [0] + structure + [0]
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
))
structure
=
np
.
array
(
structure
)
data
[
'structure'
]
=
structure
...
...
@@ -443,8 +442,6 @@ class TableLabelEncode(object):
if
cand_span_idx
<
(
self
.
max_elem_length
+
2
):
if
structure
[
cand_span_idx
]
in
span_idx_list
:
structure_mask
[
cand_span_idx
]
=
span_weight
# structure_mask[td_idx] = self.span_weight
# structure_mask[cand_span_idx] = self.span_weight
data
[
'bbox_list'
]
=
bbox_list
data
[
'bbox_list_mask'
]
=
bbox_list_mask
...
...
@@ -458,23 +455,6 @@ class TableLabelEncode(object):
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
return
data
########
# for char decode
# cell_list = []
# for cell in cells:
# char_list = cell['tokens']
# cell = self.encode(char_list, 'char')
# if cell is None:
# return None
# cell = [0] + cell + [len(self.dict_character) - 1]
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
# cell_list.append(cell)
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
# cell_list = np.array(cell_list)
# cell_list_padding[0:cell_list.shape[0]] = cell_list
# data['cells'] = cell_list_padding
# return data
def
encode
(
self
,
text
,
char_or_elem
):
"""convert text-label into text-index.
"""
...
...
ppocr/data/pubtab_dataset.py
浏览文件 @
16c247ac
# copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,6 +19,7 @@ import json
from
.imaug
import
transform
,
create_operators
class
PubTabDataSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
super
(
PubTabDataSet
,
self
).
__init__
()
...
...
@@ -57,23 +58,6 @@ class PubTabDataSet(Dataset):
random
.
seed
(
self
.
seed
)
random
.
shuffle
(
self
.
data_lines
)
return
def
load_hard_select_prob
(
self
):
label_path
=
"./pretrained_model/teds_score_exp5_st2_train.txt"
img_select_prob
=
{}
with
open
(
label_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
lno
in
range
(
len
(
lines
)):
substr
=
lines
[
lno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
split
(
" "
)
img_name
=
substr
[
0
].
strip
(
":"
)
score
=
float
(
substr
[
1
])
if
score
<=
0.8
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
0
]
elif
score
<=
0.98
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
1
]
else
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
2
]
return
img_select_prob
def
__getitem__
(
self
,
idx
):
try
:
...
...
@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
table_type
=
"simple"
if
'colspan'
in
structure_str
or
'rowspan'
in
structure_str
:
table_type
=
"complex"
# if self.table_select_type != table_type:
# select_flag = False
if
table_type
==
"complex"
:
if
self
.
table_select_prob
<
random
.
uniform
(
0
,
1
):
select_flag
=
False
...
...
ppocr/modeling/architectures/base_model.py
浏览文件 @
16c247ac
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
ppocr/modeling/heads/table_att_head.py
浏览文件 @
16c247ac
...
...
@@ -21,13 +21,16 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
import
numpy
as
np
class
TableAttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
char_num
=
280
self
.
elem_num
=
30
self
.
max_text_length
=
100
self
.
max_elem_length
=
500
self
.
max_cell_num
=
500
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
...
...
@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
else
:
if
self
.
in_max_len
==
640
:
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_elem_length
+
1
)
elif
self
.
in_max_len
==
800
:
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_elem_length
+
1
)
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_elem_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
...
...
@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
last_shape
])
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
batch_size
=
fea
.
shape
[
0
]
#sp_tokens = targets[2].numpy()
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
max_text_length
,
max_elem_length
,
max_cell_num
=
100
,
800
,
500
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
if
mode
==
'Train'
and
targets
is
not
None
:
structure
=
targets
[
0
]
for
i
in
range
(
max_elem_length
+
1
):
for
i
in
range
(
self
.
max_elem_length
+
1
):
elem_onehots
=
self
.
_char_to_onehot
(
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
...
@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
elem_onehots
=
None
outputs
=
None
alpha
=
None
max_elem_length
=
paddle
.
to_tensor
(
max_elem_length
)
max_elem_length
=
paddle
.
to_tensor
(
self
.
max_elem_length
)
i
=
0
while
i
<
max_elem_length
+
1
:
elem_onehots
=
self
.
_char_to_onehot
(
...
...
@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
loc_preds
=
F
.
sigmoid
(
loc_preds
)
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
class
AttentionGRUCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionGRUCell
,
self
).
__init__
()
...
...
ppocr/modeling/necks/table_fpn.py
浏览文件 @
16c247ac
# copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 20
21
PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
in_channels
=
in_channels
[
0
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_51.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
in3_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
1
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
stride
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_50.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
in4_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
2
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_49.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
in5_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
3
],
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_48.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p5_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_52.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p4_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_53.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p3_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_54.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
p2_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_55.w_0'
,
initializer
=
weight_attr
),
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
self
.
fuse_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
*
4
,
out_channels
=
512
,
kernel_size
=
3
,
padding
=
1
,
weight_attr
=
ParamAttr
(
name
=
'conv2d_fuse.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
def
forward
(
self
,
x
):
c2
,
c3
,
c4
,
c5
=
x
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
16c247ac
...
...
@@ -368,18 +368,6 @@ class TableLabelDecode(object):
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
get_sp_tokens
(
self
):
char_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'char'
)
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
sp_tokens
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
])
return
sp_tokens
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
...
...
tools/export_model.py
浏览文件 @
16c247ac
...
...
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape
[
-
1
]
=
100
elif
arch_config
[
"model_type"
]
==
"table"
:
infer_shape
=
[
3
,
488
,
488
]
model
=
to_static
(
model
,
input_spec
=
[
...
...
tools/infer_table.py
浏览文件 @
16c247ac
...
...
@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
img
=
f
.
read
()
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
sp_tokens
=
post_process_class
.
get_sp_tokens
()
targets
=
[[],
[],
paddle
.
to_tensor
([
sp_tokens
])]
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
,
data
=
targets
,
mode
=
'Test'
)
preds
=
model
(
images
,
data
=
None
,
mode
=
'Test'
)
post_result
=
post_process_class
(
preds
)
res_html_code
=
post_result
[
'res_html_code'
]
res_loc
=
post_result
[
'res_loc'
]
...
...
tools/program.py
浏览文件 @
16c247ac
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -276,6 +276,7 @@ def train(config,
valid_dataloader
,
post_process_class
,
eval_class
,
"table"
,
use_srn
=
use_srn
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录