Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
16c247ac
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16c247ac
编写于
6月 21, 2021
作者:
M
MissPenguin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine
上级
7c8b2c8d
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
40 addition
and
101 deletion
+40
-101
configs/table/table_mv3.yml
configs/table/table_mv3.yml
+12
-12
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+0
-20
ppocr/data/pubtab_dataset.py
ppocr/data/pubtab_dataset.py
+2
-20
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+1
-1
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+10
-12
ppocr/modeling/necks/table_fpn.py
ppocr/modeling/necks/table_fpn.py
+10
-19
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+0
-12
tools/export_model.py
tools/export_model.py
+2
-1
tools/infer_table.py
tools/infer_table.py
+1
-3
tools/program.py
tools/program.py
+2
-1
未找到文件。
configs/table/table_mv3.yml
浏览文件 @
16c247ac
Global
:
Global
:
use_gpu
:
true
use_gpu
:
true
epoch_num
:
4
0
epoch_num
:
5
0
log_smooth_window
:
20
log_smooth_window
:
20
print_batch_step
:
5
print_batch_step
:
5
save_model_dir
:
./output/table_mv3/
save_model_dir
:
./output/table_mv3/
save_epoch_step
:
3
save_epoch_step
:
5
# evaluation is run every
5000 iterations after the 400
0th iteration
# evaluation is run every
400 iterations after the
0th iteration
eval_batch_step
:
[
0
,
400
]
eval_batch_step
:
[
0
,
400
]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train
:
True
cal_metric_during_train
:
True
pretrained_model
:
pretrained_model
:
checkpoints
:
checkpoints
:
...
@@ -18,19 +17,20 @@ Global:
...
@@ -18,19 +17,20 @@ Global:
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_dict_path
:
ppocr/utils/dict/table_structure_dict.txt
character_type
:
en
character_type
:
en
max_text_length
:
100
max_text_length
:
100
max_elem_length
:
8
00
max_elem_length
:
5
00
max_cell_num
:
500
max_cell_num
:
500
infer_mode
:
False
infer_mode
:
False
process_total_num
:
0
process_total_num
:
0
process_cut_num
:
0
process_cut_num
:
0
Optimizer
:
Optimizer
:
name
:
Adam
name
:
Adam
beta1
:
0.9
beta1
:
0.9
beta2
:
0.999
beta2
:
0.999
clip_norm
:
5.0
clip_norm
:
5.0
lr
:
lr
:
learning_rate
:
0.00
0
1
learning_rate
:
0.001
regularizer
:
regularizer
:
name
:
'
L2'
name
:
'
L2'
factor
:
0.00000
factor
:
0.00000
...
@@ -41,12 +41,12 @@ Architecture:
...
@@ -41,12 +41,12 @@ Architecture:
Backbone
:
Backbone
:
name
:
MobileNetV3
name
:
MobileNetV3
scale
:
1.0
scale
:
1.0
model_name
:
large
model_name
:
small
disable_se
:
True
Head
:
Head
:
name
:
TableAttentionHead
# AttentionHead
name
:
TableAttentionHead
hidden_size
:
256
#
hidden_size
:
256
l2_decay
:
0.00001
l2_decay
:
0.00001
# loc_type: 1
loc_type
:
2
loc_type
:
2
Loss
:
Loss
:
...
@@ -86,7 +86,7 @@ Train:
...
@@ -86,7 +86,7 @@ Train:
shuffle
:
True
shuffle
:
True
batch_size_per_card
:
32
batch_size_per_card
:
32
drop_last
:
True
drop_last
:
True
num_workers
:
4
num_workers
:
1
Eval
:
Eval
:
dataset
:
dataset
:
...
@@ -113,4 +113,4 @@ Eval:
...
@@ -113,4 +113,4 @@ Eval:
shuffle
:
False
shuffle
:
False
drop_last
:
False
drop_last
:
False
batch_size_per_card
:
16
batch_size_per_card
:
16
num_workers
:
4
num_workers
:
1
ppocr/data/imaug/label_ops.py
浏览文件 @
16c247ac
...
@@ -412,7 +412,6 @@ class TableLabelEncode(object):
...
@@ -412,7 +412,6 @@ class TableLabelEncode(object):
return
None
return
None
elem_num
=
len
(
structure
)
elem_num
=
len
(
structure
)
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
structure
=
[
0
]
+
structure
+
[
len
(
self
.
dict_elem
)
-
1
]
# structure = [0] + structure + [0]
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
))
structure
=
structure
+
[
0
]
*
(
self
.
max_elem_length
+
2
-
len
(
structure
))
structure
=
np
.
array
(
structure
)
structure
=
np
.
array
(
structure
)
data
[
'structure'
]
=
structure
data
[
'structure'
]
=
structure
...
@@ -443,8 +442,6 @@ class TableLabelEncode(object):
...
@@ -443,8 +442,6 @@ class TableLabelEncode(object):
if
cand_span_idx
<
(
self
.
max_elem_length
+
2
):
if
cand_span_idx
<
(
self
.
max_elem_length
+
2
):
if
structure
[
cand_span_idx
]
in
span_idx_list
:
if
structure
[
cand_span_idx
]
in
span_idx_list
:
structure_mask
[
cand_span_idx
]
=
span_weight
structure_mask
[
cand_span_idx
]
=
span_weight
# structure_mask[td_idx] = self.span_weight
# structure_mask[cand_span_idx] = self.span_weight
data
[
'bbox_list'
]
=
bbox_list
data
[
'bbox_list'
]
=
bbox_list
data
[
'bbox_list_mask'
]
=
bbox_list_mask
data
[
'bbox_list_mask'
]
=
bbox_list_mask
...
@@ -458,23 +455,6 @@ class TableLabelEncode(object):
...
@@ -458,23 +455,6 @@ class TableLabelEncode(object):
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
self
.
max_elem_length
,
self
.
max_cell_num
,
elem_num
])
return
data
return
data
########
# for char decode
# cell_list = []
# for cell in cells:
# char_list = cell['tokens']
# cell = self.encode(char_list, 'char')
# if cell is None:
# return None
# cell = [0] + cell + [len(self.dict_character) - 1]
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
# cell_list.append(cell)
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
# cell_list = np.array(cell_list)
# cell_list_padding[0:cell_list.shape[0]] = cell_list
# data['cells'] = cell_list_padding
# return data
def
encode
(
self
,
text
,
char_or_elem
):
def
encode
(
self
,
text
,
char_or_elem
):
"""convert text-label into text-index.
"""convert text-label into text-index.
"""
"""
...
...
ppocr/data/pubtab_dataset.py
浏览文件 @
16c247ac
# copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -19,6 +19,7 @@ import json
...
@@ -19,6 +19,7 @@ import json
from
.imaug
import
transform
,
create_operators
from
.imaug
import
transform
,
create_operators
class
PubTabDataSet
(
Dataset
):
class
PubTabDataSet
(
Dataset
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
def
__init__
(
self
,
config
,
mode
,
logger
,
seed
=
None
):
super
(
PubTabDataSet
,
self
).
__init__
()
super
(
PubTabDataSet
,
self
).
__init__
()
...
@@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
...
@@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
random
.
shuffle
(
self
.
data_lines
)
random
.
shuffle
(
self
.
data_lines
)
return
return
def
load_hard_select_prob
(
self
):
label_path
=
"./pretrained_model/teds_score_exp5_st2_train.txt"
img_select_prob
=
{}
with
open
(
label_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
lno
in
range
(
len
(
lines
)):
substr
=
lines
[
lno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
split
(
" "
)
img_name
=
substr
[
0
].
strip
(
":"
)
score
=
float
(
substr
[
1
])
if
score
<=
0.8
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
0
]
elif
score
<=
0.98
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
1
]
else
:
img_select_prob
[
img_name
]
=
self
.
hard_prob
[
2
]
return
img_select_prob
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
try
:
try
:
data_line
=
self
.
data_lines
[
idx
]
data_line
=
self
.
data_lines
[
idx
]
...
@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
...
@@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
table_type
=
"simple"
table_type
=
"simple"
if
'colspan'
in
structure_str
or
'rowspan'
in
structure_str
:
if
'colspan'
in
structure_str
or
'rowspan'
in
structure_str
:
table_type
=
"complex"
table_type
=
"complex"
# if self.table_select_type != table_type:
# select_flag = False
if
table_type
==
"complex"
:
if
table_type
==
"complex"
:
if
self
.
table_select_prob
<
random
.
uniform
(
0
,
1
):
if
self
.
table_select_prob
<
random
.
uniform
(
0
,
1
):
select_flag
=
False
select_flag
=
False
...
...
ppocr/modeling/architectures/base_model.py
浏览文件 @
16c247ac
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
...
ppocr/modeling/heads/table_att_head.py
浏览文件 @
16c247ac
...
@@ -21,13 +21,16 @@ import paddle.nn as nn
...
@@ -21,13 +21,16 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
class
TableAttentionHead
(
nn
.
Layer
):
class
TableAttentionHead
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
hidden_size
,
loc_type
,
in_max_len
=
488
,
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
input_size
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
char_num
=
280
self
.
elem_num
=
30
self
.
elem_num
=
30
self
.
max_text_length
=
100
self
.
max_elem_length
=
500
self
.
max_cell_num
=
500
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
...
@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
...
@@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
self
.
loc_generator
=
nn
.
Linear
(
hidden_size
,
4
)
else
:
else
:
if
self
.
in_max_len
==
640
:
if
self
.
in_max_len
==
640
:
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
400
,
self
.
max_elem_length
+
1
)
elif
self
.
in_max_len
==
800
:
elif
self
.
in_max_len
==
800
:
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_elem_length
+
1
)
else
:
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
80
1
)
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_elem_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
...
@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
...
@@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
last_shape
])
fea
=
paddle
.
reshape
(
fea
,
[
fea
.
shape
[
0
],
fea
.
shape
[
1
],
last_shape
])
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
fea
=
fea
.
transpose
([
0
,
2
,
1
])
# (NTC)(batch, width, channels)
batch_size
=
fea
.
shape
[
0
]
batch_size
=
fea
.
shape
[
0
]
#sp_tokens = targets[2].numpy()
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
max_text_length
,
max_elem_length
,
max_cell_num
=
100
,
800
,
500
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
hidden
=
paddle
.
zeros
((
batch_size
,
self
.
hidden_size
))
output_hiddens
=
[]
output_hiddens
=
[]
if
mode
==
'Train'
and
targets
is
not
None
:
if
mode
==
'Train'
and
targets
is
not
None
:
structure
=
targets
[
0
]
structure
=
targets
[
0
]
for
i
in
range
(
max_elem_length
+
1
):
for
i
in
range
(
self
.
max_elem_length
+
1
):
elem_onehots
=
self
.
_char_to_onehot
(
elem_onehots
=
self
.
_char_to_onehot
(
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
...
@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
...
@@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
elem_onehots
=
None
elem_onehots
=
None
outputs
=
None
outputs
=
None
alpha
=
None
alpha
=
None
max_elem_length
=
paddle
.
to_tensor
(
max_elem_length
)
max_elem_length
=
paddle
.
to_tensor
(
self
.
max_elem_length
)
i
=
0
i
=
0
while
i
<
max_elem_length
+
1
:
while
i
<
max_elem_length
+
1
:
elem_onehots
=
self
.
_char_to_onehot
(
elem_onehots
=
self
.
_char_to_onehot
(
...
@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
...
@@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
loc_preds
=
F
.
sigmoid
(
loc_preds
)
loc_preds
=
F
.
sigmoid
(
loc_preds
)
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
return
{
'structure_probs'
:
structure_probs
,
'loc_preds'
:
loc_preds
}
class
AttentionGRUCell
(
nn
.
Layer
):
class
AttentionGRUCell
(
nn
.
Layer
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
def
__init__
(
self
,
input_size
,
hidden_size
,
num_embeddings
,
use_gru
=
False
):
super
(
AttentionGRUCell
,
self
).
__init__
()
super
(
AttentionGRUCell
,
self
).
__init__
()
...
...
ppocr/modeling/necks/table_fpn.py
浏览文件 @
16c247ac
# copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserve.
# copyright (c) 20
21
PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
...
@@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
in_channels
=
in_channels
[
0
],
in_channels
=
in_channels
[
0
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_51.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
in3_conv
=
nn
.
Conv2D
(
self
.
in3_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
1
],
in_channels
=
in_channels
[
1
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_50.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
in4_conv
=
nn
.
Conv2D
(
self
.
in4_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
2
],
in_channels
=
in_channels
[
2
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_49.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
in5_conv
=
nn
.
Conv2D
(
self
.
in5_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
[
3
],
in_channels
=
in_channels
[
3
],
out_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
kernel_size
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_48.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p5_conv
=
nn
.
Conv2D
(
self
.
p5_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_52.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p4_conv
=
nn
.
Conv2D
(
self
.
p4_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_53.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p3_conv
=
nn
.
Conv2D
(
self
.
p3_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_54.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
p2_conv
=
nn
.
Conv2D
(
self
.
p2_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
,
in_channels
=
self
.
out_channels
,
out_channels
=
self
.
out_channels
//
4
,
out_channels
=
self
.
out_channels
//
4
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
name
=
'conv2d_55.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
fuse_conv
=
nn
.
Conv2D
(
self
.
fuse_conv
=
nn
.
Conv2D
(
in_channels
=
self
.
out_channels
*
4
,
in_channels
=
self
.
out_channels
*
4
,
out_channels
=
512
,
out_channels
=
512
,
kernel_size
=
3
,
kernel_size
=
3
,
padding
=
1
,
padding
=
1
,
weight_attr
=
ParamAttr
(
weight_attr
=
ParamAttr
(
initializer
=
weight_attr
),
bias_attr
=
False
)
name
=
'conv2d_fuse.w_0'
,
initializer
=
weight_attr
),
bias_attr
=
False
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
c2
,
c3
,
c4
,
c5
=
x
c2
,
c3
,
c4
,
c5
=
x
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
16c247ac
...
@@ -369,18 +369,6 @@ class TableLabelDecode(object):
...
@@ -369,18 +369,6 @@ class TableLabelDecode(object):
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
return
list_character
def
get_sp_tokens
(
self
):
char_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'char'
)
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
sp_tokens
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
])
return
sp_tokens
def
__call__
(
self
,
preds
):
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
loc_preds
=
preds
[
'loc_preds'
]
...
...
tools/export_model.py
浏览文件 @
16c247ac
...
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
...
@@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
)
infer_shape
[
-
1
]
=
100
infer_shape
[
-
1
]
=
100
elif
arch_config
[
"model_type"
]
==
"table"
:
infer_shape
=
[
3
,
488
,
488
]
model
=
to_static
(
model
=
to_static
(
model
,
model
,
input_spec
=
[
input_spec
=
[
...
...
tools/infer_table.py
浏览文件 @
16c247ac
...
@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
...
@@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
img
=
f
.
read
()
img
=
f
.
read
()
data
=
{
'image'
:
img
}
data
=
{
'image'
:
img
}
batch
=
transform
(
data
,
ops
)
batch
=
transform
(
data
,
ops
)
sp_tokens
=
post_process_class
.
get_sp_tokens
()
targets
=
[[],
[],
paddle
.
to_tensor
([
sp_tokens
])]
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
np
.
expand_dims
(
batch
[
0
],
axis
=
0
)
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
,
data
=
targets
,
mode
=
'Test'
)
preds
=
model
(
images
,
data
=
None
,
mode
=
'Test'
)
post_result
=
post_process_class
(
preds
)
post_result
=
post_process_class
(
preds
)
res_html_code
=
post_result
[
'res_html_code'
]
res_html_code
=
post_result
[
'res_html_code'
]
res_loc
=
post_result
[
'res_loc'
]
res_loc
=
post_result
[
'res_loc'
]
...
...
tools/program.py
浏览文件 @
16c247ac
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -276,6 +276,7 @@ def train(config,
...
@@ -276,6 +276,7 @@ def train(config,
valid_dataloader
,
valid_dataloader
,
post_process_class
,
post_process_class
,
eval_class
,
eval_class
,
"table"
,
use_srn
=
use_srn
)
use_srn
=
use_srn
)
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
cur_metric_str
=
'cur metric, {}'
.
format
(
', '
.
join
(
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
[
'{}: {}'
.
format
(
k
,
v
)
for
k
,
v
in
cur_metric
.
items
()]))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录