Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ed02b91d
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
1 年多 前同步成功
通知
1532
Star
32963
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed02b91d
编写于
6月 02, 2021
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add distillation function
上级
551a6827
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
405 addition
and
79 deletion
+405
-79
ppocr/losses/__init__.py
ppocr/losses/__init__.py
+24
-15
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+101
-0
ppocr/losses/cls_loss.py
ppocr/losses/cls_loss.py
+1
-1
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+57
-0
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+76
-0
ppocr/losses/rec_ctc_loss.py
ppocr/losses/rec_ctc_loss.py
+1
-1
ppocr/modeling/architectures/__init__.py
ppocr/modeling/architectures/__init__.py
+12
-4
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+0
-1
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+65
-0
ppocr/modeling/backbones/det_mobilenet_v3.py
ppocr/modeling/backbones/det_mobilenet_v3.py
+13
-32
ppocr/modeling/backbones/rec_mobilenet_v3.py
ppocr/modeling/backbones/rec_mobilenet_v3.py
+3
-6
ppocr/modeling/heads/rec_ctc_head.py
ppocr/modeling/heads/rec_ctc_head.py
+5
-8
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+9
-8
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+25
-0
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+4
-1
tools/program.py
tools/program.py
+1
-1
tools/train.py
tools/train.py
+8
-1
未找到文件。
ppocr/losses/__init__.py
浏览文件 @
ed02b91d
...
...
@@ -13,28 +13,37 @@
# limitations under the License.
import
copy
import
paddle
import
paddle.nn
as
nn
# det loss
from
.det_db_loss
import
DBLoss
from
.det_east_loss
import
EASTLoss
from
.det_sast_loss
import
SASTLoss
def
build_loss
(
config
):
# det loss
from
.det_db_loss
import
DBLoss
from
.det_east_loss
import
EASTLoss
from
.det_sast_loss
import
SASTLoss
# rec loss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_att_loss
import
AttentionLoss
from
.rec_srn_loss
import
SRNLoss
# cls loss
from
.cls_loss
import
ClsLoss
# e2e loss
from
.e2e_pg_loss
import
PGLoss
# rec loss
from
.rec_ctc_loss
import
CTCLoss
from
.rec_att_loss
import
AttentionLoss
from
.rec_srn_loss
import
SRNLoss
# basic loss function
from
.basic_loss
import
DistanceLoss
# cls loss
from
.cls_loss
import
Cls
Loss
# combined loss function
from
.combined_loss
import
Combined
Loss
# e2e loss
from
.e2e_pg_loss
import
PGLoss
def
build_loss
(
config
):
support_dict
=
[
'DBLoss'
,
'EASTLoss'
,
'SASTLoss'
,
'CTCLoss'
,
'ClsLoss'
,
'AttentionLoss'
,
'SRNLoss'
,
'PGLoss'
]
'SRNLoss'
,
'PGLoss'
,
'CombinedLoss'
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
assert
module_name
in
support_dict
,
Exception
(
'loss only support {}'
.
format
(
...
...
ppocr/losses/basic_loss.py
0 → 100644
浏览文件 @
ed02b91d
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
L1Loss
from
paddle.nn
import
MSELoss
as
L2Loss
from
paddle.nn
import
SmoothL1Loss
class
CELoss
(
nn
.
Layer
):
def
__init__
(
self
,
name
=
"loss_ce"
,
epsilon
=
None
):
super
().
__init__
()
self
.
name
=
name
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
self
.
epsilon
=
epsilon
def
_labelsmoothing
(
self
,
target
,
class_num
):
if
target
.
shape
[
-
1
]
!=
class_num
:
one_hot_target
=
F
.
one_hot
(
target
,
class_num
)
else
:
one_hot_target
=
target
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
epsilon
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
class_num
])
return
soft_target
def
forward
(
self
,
x
,
label
):
loss_dict
=
{}
if
self
.
epsilon
is
not
None
:
class_num
=
x
.
shape
[
-
1
]
label
=
self
.
_labelsmoothing
(
label
,
class_num
)
x
=
-
F
.
log_softmax
(
x
,
axis
=-
1
)
loss
=
paddle
.
sum
(
x
*
label
,
axis
=-
1
)
else
:
if
label
.
shape
[
-
1
]
==
x
.
shape
[
-
1
]:
label
=
F
.
softmax
(
label
,
axis
=-
1
)
soft_label
=
True
else
:
soft_label
=
False
loss
=
F
.
cross_entropy
(
x
,
label
=
label
,
soft_label
=
soft_label
)
loss_dict
[
self
.
name
]
=
paddle
.
mean
(
loss
)
return
loss_dict
class
DMLLoss
(
nn
.
Layer
):
"""
DMLLoss
"""
def
__init__
(
self
,
name
=
"loss_dml"
):
super
().
__init__
()
self
.
name
=
name
def
forward
(
self
,
out1
,
out2
):
loss_dict
=
{}
soft_out1
=
F
.
softmax
(
out1
,
axis
=-
1
)
log_soft_out1
=
paddle
.
log
(
soft_out1
)
soft_out2
=
F
.
softmax
(
out2
,
axis
=-
1
)
log_soft_out2
=
paddle
.
log
(
soft_out2
)
loss
=
(
F
.
kl_div
(
log_soft_out1
,
soft_out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_soft_out2
,
soft_out1
,
reduction
=
'batchmean'
))
/
2.0
loss_dict
[
self
.
name
]
=
loss
return
loss_dict
class
DistanceLoss
(
nn
.
Layer
):
"""
DistanceLoss:
mode: loss mode
name: loss key in the output dict
"""
def
__init__
(
self
,
mode
=
"l2"
,
name
=
"loss_dist"
,
**
kargs
):
assert
mode
in
[
"l1"
,
"l2"
,
"smooth_l1"
]
if
mode
==
"l1"
:
self
.
loss_func
=
nn
.
L1Loss
(
**
kargs
)
elif
mode
==
"l1"
:
self
.
loss_func
=
nn
.
MSELoss
(
**
kargs
)
elif
mode
==
"smooth_l1"
:
self
.
loss_func
=
nn
.
SmoothL1Loss
(
**
kargs
)
self
.
name
=
"{}_{}"
.
format
(
name
,
mode
)
def
forward
(
self
,
x
,
y
):
return
{
self
.
name
:
self
.
loss_func
(
x
,
y
)}
ppocr/losses/cls_loss.py
浏览文件 @
ed02b91d
...
...
@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
super
(
ClsLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CrossEntropyLoss
(
reduction
=
'mean'
)
def
__call__
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
label
=
batch
[
1
]
loss
=
self
.
loss_func
(
input
=
predicts
,
label
=
label
)
return
{
'loss'
:
loss
}
ppocr/losses/combined_loss.py
0 → 100644
浏览文件 @
ed02b91d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationDMLLoss
class
CombinedLoss
(
nn
.
Layer
):
"""
CombinedLoss:
a combionation of loss function
"""
def
__init__
(
self
,
loss_config_list
=
None
):
super
().
__init__
()
self
.
loss_func
=
[]
self
.
loss_weight
=
[]
assert
isinstance
(
loss_config_list
,
list
),
(
'operator config should be a list'
)
for
config
in
loss_config_list
:
assert
isinstance
(
config
,
dict
)
and
len
(
config
)
==
1
,
"yaml format error"
name
=
list
(
config
)[
0
]
param
=
config
[
name
]
assert
"weight"
in
param
,
"weight must be in param, but param just contains {}"
.
format
(
param
.
keys
())
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
def
forward
(
self
,
input
,
batch
,
**
kargs
):
loss_dict
=
{}
for
idx
,
loss_func
in
enumerate
(
self
.
loss_func
):
loss
=
loss_func
(
input
,
batch
,
**
kargs
)
if
isinstance
(
loss
,
paddle
.
Tensor
):
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
weight
=
self
.
loss_weight
[
idx
]
loss
=
{
"{}_{}"
.
format
(
key
,
idx
):
loss
[
key
]
*
weight
for
key
in
loss
}
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
return
loss_dict
ppocr/losses/distillation_loss.py
0 → 100644
浏览文件 @
ed02b91d
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.nn
as
nn
from
.rec_ctc_loss
import
CTCLoss
from
.basic_loss
import
DMLLoss
class
DistillationDMLLoss
(
DMLLoss
):
"""
"""
def
__init__
(
self
,
model_name_list1
=
[],
model_name_list2
=
[],
key
=
None
,
name
=
"loss_dml"
):
super
().
__init__
(
name
=
name
)
if
not
isinstance
(
model_name_list1
,
(
list
,
)):
model_name_list1
=
[
model_name_list1
]
if
not
isinstance
(
model_name_list2
,
(
list
,
)):
model_name_list2
=
[
model_name_list2
]
assert
len
(
model_name_list1
)
==
len
(
model_name_list2
)
self
.
model_name_list1
=
model_name_list1
self
.
model_name_list2
=
model_name_list2
self
.
key
=
key
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
in
range
(
len
(
self
.
model_name_list1
)):
out1
=
predicts
[
self
.
model_name_list1
[
idx
]]
out2
=
predicts
[
self
.
model_name_list2
[
idx
]]
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
assert
len
(
loss
)
==
1
loss
=
list
(
loss
.
values
())[
0
]
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
return
loss_dict
class
DistillationCTCLoss
(
CTCLoss
):
def
__init__
(
self
,
model_name_list
=
[],
key
=
None
,
name
=
"loss_ctc"
):
super
().
__init__
()
self
.
model_name_list
=
model_name_list
self
.
key
=
key
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
model_name
in
self
.
model_name_list
:
out
=
predicts
[
model_name
]
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
loss
=
super
().
forward
(
out
,
batch
)
if
isinstance
(
loss
,
dict
):
assert
len
(
loss
)
==
1
loss
=
list
(
loss
.
values
())[
0
]
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
model_name
)]
=
loss
return
loss_dict
ppocr/losses/rec_ctc_loss.py
浏览文件 @
ed02b91d
...
...
@@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
super
(
CTCLoss
,
self
).
__init__
()
self
.
loss_func
=
nn
.
CTCLoss
(
blank
=
0
,
reduction
=
'none'
)
def
__call__
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
predicts
=
predicts
.
transpose
((
1
,
0
,
2
))
N
,
B
,
_
=
predicts
.
shape
preds_lengths
=
paddle
.
to_tensor
([
N
]
*
B
,
dtype
=
'int64'
)
...
...
ppocr/modeling/architectures/__init__.py
浏览文件 @
ed02b91d
...
...
@@ -13,12 +13,20 @@
# limitations under the License.
import
copy
import
importlib
from
.base_model
import
BaseModel
from
.distillation_model
import
DistillationModel
__all__
=
[
'build_model'
]
def
build_model
(
config
):
from
.base_model
import
BaseModel
def
build_model
(
config
):
config
=
copy
.
deepcopy
(
config
)
module_class
=
BaseModel
(
config
)
return
module_class
\ No newline at end of file
if
not
"name"
in
config
:
arch
=
BaseModel
(
config
)
else
:
name
=
config
.
pop
(
"name"
)
mod
=
importlib
.
import_module
(
__name__
)
arch
=
getattr
(
mod
,
name
)(
config
)
return
arch
ppocr/modeling/architectures/base_model.py
浏览文件 @
ed02b91d
...
...
@@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
config (dict): the super parameters for module.
"""
super
(
BaseModel
,
self
).
__init__
()
in_channels
=
config
.
get
(
'in_channels'
,
3
)
model_type
=
config
[
'model_type'
]
# build transfrom,
...
...
ppocr/modeling/architectures/distillation_model.py
0 → 100644
浏览文件 @
ed02b91d
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
paddle
import
nn
from
ppocr.modeling.transforms
import
build_transform
from
ppocr.modeling.backbones
import
build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
from
.base_model
import
BaseModel
from
ppocr.utils.save_load
import
load_dygraph_pretrain
__all__
=
[
'DistillationModel'
]
class
DistillationModel
(
nn
.
Layer
):
def
__init__
(
self
,
config
):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super
().
__init__
()
freeze_params
=
config
[
"freeze_params"
]
pretrained
=
config
[
"pretrained"
]
if
not
isinstance
(
freeze_params
,
list
):
freeze_params
=
[
freeze_params
]
assert
len
(
config
[
"Models"
])
==
len
(
freeze_params
)
if
not
isinstance
(
pretrained
,
list
):
pretrained
=
[
pretrained
]
*
len
(
config
[
"Models"
])
assert
len
(
config
[
"Models"
])
==
len
(
pretrained
)
self
.
model_dict
=
dict
()
index
=
0
for
key
in
config
[
"Models"
]:
model_config
=
config
[
"Models"
][
key
]
model
=
BaseModel
(
model_config
)
if
pretrained
[
index
]
is
not
None
:
load_dygraph_pretrain
(
model
,
path
=
pretrained
[
index
])
if
freeze_params
[
index
]:
for
param
in
model
.
parameters
():
param
.
trainable
=
False
self
.
model_dict
[
key
]
=
self
.
add_sublayer
(
key
,
model
)
index
+=
1
def
forward
(
self
,
x
):
result_dict
=
dict
()
for
key
in
self
.
model_dict
:
result_dict
[
key
]
=
self
.
model_dict
[
key
](
x
)
return
result_dict
ppocr/modeling/backbones/det_mobilenet_v3.py
浏览文件 @
ed02b91d
...
...
@@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
padding
=
1
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv1'
)
act
=
'hardswish'
)
self
.
stages
=
[]
self
.
out_channels
=
[]
...
...
@@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
kernel_size
=
k
,
stride
=
s
,
use_se
=
se
,
act
=
nl
,
name
=
"conv"
+
str
(
i
+
2
)))
act
=
nl
))
inplanes
=
make_divisible
(
scale
*
c
)
i
+=
1
block_list
.
append
(
...
...
@@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
padding
=
0
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv_last'
))
act
=
'hardswish'
))
self
.
stages
.
append
(
nn
.
Sequential
(
*
block_list
))
self
.
out_channels
.
append
(
make_divisible
(
scale
*
cls_ch_squeeze
))
for
i
,
stage
in
enumerate
(
self
.
stages
):
...
...
@@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
padding
,
groups
=
1
,
if_act
=
True
,
act
=
None
,
name
=
None
):
act
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
if_act
=
if_act
self
.
act
=
act
...
...
@@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
name
=
name
+
'_weights'
),
bias_attr
=
False
)
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_bn_scale"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_bn_offset"
),
moving_mean_name
=
name
+
"_bn_mean"
,
moving_variance_name
=
name
+
"_bn_variance"
)
self
.
bn
=
nn
.
BatchNorm
(
num_channels
=
out_channels
,
act
=
None
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
...
...
@@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
kernel_size
,
stride
,
use_se
,
act
=
None
,
name
=
''
):
act
=
None
):
super
(
ResidualUnit
,
self
).
__init__
()
self
.
if_shortcut
=
stride
==
1
and
in_channels
==
out_channels
self
.
if_se
=
use_se
...
...
@@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
stride
=
1
,
padding
=
0
,
if_act
=
True
,
act
=
act
,
name
=
name
+
"_expand"
)
act
=
act
)
self
.
bottleneck_conv
=
ConvBNLayer
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
...
...
@@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
padding
=
int
((
kernel_size
-
1
)
//
2
),
groups
=
mid_channels
,
if_act
=
True
,
act
=
act
,
name
=
name
+
"_depthwise"
)
act
=
act
)
if
self
.
if_se
:
self
.
mid_se
=
SEModule
(
mid_channels
,
name
=
name
+
"_se"
)
self
.
mid_se
=
SEModule
(
mid_channels
)
self
.
linear_conv
=
ConvBNLayer
(
in_channels
=
mid_channels
,
out_channels
=
out_channels
,
...
...
@@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
stride
=
1
,
padding
=
0
,
if_act
=
False
,
act
=
None
,
name
=
name
+
"_linear"
)
act
=
None
)
def
forward
(
self
,
inputs
):
x
=
self
.
expand_conv
(
inputs
)
...
...
@@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
class
SEModule
(
nn
.
Layer
):
def
__init__
(
self
,
in_channels
,
reduction
=
4
,
name
=
""
):
def
__init__
(
self
,
in_channels
,
reduction
=
4
):
super
(
SEModule
,
self
).
__init__
()
self
.
avg_pool
=
nn
.
AdaptiveAvgPool2D
(
1
)
self
.
conv1
=
nn
.
Conv2D
(
...
...
@@ -266,17 +251,13 @@ class SEModule(nn.Layer):
out_channels
=
in_channels
//
reduction
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
name
=
name
+
"_1_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_1_offset"
))
padding
=
0
)
self
.
conv2
=
nn
.
Conv2D
(
in_channels
=
in_channels
//
reduction
,
out_channels
=
in_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
weight_attr
=
ParamAttr
(
name
+
"_2_weights"
),
bias_attr
=
ParamAttr
(
name
=
name
+
"_2_offset"
))
padding
=
0
)
def
forward
(
self
,
inputs
):
outputs
=
self
.
avg_pool
(
inputs
)
...
...
ppocr/modeling/backbones/rec_mobilenet_v3.py
浏览文件 @
ed02b91d
...
...
@@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
padding
=
1
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv1'
)
act
=
'hardswish'
)
i
=
0
block_list
=
[]
inplanes
=
make_divisible
(
inplanes
*
scale
)
...
...
@@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
kernel_size
=
k
,
stride
=
s
,
use_se
=
se
,
act
=
nl
,
name
=
'conv'
+
str
(
i
+
2
)))
act
=
nl
))
inplanes
=
make_divisible
(
scale
*
c
)
i
+=
1
self
.
blocks
=
nn
.
Sequential
(
*
block_list
)
...
...
@@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
padding
=
0
,
groups
=
1
,
if_act
=
True
,
act
=
'hardswish'
,
name
=
'conv_last'
)
act
=
'hardswish'
)
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
)
self
.
out_channels
=
make_divisible
(
scale
*
cls_ch_squeeze
)
...
...
ppocr/modeling/heads/rec_ctc_head.py
浏览文件 @
ed02b91d
...
...
@@ -23,14 +23,12 @@ from paddle import ParamAttr, nn
from
paddle.nn
import
functional
as
F
def
get_para_bias_attr
(
l2_decay
,
k
,
name
):
def
get_para_bias_attr
(
l2_decay
,
k
):
regularizer
=
paddle
.
regularizer
.
L2Decay
(
l2_decay
)
stdv
=
1.0
/
math
.
sqrt
(
k
*
1.0
)
initializer
=
nn
.
initializer
.
Uniform
(
-
stdv
,
stdv
)
weight_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
,
name
=
name
+
"_w_attr"
)
bias_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
,
name
=
name
+
"_b_attr"
)
weight_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
)
bias_attr
=
ParamAttr
(
regularizer
=
regularizer
,
initializer
=
initializer
)
return
[
weight_attr
,
bias_attr
]
...
...
@@ -38,13 +36,12 @@ class CTCHead(nn.Layer):
def
__init__
(
self
,
in_channels
,
out_channels
,
fc_decay
=
0.0004
,
**
kwargs
):
super
(
CTCHead
,
self
).
__init__
()
weight_attr
,
bias_attr
=
get_para_bias_attr
(
l2_decay
=
fc_decay
,
k
=
in_channels
,
name
=
'ctc_fc'
)
l2_decay
=
fc_decay
,
k
=
in_channels
)
self
.
fc
=
nn
.
Linear
(
in_channels
,
out_channels
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
name
=
'ctc_fc'
)
bias_attr
=
bias_attr
)
self
.
out_channels
=
out_channels
def
forward
(
self
,
x
,
labels
=
None
):
...
...
ppocr/postprocess/__init__.py
浏览文件 @
ed02b91d
...
...
@@ -21,18 +21,19 @@ import copy
__all__
=
[
'build_post_process'
]
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
from
.db_postprocess
import
DBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
from
.cls_postprocess
import
ClsPostProcess
from
.pg_postprocess
import
PGPostProcess
def
build_post_process
(
config
,
global_config
=
None
):
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
ed02b91d
...
...
@@ -125,6 +125,31 @@ class CTCLabelDecode(BaseRecLabelDecode):
return
dict_character
class
DistillationCTCLabelDecode
(
CTCLabelDecode
):
"""
Convert
Convert between text-label and text-index
"""
def
__init__
(
self
,
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
model_name
=
"student"
,
key_out
=
None
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
self
.
model_name
=
model_name
self
.
key_out
=
key_out
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
self
.
model_name
]
if
self
.
key_out
is
not
None
:
pred
=
pred
[
self
.
key_out
]
return
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
class
AttnLabelDecode
(
BaseRecLabelDecode
):
""" Convert between text-label and text-index """
...
...
ppocr/utils/save_load.py
浏览文件 @
ed02b91d
...
...
@@ -42,7 +42,10 @@ def _mkdir_if_not_exist(path, logger):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
load_dygraph_pretrain
(
model
,
logger
,
path
=
None
,
load_static_weights
=
False
):
def
load_dygraph_pretrain
(
model
,
logger
=
None
,
path
=
None
,
load_static_weights
=
False
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
...
...
tools/program.py
浏览文件 @
ed02b91d
...
...
@@ -386,7 +386,7 @@ def preprocess(is_train=False):
alg
=
config
[
'Architecture'
][
'algorithm'
]
assert
alg
in
[
'EAST'
,
'DB'
,
'SAST'
,
'Rosetta'
,
'CRNN'
,
'STARNet'
,
'RARE'
,
'SRN'
,
'CLS'
,
'PGNet'
'CLS'
,
'PGNet'
,
'Distillation'
]
device
=
'gpu:{}'
.
format
(
dist
.
ParallelEnv
().
dev_id
)
if
use_gpu
else
'cpu'
...
...
tools/train.py
浏览文件 @
ed02b91d
...
...
@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
if
hasattr
(
post_process_class
,
'character'
):
char_num
=
len
(
getattr
(
post_process_class
,
'character'
))
if
config
[
'Architecture'
][
"algorithm"
]
in
[
"Distillation"
,
]:
# distillation model
for
key
in
config
[
'Architecture'
][
"Models"
]:
config
[
'Architecture'
][
"Models"
][
key
][
"Head"
][
'out_channels'
]
=
char_num
else
:
# base rec model
config
[
'Architecture'
][
"Head"
][
'out_channels'
]
=
char_num
model
=
build_model
(
config
[
'Architecture'
])
if
config
[
'Global'
][
'distributed'
]:
model
=
paddle
.
DataParallel
(
model
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录