Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
5dfcc983
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5dfcc983
编写于
6月 27, 2022
作者:
文幕地方
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug
上级
9b1e9ae6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
54 addition
and
34 deletion
+54
-34
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+8
-4
ppocr/losses/table_att_loss.py
ppocr/losses/table_att_loss.py
+33
-22
ppocr/metrics/table_metric.py
ppocr/metrics/table_metric.py
+4
-2
ppocr/modeling/heads/table_att_head.py
ppocr/modeling/heads/table_att_head.py
+9
-6
未找到文件。
ppocr/data/imaug/label_ops.py
浏览文件 @
5dfcc983
...
...
@@ -591,7 +591,7 @@ class TableLabelEncode(AttnLabelEncode):
replace_empty_cell_token
=
False
,
merge_no_span_structure
=
False
,
learn_empty_box
=
False
,
point_num
=
4
,
point_num
=
2
,
**
kwargs
):
self
.
max_text_len
=
max_text_length
self
.
lower
=
False
...
...
@@ -669,13 +669,15 @@ class TableLabelEncode(AttnLabelEncode):
# encode box
bboxes
=
np
.
zeros
(
(
self
.
_max_text_len
,
self
.
point_num
),
dtype
=
np
.
float32
)
(
self
.
_max_text_len
,
self
.
point_num
*
2
),
dtype
=
np
.
float32
)
bbox_masks
=
np
.
zeros
((
self
.
_max_text_len
,
1
),
dtype
=
np
.
float32
)
bbox_idx
=
0
for
i
,
token
in
enumerate
(
structure
):
if
self
.
idx2char
[
token
]
in
self
.
td_token
:
if
'bbox'
in
cells
[
bbox_idx
]:
if
'bbox'
in
cells
[
bbox_idx
]
and
len
(
cells
[
bbox_idx
][
'tokens'
])
>
0
:
bbox
=
cells
[
bbox_idx
][
'bbox'
].
copy
()
bbox
=
np
.
array
(
bbox
,
dtype
=
np
.
float32
).
reshape
(
-
1
)
bboxes
[
i
]
=
bbox
...
...
@@ -723,11 +725,13 @@ class TableMasterLabelEncode(TableLabelEncode):
replace_empty_cell_token
=
False
,
merge_no_span_structure
=
False
,
learn_empty_box
=
False
,
point_num
=
4
,
point_num
=
2
,
**
kwargs
):
super
(
TableMasterLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
replace_empty_cell_token
,
merge_no_span_structure
,
learn_empty_box
,
point_num
,
**
kwargs
)
self
.
pad_idx
=
self
.
dict
[
self
.
pad_str
]
self
.
unknown_idx
=
self
.
dict
[
self
.
unknown_str
]
@
property
def
_max_text_len
(
self
):
...
...
ppocr/losses/table_att_loss.py
浏览文件 @
5dfcc983
...
...
@@ -21,15 +21,21 @@ from paddle import nn
from
paddle.nn
import
functional
as
F
from
paddle
import
fluid
class
TableAttentionLoss
(
nn
.
Layer
):
def
__init__
(
self
,
structure_weight
,
loc_weight
,
use_giou
=
False
,
giou_weight
=
1.0
,
**
kwargs
):
def
__init__
(
self
,
structure_weight
,
loc_weight
,
use_giou
=
False
,
giou_weight
=
1.0
,
**
kwargs
):
super
(
TableAttentionLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
weight
=
None
,
reduction
=
'none'
)
self
.
structure_weight
=
structure_weight
self
.
loc_weight
=
loc_weight
self
.
use_giou
=
use_giou
self
.
giou_weight
=
giou_weight
def
giou_loss
(
self
,
preds
,
bbox
,
eps
=
1e-7
,
reduction
=
'mean'
):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
...
...
@@ -48,9 +54,10 @@ class TableAttentionLoss(nn.Layer):
inters
=
iw
*
ih
# union
uni
=
(
preds
[:,
2
]
-
preds
[:,
0
]
+
1e-3
)
*
(
preds
[:,
3
]
-
preds
[:,
1
]
+
1e-3
)
+
(
bbox
[:,
2
]
-
bbox
[:,
0
]
+
1e-3
)
*
(
bbox
[:,
3
]
-
bbox
[:,
1
]
+
1e-3
)
-
inters
+
eps
uni
=
(
preds
[:,
2
]
-
preds
[:,
0
]
+
1e-3
)
*
(
preds
[:,
3
]
-
preds
[:,
1
]
+
1e-3
)
+
(
bbox
[:,
2
]
-
bbox
[:,
0
]
+
1e-3
)
*
(
bbox
[:,
3
]
-
bbox
[:,
1
]
+
1e-3
)
-
inters
+
eps
# ious
ious
=
inters
/
uni
...
...
@@ -80,30 +87,34 @@ class TableAttentionLoss(nn.Layer):
structure_probs
=
predicts
[
'structure_probs'
]
structure_targets
=
batch
[
1
].
astype
(
"int64"
)
structure_targets
=
structure_targets
[:,
1
:]
if
len
(
batch
)
==
6
:
structure_mask
=
batch
[
5
].
astype
(
"int64"
)
structure_mask
=
structure_mask
[:,
1
:]
structure_mask
=
paddle
.
reshape
(
structure_mask
,
[
-
1
])
structure_probs
=
paddle
.
reshape
(
structure_probs
,
[
-
1
,
structure_probs
.
shape
[
-
1
]])
structure_probs
=
paddle
.
reshape
(
structure_probs
,
[
-
1
,
structure_probs
.
shape
[
-
1
]])
structure_targets
=
paddle
.
reshape
(
structure_targets
,
[
-
1
])
structure_loss
=
self
.
loss_func
(
structure_probs
,
structure_targets
)
if
len
(
batch
)
==
6
:
structure_loss
=
structure_loss
*
structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss
=
paddle
.
mean
(
structure_loss
)
*
self
.
structure_weight
loc_preds
=
predicts
[
'loc_preds'
]
loc_targets
=
batch
[
2
].
astype
(
"float32"
)
loc_targets_mask
=
batch
[
4
].
astype
(
"float32"
)
loc_targets_mask
=
batch
[
3
].
astype
(
"float32"
)
loc_targets
=
loc_targets
[:,
1
:,
:]
loc_targets_mask
=
loc_targets_mask
[:,
1
:,
:]
loc_loss
=
F
.
mse_loss
(
loc_preds
*
loc_targets_mask
,
loc_targets
)
*
self
.
loc_weight
loc_loss
=
F
.
mse_loss
(
loc_preds
*
loc_targets_mask
,
loc_targets
)
*
self
.
loc_weight
if
self
.
use_giou
:
loc_loss_giou
=
self
.
giou_loss
(
loc_preds
*
loc_targets_mask
,
loc_targets
)
*
self
.
giou_weight
loc_loss_giou
=
self
.
giou_loss
(
loc_preds
*
loc_targets_mask
,
loc_targets
)
*
self
.
giou_weight
total_loss
=
structure_loss
+
loc_loss
+
loc_loss_giou
return
{
'loss'
:
total_loss
,
"structure_loss"
:
structure_loss
,
"loc_loss"
:
loc_loss
,
"loc_loss_giou"
:
loc_loss_giou
}
return
{
'loss'
:
total_loss
,
"structure_loss"
:
structure_loss
,
"loc_loss"
:
loc_loss
,
"loc_loss_giou"
:
loc_loss_giou
}
else
:
total_loss
=
structure_loss
+
loc_loss
return
{
'loss'
:
total_loss
,
"structure_loss"
:
structure_loss
,
"loc_loss"
:
loc_loss
}
\ No newline at end of file
total_loss
=
structure_loss
+
loc_loss
return
{
'loss'
:
total_loss
,
"structure_loss"
:
structure_loss
,
"loc_loss"
:
loc_loss
}
ppocr/metrics/table_metric.py
浏览文件 @
5dfcc983
...
...
@@ -31,6 +31,8 @@ class TableStructureMetric(object):
gt_structure_batch_list
):
pred_str
=
''
.
join
(
pred
)
target_str
=
''
.
join
(
target
)
# pred_str = pred_str.replace('<thead>','').replace('</thead>','').replace('<tbody>','').replace('</tbody>','')
# target_str = target_str.replace('<thead>','').replace('</thead>','').replace('<tbody>','').replace('</tbody>','')
if
pred_str
==
target_str
:
correct_num
+=
1
all_num
+=
1
...
...
@@ -131,10 +133,10 @@ class TableMetric(object):
self
.
bbox_metric
.
reset
()
def
format_box
(
self
,
box
):
if
self
.
point_num
==
4
:
if
self
.
point_num
==
2
:
x1
,
y1
,
x2
,
y2
=
box
box
=
[[
x1
,
y1
],
[
x2
,
y1
],
[
x2
,
y2
],
[
x1
,
y2
]]
elif
self
.
point_num
==
8
:
elif
self
.
point_num
==
4
:
x1
,
y1
,
x2
,
y2
,
x3
,
y3
,
x4
,
y4
=
box
box
=
[[
x1
,
y1
],
[
x2
,
y2
],
[
x3
,
y3
],
[
x4
,
y4
]]
return
box
ppocr/modeling/heads/table_att_head.py
浏览文件 @
5dfcc983
...
...
@@ -31,16 +31,18 @@ class TableAttentionHead(nn.Layer):
loc_type
,
in_max_len
=
488
,
max_text_length
=
800
,
out_channels
=
30
,
point_num
=
2
,
**
kwargs
):
super
(
TableAttentionHead
,
self
).
__init__
()
self
.
input_size
=
in_channels
[
-
1
]
self
.
hidden_size
=
hidden_size
self
.
elem_num
=
30
self
.
out_channels
=
out_channels
self
.
max_text_length
=
max_text_length
self
.
structure_attention_cell
=
AttentionGRUCell
(
self
.
input_size
,
hidden_size
,
self
.
elem_num
,
use_gru
=
False
)
self
.
structure_generator
=
nn
.
Linear
(
hidden_size
,
self
.
elem_num
)
self
.
input_size
,
hidden_size
,
self
.
out_channels
,
use_gru
=
False
)
self
.
structure_generator
=
nn
.
Linear
(
hidden_size
,
self
.
out_channels
)
self
.
loc_type
=
loc_type
self
.
in_max_len
=
in_max_len
...
...
@@ -53,7 +55,8 @@ class TableAttentionHead(nn.Layer):
self
.
loc_fea_trans
=
nn
.
Linear
(
625
,
self
.
max_text_length
+
1
)
else
:
self
.
loc_fea_trans
=
nn
.
Linear
(
256
,
self
.
max_text_length
+
1
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
4
)
self
.
loc_generator
=
nn
.
Linear
(
self
.
input_size
+
hidden_size
,
point_num
*
2
)
def
_char_to_onehot
(
self
,
input_char
,
onehot_dim
):
input_ont_hot
=
F
.
one_hot
(
input_char
,
onehot_dim
)
...
...
@@ -77,7 +80,7 @@ class TableAttentionHead(nn.Layer):
structure
=
targets
[
0
]
for
i
in
range
(
self
.
max_text_length
+
1
):
elem_onehots
=
self
.
_char_to_onehot
(
structure
[:,
i
],
onehot_dim
=
self
.
elem_num
)
structure
[:,
i
],
onehot_dim
=
self
.
out_channels
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
hidden
,
fea
,
elem_onehots
)
output_hiddens
.
append
(
paddle
.
unsqueeze
(
outputs
,
axis
=
1
))
...
...
@@ -104,7 +107,7 @@ class TableAttentionHead(nn.Layer):
i
=
0
while
i
<
max_text_length
+
1
:
elem_onehots
=
self
.
_char_to_onehot
(
temp_elem
,
onehot_dim
=
self
.
elem_num
)
temp_elem
,
onehot_dim
=
self
.
out_channels
)
(
outputs
,
hidden
),
alpha
=
self
.
structure_attention_cell
(
hidden
,
fea
,
elem_onehots
)
output_hiddens
.
append
(
paddle
.
unsqueeze
(
outputs
,
axis
=
1
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录