Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
7dc56191
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7dc56191
编写于
9月 09, 2021
作者:
littletomatodonkey
提交者:
GitHub
9月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix rec distillation (#3994)
* fix rec distillation * add dist cfg * fix yaml
上级
51f4a2c3
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
255 addition
and
95 deletion
+255
-95
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
+20
-68
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
+160
-0
doc/doc_ch/knowledge_distillation.md
doc/doc_ch/knowledge_distillation.md
+34
-1
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+18
-11
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+9
-5
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+8
-6
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+6
-4
未找到文件。
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
浏览文件 @
7dc56191
...
@@ -4,7 +4,7 @@ Global:
...
@@ -4,7 +4,7 @@ Global:
epoch_num
:
800
epoch_num
:
800
log_smooth_window
:
20
log_smooth_window
:
20
print_batch_step
:
10
print_batch_step
:
10
save_model_dir
:
./output/rec_
chinese_lite_distillation_v2.1
save_model_dir
:
./output/rec_
mobile_pp-OCRv2
save_epoch_step
:
3
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
cal_metric_during_train
:
true
...
@@ -19,7 +19,7 @@ Global:
...
@@ -19,7 +19,7 @@ Global:
infer_mode
:
false
infer_mode
:
false
use_space_char
:
true
use_space_char
:
true
distributed
:
true
distributed
:
true
save_res_path
:
./output/rec/predicts_
chinese_lite_distillation_v2.1
.txt
save_res_path
:
./output/rec/predicts_
mobile_pp-OCRv2
.txt
Optimizer
:
Optimizer
:
...
@@ -35,79 +35,32 @@ Optimizer:
...
@@ -35,79 +35,32 @@ Optimizer:
name
:
L2
name
:
L2
factor
:
2.0e-05
factor
:
2.0e-05
Architecture
:
Architecture
:
model_type
:
&model_type
"
rec"
model_type
:
rec
name
:
DistillationModel
algorithm
:
CRNN
algorithm
:
Distillation
Transform
:
Models
:
Backbone
:
Teacher
:
name
:
MobileNetV1Enhance
pretrained
:
scale
:
0.5
freeze_params
:
false
Neck
:
return_all_feats
:
true
name
:
SequenceEncoder
model_type
:
*model_type
encoder_type
:
rnn
algorithm
:
CRNN
hidden_size
:
64
Transform
:
Head
:
Backbone
:
name
:
CTCHead
name
:
MobileNetV1Enhance
mid_channels
:
96
scale
:
0.5
fc_decay
:
0.00002
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
Student
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
CRNN
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
Loss
:
Loss
:
name
:
CombinedLoss
name
:
CTCLoss
loss_config_list
:
-
DistillationCTCLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
-
DistillationDMLLoss
:
weight
:
1.0
act
:
"
softmax"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
-
DistillationDistanceLoss
:
weight
:
1.0
mode
:
"
l2"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
PostProcess
:
PostProcess
:
name
:
DistillationCTCLabelDecode
name
:
CTCLabelDecode
model_name
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
Metric
:
Metric
:
name
:
DistillationMetric
name
:
RecMetric
base_metric_name
:
RecMetric
main_indicator
:
acc
main_indicator
:
acc
key
:
"
Student"
Train
:
Train
:
dataset
:
dataset
:
...
@@ -132,7 +85,6 @@ Train:
...
@@ -132,7 +85,6 @@ Train:
shuffle
:
true
shuffle
:
true
batch_size_per_card
:
128
batch_size_per_card
:
128
drop_last
:
true
drop_last
:
true
num_sections
:
1
num_workers
:
8
num_workers
:
8
Eval
:
Eval
:
dataset
:
dataset
:
...
...
configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
0 → 100644
浏览文件 @
7dc56191
Global
:
debug
:
false
use_gpu
:
true
epoch_num
:
800
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/rec_pp-OCRv2_distillation
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
pretrained_model
:
checkpoints
:
save_inference_dir
:
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
character_type
:
ch
max_text_length
:
25
infer_mode
:
false
use_space_char
:
true
distributed
:
true
save_res_path
:
./output/rec/predicts_pp-OCRv2_distillation.txt
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Piecewise
decay_epochs
:
[
700
,
800
]
values
:
[
0.001
,
0.0001
]
warmup_epoch
:
5
regularizer
:
name
:
L2
factor
:
2.0e-05
Architecture
:
model_type
:
&model_type
"
rec"
name
:
DistillationModel
algorithm
:
Distillation
Models
:
Teacher
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
CRNN
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
Student
:
pretrained
:
freeze_params
:
false
return_all_feats
:
true
model_type
:
*model_type
algorithm
:
CRNN
Transform
:
Backbone
:
name
:
MobileNetV1Enhance
scale
:
0.5
Neck
:
name
:
SequenceEncoder
encoder_type
:
rnn
hidden_size
:
64
Head
:
name
:
CTCHead
mid_channels
:
96
fc_decay
:
0.00002
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationCTCLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
-
DistillationDMLLoss
:
weight
:
1.0
act
:
"
softmax"
use_log
:
true
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
head_out
-
DistillationDistanceLoss
:
weight
:
1.0
mode
:
"
l2"
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
PostProcess
:
name
:
DistillationCTCLabelDecode
model_name
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
Metric
:
name
:
DistillationMetric
base_metric_name
:
RecMetric
main_indicator
:
acc
key
:
"
Student"
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/
label_file_list
:
-
./train_data/train_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
RecAug
:
-
CTCLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
true
batch_size_per_card
:
128
drop_last
:
true
num_sections
:
1
num_workers
:
8
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data
label_file_list
:
-
./train_data/val_list.txt
transforms
:
-
DecodeImage
:
img_mode
:
BGR
channel_first
:
false
-
CTCLabelEncode
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
keep_keys
:
-
image
-
label
-
length
loader
:
shuffle
:
false
drop_last
:
false
batch_size_per_card
:
128
num_workers
:
8
doc/doc_ch/knowledge_distillation.md
浏览文件 @
7dc56191
...
@@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
...
@@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
### 2.1 识别配置文件解析
### 2.1 识别配置文件解析
配置文件在
[
ch_PP-OCRv2_rec
.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec
.yml
)
。
配置文件在
[
ch_PP-OCRv2_rec
_distillation.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation
.yml
)
。
#### 2.1.1 模型结构
#### 2.1.1 模型结构
...
@@ -246,6 +246,39 @@ Metric:
...
@@ -246,6 +246,39 @@ Metric:
关于
`DistillationMetric`
更加具体的实现可以参考:
[
distillation_metric.py
](
../../ppocr/metrics/distillation_metric.py#L24
)
。
关于
`DistillationMetric`
更加具体的实现可以参考:
[
distillation_metric.py
](
../../ppocr/metrics/distillation_metric.py#L24
)
。
#### 2.1.5 蒸馏模型微调
对蒸馏得到的识别蒸馏进行微调有2种方式。
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在
[
ch_PP-OCRv2_rec_distillation.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
)
中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。
*
首先下载预训练模型并解压。
```
shell
# 下面预训练模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
tar
-xf
ch_PP-OCRv2_rec_train.tar
```
*
然后使用python,对其中的学生模型参数进行提取
```
python
import
paddle
# 加载预训练模型
all_params
=
paddle
.
load
(
"ch_PP-OCRv2_rec_train/best_accuracy.pdparams"
)
# 查看权重参数的keys
print
(
all_params
.
keys
())
# 学生模型的权重提取
s_params
=
{
key
[
len
(
"Student."
):]:
all_params
[
key
]
for
key
in
all_params
if
"Student."
in
key
}
# 查看学生模型权重参数的keys
print
(
s_params
.
keys
())
# 保存
paddle
.
save
(
s_params
,
"ch_PP-OCRv2_rec_train/student.pdparams"
)
```
转化完成之后,使用
[
ch_PP-OCRv2_rec.yml
](
../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
)
,修改预训练模型的路径(为导出的
`student.pdparams`
模型路径)以及自己的数据路径,即可进行模型微调。
### 2.2 检测配置文件解析
### 2.2 检测配置文件解析
*
coming soon!
*
coming soon!
ppocr/losses/basic_loss.py
浏览文件 @
7dc56191
...
@@ -56,31 +56,34 @@ class CELoss(nn.Layer):
...
@@ -56,31 +56,34 @@ class CELoss(nn.Layer):
class
KLJSLoss
(
object
):
class
KLJSLoss
(
object
):
def
__init__
(
self
,
mode
=
'kl'
):
def
__init__
(
self
,
mode
=
'kl'
):
assert
mode
in
[
'kl'
,
'js'
,
'KL'
,
'JS'
],
"mode can only be one of ['kl', 'js', 'KL', 'JS']"
assert
mode
in
[
'kl'
,
'js'
,
'KL'
,
'JS'
],
"mode can only be one of ['kl', 'js', 'KL', 'JS']"
self
.
mode
=
mode
self
.
mode
=
mode
def
__call__
(
self
,
p1
,
p2
,
reduction
=
"mean"
):
def
__call__
(
self
,
p1
,
p2
,
reduction
=
"mean"
):
loss
=
paddle
.
multiply
(
p2
,
paddle
.
log
(
(
p2
+
1e-5
)
/
(
p1
+
1e-5
)
+
1e-5
))
loss
=
paddle
.
multiply
(
p2
,
paddle
.
log
(
(
p2
+
1e-5
)
/
(
p1
+
1e-5
)
+
1e-5
))
if
self
.
mode
.
lower
()
==
"js"
:
if
self
.
mode
.
lower
()
==
"js"
:
loss
+=
paddle
.
multiply
(
p1
,
paddle
.
log
((
p1
+
1e-5
)
/
(
p2
+
1e-5
)
+
1e-5
))
loss
+=
paddle
.
multiply
(
p1
,
paddle
.
log
((
p1
+
1e-5
)
/
(
p2
+
1e-5
)
+
1e-5
))
loss
*=
0.5
loss
*=
0.5
if
reduction
==
"mean"
:
if
reduction
==
"mean"
:
loss
=
paddle
.
mean
(
loss
,
axis
=
[
1
,
2
])
loss
=
paddle
.
mean
(
loss
,
axis
=
[
1
,
2
])
elif
reduction
==
"none"
or
reduction
is
None
:
elif
reduction
==
"none"
or
reduction
is
None
:
return
loss
return
loss
else
:
else
:
loss
=
paddle
.
sum
(
loss
,
axis
=
[
1
,
2
])
loss
=
paddle
.
sum
(
loss
,
axis
=
[
1
,
2
])
return
loss
return
loss
class
DMLLoss
(
nn
.
Layer
):
class
DMLLoss
(
nn
.
Layer
):
"""
"""
DMLLoss
DMLLoss
"""
"""
def
__init__
(
self
,
act
=
None
):
def
__init__
(
self
,
act
=
None
,
use_log
=
False
):
super
().
__init__
()
super
().
__init__
()
if
act
is
not
None
:
if
act
is
not
None
:
assert
act
in
[
"softmax"
,
"sigmoid"
]
assert
act
in
[
"softmax"
,
"sigmoid"
]
...
@@ -90,20 +93,24 @@ class DMLLoss(nn.Layer):
...
@@ -90,20 +93,24 @@ class DMLLoss(nn.Layer):
self
.
act
=
nn
.
Sigmoid
()
self
.
act
=
nn
.
Sigmoid
()
else
:
else
:
self
.
act
=
None
self
.
act
=
None
self
.
use_log
=
use_log
self
.
jskl_loss
=
KLJSLoss
(
mode
=
"js"
)
self
.
jskl_loss
=
KLJSLoss
(
mode
=
"js"
)
def
forward
(
self
,
out1
,
out2
):
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
out2
=
self
.
act
(
out2
)
if
len
(
out1
.
shape
)
<
2
:
if
self
.
use_log
:
# for recognition distillation, log is needed for feature map
log_out1
=
paddle
.
log
(
out1
)
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
else
:
else
:
# for detection distillation log is not needed
loss
=
self
.
jskl_loss
(
out1
,
out2
)
loss
=
self
.
jskl_loss
(
out1
,
out2
)
return
loss
return
loss
...
...
ppocr/losses/combined_loss.py
浏览文件 @
7dc56191
...
@@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer):
...
@@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer):
loss
=
loss_func
(
input
,
batch
,
**
kargs
)
loss
=
loss_func
(
input
,
batch
,
**
kargs
)
if
isinstance
(
loss
,
paddle
.
Tensor
):
if
isinstance
(
loss
,
paddle
.
Tensor
):
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
weight
=
self
.
loss_weight
[
idx
]
weight
=
self
.
loss_weight
[
idx
]
for
key
in
loss
.
keys
():
if
key
==
"loss"
:
loss
=
{
key
:
loss
[
key
]
*
weight
for
key
in
loss
}
loss_all
+=
loss
[
key
]
*
weight
else
:
if
"loss"
in
loss
:
loss_dict
[
"{}_{}"
.
format
(
key
,
idx
)]
=
loss
[
key
]
loss_all
+=
loss
[
"loss"
]
else
:
loss_all
+=
paddle
.
add_n
(
list
(
loss
.
values
()))
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
loss_all
loss_dict
[
"loss"
]
=
loss_all
return
loss_dict
return
loss_dict
ppocr/losses/distillation_loss.py
浏览文件 @
7dc56191
...
@@ -44,20 +44,22 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -44,20 +44,22 @@ class DistillationDMLLoss(DMLLoss):
def
__init__
(
self
,
def
__init__
(
self
,
model_name_pairs
=
[],
model_name_pairs
=
[],
act
=
None
,
act
=
None
,
use_log
=
False
,
key
=
None
,
key
=
None
,
maps_name
=
None
,
maps_name
=
None
,
name
=
"dml"
):
name
=
"dml"
):
super
().
__init__
(
act
=
act
)
super
().
__init__
(
act
=
act
,
use_log
=
use_log
)
assert
isinstance
(
model_name_pairs
,
list
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
key
=
key
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
name
=
name
self
.
name
=
name
self
.
maps_name
=
self
.
_check_maps_name
(
maps_name
)
self
.
maps_name
=
self
.
_check_maps_name
(
maps_name
)
def
_check_model_name_pairs
(
self
,
model_name_pairs
):
def
_check_model_name_pairs
(
self
,
model_name_pairs
):
if
not
isinstance
(
model_name_pairs
,
list
):
if
not
isinstance
(
model_name_pairs
,
list
):
return
[]
return
[]
elif
isinstance
(
model_name_pairs
[
0
],
list
)
and
isinstance
(
model_name_pairs
[
0
][
0
],
str
):
elif
isinstance
(
model_name_pairs
[
0
],
list
)
and
isinstance
(
model_name_pairs
[
0
][
0
],
str
):
return
model_name_pairs
return
model_name_pairs
else
:
else
:
return
[
model_name_pairs
]
return
[
model_name_pairs
]
...
@@ -112,9 +114,9 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -112,9 +114,9 @@ class DistillationDMLLoss(DMLLoss):
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
map_name
,
idx
)]
=
loss
[
key
]
0
],
pair
[
1
],
map_name
,
idx
)]
=
loss
[
key
]
else
:
else
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
[
_c
],
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
[
idx
)]
=
loss
_c
],
idx
)]
=
loss
loss_dict
=
_sum_loss
(
loss_dict
)
loss_dict
=
_sum_loss
(
loss_dict
)
return
loss_dict
return
loss_dict
...
...
ppocr/utils/save_load.py
浏览文件 @
7dc56191
...
@@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
...
@@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
else
:
logger
.
info
(
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
)
model
.
set_state_dict
(
new_state_dict
)
model
.
set_state_dict
(
new_state_dict
)
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
return
{}
return
{}
def
load_pretrained_params
(
model
,
path
):
def
load_pretrained_params
(
model
,
path
):
if
path
is
None
:
if
path
is
None
:
return
False
return
False
...
@@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
...
@@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
print
(
f
"load pretrain successful from
{
path
}
"
)
print
(
f
"load pretrain successful from
{
path
}
"
)
return
model
return
model
def
save_model
(
model
,
def
save_model
(
model
,
optimizer
,
optimizer
,
model_path
,
model_path
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录