Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
ab4db2ac
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ab4db2ac
编写于
6月 03, 2021
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support dict output for basemodel
上级
e5d3a2d8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
63 addition
and
8 deletion
+63
-8
configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
...h_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
+13
-3
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+1
-0
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+1
-0
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+34
-0
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+10
-1
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+4
-4
未找到文件。
configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
浏览文件 @
ab4db2ac
...
...
@@ -39,6 +39,7 @@ Architecture:
Student
:
pretrained
:
null
freeze_params
:
false
return_all_feats
:
true
model_type
:
rec
algorithm
:
CRNN
Transform
:
...
...
@@ -57,6 +58,7 @@ Architecture:
Teacher
:
pretrained
:
null
freeze_params
:
false
return_all_feats
:
true
model_type
:
rec
algorithm
:
CRNN
Transform
:
...
...
@@ -80,18 +82,26 @@ Loss:
-
DistillationCTCLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
null
key
:
head_out
-
DistillationDMLLoss
:
weight
:
1.0
act
:
"
softmax"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
null
key
:
head_out
-
DistillationDistanceLoss
:
weight
:
1.0
mode
:
"
l2"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
PostProcess
:
name
:
DistillationCTCLabelDecode
model_name
:
"
Student"
key
_out
:
null
key
:
head_out
Metric
:
name
:
RecMetric
...
...
ppocr/losses/basic_loss.py
浏览文件 @
ab4db2ac
...
...
@@ -97,6 +97,7 @@ class DistanceLoss(nn.Layer):
"""
def
__init__
(
self
,
mode
=
"l2"
,
name
=
"loss_dist"
,
**
kargs
):
super
().
__init__
()
assert
mode
in
[
"l1"
,
"l2"
,
"smooth_l1"
]
if
mode
==
"l1"
:
self
.
loss_func
=
nn
.
L1Loss
(
**
kargs
)
...
...
ppocr/losses/combined_loss.py
浏览文件 @
ab4db2ac
...
...
@@ -17,6 +17,7 @@ import paddle.nn as nn
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDistanceLoss
class
CombinedLoss
(
nn
.
Layer
):
...
...
ppocr/losses/distillation_loss.py
浏览文件 @
ab4db2ac
...
...
@@ -17,6 +17,7 @@ import paddle.nn as nn
from
.rec_ctc_loss
import
CTCLoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
class
DistillationDMLLoss
(
DMLLoss
):
...
...
@@ -69,3 +70,36 @@ class DistillationCTCLoss(CTCLoss):
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
model_name
)]
=
loss
return
loss_dict
class
DistillationDistanceLoss
(
DistanceLoss
):
"""
"""
def
__init__
(
self
,
mode
=
"l2"
,
model_name_pairs
=
[],
key
=
None
,
name
=
"loss_distance"
,
**
kargs
):
super
().
__init__
(
mode
=
mode
,
name
=
name
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
out1
=
predicts
[
pair
[
0
]]
out2
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
return
loss_dict
ppocr/modeling/architectures/base_model.py
浏览文件 @
ab4db2ac
...
...
@@ -67,14 +67,23 @@ class BaseModel(nn.Layer):
config
[
"Head"
][
'in_channels'
]
=
in_channels
self
.
head
=
build_head
(
config
[
"Head"
])
self
.
return_all_feats
=
config
.
get
(
"return_all_feats"
,
False
)
def
forward
(
self
,
x
,
data
=
None
):
y
=
dict
()
if
self
.
use_transform
:
x
=
self
.
transform
(
x
)
x
=
self
.
backbone
(
x
)
y
[
"backbone_out"
]
=
x
if
self
.
use_neck
:
x
=
self
.
neck
(
x
)
y
[
"neck_out"
]
=
x
if
data
is
None
:
x
=
self
.
head
(
x
)
else
:
x
=
self
.
head
(
x
,
data
)
return
x
y
[
"head_out"
]
=
x
if
self
.
return_all_feats
:
return
y
else
:
return
x
ppocr/postprocess/rec_postprocess.py
浏览文件 @
ab4db2ac
...
...
@@ -136,17 +136,17 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
character_type
=
'ch'
,
use_space_char
=
False
,
model_name
=
"student"
,
key
_out
=
None
,
key
=
None
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
self
.
model_name
=
model_name
self
.
key
_out
=
key_out
self
.
key
=
key
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
self
.
model_name
]
if
self
.
key
_out
is
not
None
:
pred
=
pred
[
self
.
key
_out
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
return
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录