Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
6ce44198
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ce44198
编写于
7月 07, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug
上级
185d1e1f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
24 addition
and
25 deletion
+24
-25
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
+2
-2
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+2
-2
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+15
-7
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-1
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+3
-13
未找到文件。
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
浏览文件 @
6ce44198
...
@@ -88,7 +88,7 @@ Loss:
...
@@ -88,7 +88,7 @@ Loss:
-
DistillationDMLLoss
:
-
DistillationDMLLoss
:
model_name_pairs
:
model_name_pairs
:
-
[
"
Student"
,
"
Student2"
]
-
[
"
Student"
,
"
Student2"
]
maps_name
:
[
"
thrink_maps"
]
maps_name
:
"
thrink_maps"
weight
:
1.0
weight
:
1.0
act
:
"
softmax"
act
:
"
softmax"
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
...
@@ -96,7 +96,7 @@ Loss:
...
@@ -96,7 +96,7 @@ Loss:
-
DistillationDBLoss
:
-
DistillationDBLoss
:
weight
:
1.0
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
]
model_name_list
:
[
"
Student"
,
"
Student2"
]
key
:
maps
#
key: maps
name
:
DBLoss
name
:
DBLoss
balance_loss
:
true
balance_loss
:
true
main_loss_type
:
DiceLoss
main_loss_type
:
DiceLoss
...
...
ppocr/losses/combined_loss.py
浏览文件 @
6ce44198
...
@@ -50,11 +50,11 @@ class CombinedLoss(nn.Layer):
...
@@ -50,11 +50,11 @@ class CombinedLoss(nn.Layer):
if
isinstance
(
loss
,
paddle
.
Tensor
):
if
isinstance
(
loss
,
paddle
.
Tensor
):
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
weight
=
self
.
loss_weight
[
idx
]
weight
=
self
.
loss_weight
[
idx
]
for
key
in
loss
:
for
key
in
loss
.
keys
()
:
if
key
==
"loss"
:
if
key
==
"loss"
:
loss_all
+=
loss
[
key
]
*
weight
loss_all
+=
loss
[
key
]
*
weight
else
:
else
:
loss
[
"{}_{}"
.
format
(
key
,
idx
)]
=
loss
[
key
]
loss
_dict
[
"{}_{}"
.
format
(
key
,
idx
)]
=
loss
[
key
]
# loss[f"{key}_{idx}"] = loss[key]
# loss[f"{key}_{idx}"] = loss[key]
loss_dict
.
update
(
loss
)
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
loss_all
loss_dict
[
"loss"
]
=
loss_all
...
...
ppocr/losses/distillation_loss.py
浏览文件 @
6ce44198
...
@@ -24,7 +24,6 @@ from .det_db_loss import DBLoss
...
@@ -24,7 +24,6 @@ from .det_db_loss import DBLoss
from
.det_basic_loss
import
BalanceLoss
,
MaskL1Loss
,
DiceLoss
from
.det_basic_loss
import
BalanceLoss
,
MaskL1Loss
,
DiceLoss
def
_sum_loss
(
loss_dict
):
def
_sum_loss
(
loss_dict
):
if
"loss"
in
loss_dict
.
keys
():
if
"loss"
in
loss_dict
.
keys
():
return
loss_dict
return
loss_dict
...
@@ -51,9 +50,17 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -51,9 +50,17 @@ class DistillationDMLLoss(DMLLoss):
super
().
__init__
(
act
=
act
)
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
name
=
name
self
.
name
=
name
self
.
maps_name
=
maps_name
self
.
maps_name
=
maps_name
def
_check_model_name_pairs
(
self
,
model_name_pairs
):
if
not
isinstance
(
model_name_pairs
,
list
):
return
[]
elif
isinstance
(
model_name_pairs
[
0
],
list
)
and
isinstance
(
model_name_pairs
[
0
][
0
],
str
):
return
model_name_pairs
else
:
return
[
model_name_pairs
]
def
_check_maps_name
(
self
,
maps_name
):
def
_check_maps_name
(
self
,
maps_name
):
if
maps_name
is
None
:
if
maps_name
is
None
:
...
@@ -69,13 +76,14 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -69,13 +76,14 @@ class DistillationDMLLoss(DMLLoss):
new_outs
=
{}
new_outs
=
{}
for
k
in
self
.
maps_name
:
for
k
in
self
.
maps_name
:
if
k
==
"thrink_maps"
:
if
k
==
"thrink_maps"
:
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
1
,
starts
=
0
,
ends
=
1
)
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
1
]
)
elif
k
==
"threshold_maps"
:
elif
k
==
"threshold_maps"
:
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
1
,
starts
=
1
,
ends
=
2
)
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
2
]
)
elif
k
==
"binary_maps"
:
elif
k
==
"binary_maps"
:
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
1
,
starts
=
2
,
ends
=
3
)
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
[
1
],
starts
=
[
2
],
ends
=
[
3
]
)
else
:
else
:
continue
continue
return
new_outs
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
loss_dict
=
dict
()
...
@@ -104,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -104,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
map_name
,
idx
)]
=
loss
[
key
]
0
],
pair
[
1
],
map_name
,
idx
)]
=
loss
[
key
]
else
:
else
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
map
_name
,
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps
_name
,
idx
)]
=
loss
idx
)]
=
loss
loss_dict
=
_sum_loss
(
loss_dict
)
loss_dict
=
_sum_loss
(
loss_dict
)
...
@@ -151,7 +159,7 @@ class DistillationDBLoss(DBLoss):
...
@@ -151,7 +159,7 @@ class DistillationDBLoss(DBLoss):
self
.
name
=
name
self
.
name
=
name
self
.
key
=
None
self
.
key
=
None
def
forward
(
self
,
preicts
,
batch
):
def
forward
(
self
,
pre
d
icts
,
batch
):
loss_dict
=
{}
loss_dict
=
{}
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
out
=
predicts
[
model_name
]
out
=
predicts
[
model_name
]
...
...
ppocr/postprocess/__init__.py
浏览文件 @
6ce44198
...
@@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
...
@@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
support_dict
=
[
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
6ce44198
...
@@ -200,12 +200,9 @@ class DistillationDBPostProcess(DBPostProcess):
...
@@ -200,12 +200,9 @@ class DistillationDBPostProcess(DBPostProcess):
use_dilation
=
False
,
use_dilation
=
False
,
score_mode
=
"fast"
,
score_mode
=
"fast"
,
**
kwargs
):
**
kwargs
):
super
(
DistillationDBPostProcess
,
self
).
__init__
(
thresh
,
super
(
DistillationDBPostProcess
,
self
).
__init__
(
box_thresh
,
thresh
,
box_thresh
,
max_candidates
,
unclip_ratio
,
use_dilation
,
max_candidates
,
score_mode
)
unclip_ratio
,
use_dilation
,
score_mode
)
if
not
isinstance
(
model_name
,
list
):
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
model_name
=
model_name
...
@@ -221,10 +218,3 @@ class DistillationDBPostProcess(DBPostProcess):
...
@@ -221,10 +218,3 @@ class DistillationDBPostProcess(DBPostProcess):
results
[
name
]
=
super
().
__call__
(
pred
,
shape_list
=
label
)
results
[
name
]
=
super
().
__call__
(
pred
,
shape_list
=
label
)
return
results
return
results
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录