Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
43a03a0c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
43a03a0c
编写于
4月 20, 2022
作者:
W
Walter
提交者:
GitHub
4月 20, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1853 from HydrogenSulfate/multi_optim
support for multi optimizer case
上级
42cb6435
15242df1
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
149 addition
and
60 deletion
+149
-60
ppcls/engine/engine.py
ppcls/engine/engine.py
+25
-11
ppcls/engine/train/train.py
ppcls/engine/train/train.py
+11
-5
ppcls/engine/train/utils.py
ppcls/engine/train/utils.py
+10
-6
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/optimizer/__init__.py
ppcls/optimizer/__init__.py
+76
-26
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+26
-12
未找到文件。
ppcls/engine/engine.py
浏览文件 @
43a03a0c
...
@@ -214,16 +214,19 @@ class Engine(object):
...
@@ -214,16 +214,19 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
if
self
.
config
[
"Global"
][
"pretrained_model"
]
is
not
None
:
if
self
.
config
[
"Global"
][
"pretrained_model"
].
startswith
(
"http"
):
if
self
.
config
[
"Global"
][
"pretrained_model"
].
startswith
(
"http"
):
load_dygraph_pretrain_from_url
(
load_dygraph_pretrain_from_url
(
self
.
model
,
self
.
config
[
"Global"
][
"pretrained_model"
])
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)],
self
.
config
[
"Global"
][
"pretrained_model"
])
else
:
else
:
load_dygraph_pretrain
(
load_dygraph_pretrain
(
self
.
model
,
self
.
config
[
"Global"
][
"pretrained_model"
])
[
self
.
model
,
getattr
(
self
,
'train_loss_func'
,
None
)],
self
.
config
[
"Global"
][
"pretrained_model"
])
# build optimizer
# build optimizer
if
self
.
mode
==
'train'
:
if
self
.
mode
==
'train'
:
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
config
[
"Optimizer"
],
self
.
config
[
"Global"
][
"epochs"
],
self
.
config
,
self
.
config
[
"Global"
][
"epochs"
],
len
(
self
.
train_dataloader
),
[
self
.
model
])
len
(
self
.
train_dataloader
),
[
self
.
model
,
self
.
train_loss_func
])
# for amp training
# for amp training
if
self
.
amp
:
if
self
.
amp
:
...
@@ -241,6 +244,11 @@ class Engine(object):
...
@@ -241,6 +244,11 @@ class Engine(object):
optimizers
=
self
.
optimizer
,
optimizers
=
self
.
optimizer
,
level
=
amp_level
,
level
=
amp_level
,
save_dtype
=
'float32'
)
save_dtype
=
'float32'
)
if
len
(
self
.
train_loss_func
.
parameters
())
>
0
:
self
.
train_loss_func
=
paddle
.
amp
.
decorate
(
models
=
self
.
train_loss_func
,
level
=
amp_level
,
save_dtype
=
'float32'
)
# for distributed
# for distributed
world_size
=
dist
.
get_world_size
()
world_size
=
dist
.
get_world_size
()
...
@@ -251,7 +259,10 @@ class Engine(object):
...
@@ -251,7 +259,10 @@ class Engine(object):
if
self
.
config
[
"Global"
][
"distributed"
]:
if
self
.
config
[
"Global"
][
"distributed"
]:
dist
.
init_parallel_env
()
dist
.
init_parallel_env
()
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
self
.
model
=
paddle
.
DataParallel
(
self
.
model
)
if
self
.
mode
==
'train'
and
len
(
self
.
train_loss_func
.
parameters
(
))
>
0
:
self
.
train_loss_func
=
paddle
.
DataParallel
(
self
.
train_loss_func
)
# build postprocess for infer
# build postprocess for infer
if
self
.
mode
==
'infer'
:
if
self
.
mode
==
'infer'
:
self
.
preprocess_func
=
create_operators
(
self
.
config
[
"Infer"
][
self
.
preprocess_func
=
create_operators
(
self
.
config
[
"Infer"
][
...
@@ -279,9 +290,9 @@ class Engine(object):
...
@@ -279,9 +290,9 @@ class Engine(object):
# global iter counter
# global iter counter
self
.
global_step
=
0
self
.
global_step
=
0
if
self
.
config
[
"Global"
][
"checkpoints"
]
is
not
None
:
if
self
.
config
.
Global
.
checkpoints
is
not
None
:
metric_info
=
init_model
(
self
.
config
[
"Global"
]
,
self
.
model
,
metric_info
=
init_model
(
self
.
config
.
Global
,
self
.
model
,
self
.
optimizer
)
self
.
optimizer
,
self
.
train_loss_func
)
if
metric_info
is
not
None
:
if
metric_info
is
not
None
:
best_metric
.
update
(
metric_info
)
best_metric
.
update
(
metric_info
)
...
@@ -317,7 +328,8 @@ class Engine(object):
...
@@ -317,7 +328,8 @@ class Engine(object):
best_metric
,
best_metric
,
self
.
output_dir
,
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"best_model"
)
prefix
=
"best_model"
,
loss
=
self
.
train_loss_func
)
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
epoch_id
,
best_metric
[
"metric"
]))
epoch_id
,
best_metric
[
"metric"
]))
logger
.
scaler
(
logger
.
scaler
(
...
@@ -336,7 +348,8 @@ class Engine(object):
...
@@ -336,7 +348,8 @@ class Engine(object):
"epoch"
:
epoch_id
},
"epoch"
:
epoch_id
},
self
.
output_dir
,
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"epoch_{}"
.
format
(
epoch_id
))
prefix
=
"epoch_{}"
.
format
(
epoch_id
),
loss
=
self
.
train_loss_func
)
# save the latest model
# save the latest model
save_load
.
save_model
(
save_load
.
save_model
(
self
.
model
,
self
.
model
,
...
@@ -344,7 +357,8 @@ class Engine(object):
...
@@ -344,7 +357,8 @@ class Engine(object):
"epoch"
:
epoch_id
},
"epoch"
:
epoch_id
},
self
.
output_dir
,
self
.
output_dir
,
model_name
=
self
.
config
[
"Arch"
][
"name"
],
model_name
=
self
.
config
[
"Arch"
][
"name"
],
prefix
=
"latest"
)
prefix
=
"latest"
,
loss
=
self
.
train_loss_func
)
if
self
.
vdl_writer
is
not
None
:
if
self
.
vdl_writer
is
not
None
:
self
.
vdl_writer
.
close
()
self
.
vdl_writer
.
close
()
...
...
ppcls/engine/train/train.py
浏览文件 @
43a03a0c
...
@@ -53,16 +53,22 @@ def train_epoch(engine, epoch_id, print_batch_step):
...
@@ -53,16 +53,22 @@ def train_epoch(engine, epoch_id, print_batch_step):
out
=
forward
(
engine
,
batch
)
out
=
forward
(
engine
,
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
# step opt
and lr
# step opt
if
engine
.
amp
:
if
engine
.
amp
:
scaled
=
engine
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
=
engine
.
scaler
.
scale
(
loss_dict
[
"loss"
])
scaled
.
backward
()
scaled
.
backward
()
engine
.
scaler
.
minimize
(
engine
.
optimizer
,
scaled
)
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
scaler
.
minimize
(
engine
.
optimizer
[
i
],
scaled
)
else
:
else
:
loss_dict
[
"loss"
].
backward
()
loss_dict
[
"loss"
].
backward
()
engine
.
optimizer
.
step
()
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
.
clear_grad
()
engine
.
optimizer
[
i
].
step
()
engine
.
lr_sch
.
step
()
# clear grad
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
clear_grad
()
# step lr
for
i
in
range
(
len
(
engine
.
lr_sch
)):
engine
.
lr_sch
[
i
].
step
()
# below code just for logging
# below code just for logging
# update metric_for_logger
# update metric_for_logger
...
...
ppcls/engine/train/utils.py
浏览文件 @
43a03a0c
...
@@ -38,7 +38,10 @@ def update_loss(trainer, loss_dict, batch_size):
...
@@ -38,7 +38,10 @@ def update_loss(trainer, loss_dict, batch_size):
def
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
):
def
log_info
(
trainer
,
batch_size
,
epoch_id
,
iter_id
):
lr_msg
=
"lr: {:.5f}"
.
format
(
trainer
.
lr_sch
.
get_lr
())
lr_msg
=
", "
.
join
([
"lr_{}: {:.8f}"
.
format
(
i
+
1
,
lr
.
get_lr
())
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
)
])
metric_msg
=
", "
.
join
([
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
trainer
.
output_info
[
key
].
avg
)
"{}: {:.5f}"
.
format
(
key
,
trainer
.
output_info
[
key
].
avg
)
for
key
in
trainer
.
output_info
for
key
in
trainer
.
output_info
...
@@ -59,11 +62,12 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
...
@@ -59,11 +62,12 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
len
(
trainer
.
train_dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
len
(
trainer
.
train_dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
eta_msg
))
logger
.
scaler
(
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
):
name
=
"lr"
,
logger
.
scaler
(
value
=
trainer
.
lr_sch
.
get_lr
(),
name
=
"lr_{}"
.
format
(
i
+
1
),
step
=
trainer
.
global_step
,
value
=
lr
.
get_lr
(),
writer
=
trainer
.
vdl_writer
)
step
=
trainer
.
global_step
,
writer
=
trainer
.
vdl_writer
)
for
key
in
trainer
.
output_info
:
for
key
in
trainer
.
output_info
:
logger
.
scaler
(
logger
.
scaler
(
name
=
"train_{}"
.
format
(
key
),
name
=
"train_{}"
.
format
(
key
),
...
...
ppcls/loss/__init__.py
浏览文件 @
43a03a0c
...
@@ -47,6 +47,7 @@ class CombinedLoss(nn.Layer):
...
@@ -47,6 +47,7 @@ class CombinedLoss(nn.Layer):
param
.
keys
())
param
.
keys
())
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
self
.
loss_weight
.
append
(
param
.
pop
(
"weight"
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
self
.
loss_func
.
append
(
eval
(
name
)(
**
param
))
self
.
loss_func
=
nn
.
LayerList
(
self
.
loss_func
)
def
__call__
(
self
,
input
,
batch
):
def
__call__
(
self
,
input
,
batch
):
loss_dict
=
{}
loss_dict
=
{}
...
...
ppcls/optimizer/__init__.py
浏览文件 @
43a03a0c
...
@@ -18,6 +18,7 @@ from __future__ import print_function
...
@@ -18,6 +18,7 @@ from __future__ import print_function
import
copy
import
copy
import
paddle
import
paddle
from
typing
import
Dict
,
List
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
...
@@ -44,29 +45,78 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
...
@@ -44,29 +45,78 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
# model_list is None in static graph
# model_list is None in static graph
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
=
None
):
def
build_optimizer
(
config
,
epochs
,
step_each_epoch
,
model_list
=
None
):
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
# step1 build lr
optim_config
=
config
[
"Optimizer"
]
lr
=
build_lr_scheduler
(
config
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
if
isinstance
(
optim_config
,
dict
):
logger
.
debug
(
"build lr ({}) success.."
.
format
(
lr
))
# convert {'name': xxx, **optim_cfg} to [{name: {scope: xxx, **optim_cfg}}]
# step2 build regularization
optim_name
=
optim_config
.
pop
(
"name"
)
if
'regularizer'
in
config
and
config
[
'regularizer'
]
is
not
None
:
optim_config
:
List
[
Dict
[
str
,
Dict
]]
=
[{
if
'weight_decay'
in
config
:
optim_name
:
{
logger
.
warning
(
'scope'
:
"all"
,
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
**
)
optim_config
reg_config
=
config
.
pop
(
'regularizer'
)
}
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
}]
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
optim_list
=
[]
config
[
"weight_decay"
]
=
reg
lr_list
=
[]
logger
.
debug
(
"build regularizer ({}) success.."
.
format
(
reg
))
"""NOTE:
# step3 build optimizer
Currently only support optim objets below.
optim_name
=
config
.
pop
(
'name'
)
1. single optimizer config.
if
'clip_norm'
in
config
:
2. next level uner Arch, such as Arch.backbone, Arch.neck, Arch.head.
clip_norm
=
config
.
pop
(
'clip_norm'
)
3. loss which has parameters, such as CenterLoss.
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
"""
else
:
for
optim_item
in
optim_config
:
grad_clip
=
None
# optim_cfg = {optim_name: {scope: xxx, **optim_cfg}}
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
# step1 build lr
grad_clip
=
grad_clip
,
optim_name
=
list
(
optim_item
.
keys
())[
0
]
# get optim_name
**
config
)(
model_list
=
model_list
)
optim_scope
=
optim_item
[
optim_name
].
pop
(
'scope'
)
# get optim_scope
logger
.
debug
(
"build optimizer ({}) success.."
.
format
(
optim
))
optim_cfg
=
optim_item
[
optim_name
]
# get optim_cfg
return
optim
,
lr
lr
=
build_lr_scheduler
(
optim_cfg
.
pop
(
'lr'
),
epochs
,
step_each_epoch
)
logger
.
debug
(
"build lr ({}) for scope ({}) success.."
.
format
(
lr
,
optim_scope
))
# step2 build regularization
if
'regularizer'
in
optim_cfg
and
optim_cfg
[
'regularizer'
]
is
not
None
:
if
'weight_decay'
in
optim_cfg
:
logger
.
warning
(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config.
\"
weight_decay
\"
has been ignored."
)
reg_config
=
optim_cfg
.
pop
(
'regularizer'
)
reg_name
=
reg_config
.
pop
(
'name'
)
+
'Decay'
reg
=
getattr
(
paddle
.
regularizer
,
reg_name
)(
**
reg_config
)
optim_cfg
[
"weight_decay"
]
=
reg
logger
.
debug
(
"build regularizer ({}) for scope ({}) success.."
.
format
(
reg
,
optim_scope
))
# step3 build optimizer
if
'clip_norm'
in
optim_cfg
:
clip_norm
=
optim_cfg
.
pop
(
'clip_norm'
)
grad_clip
=
paddle
.
nn
.
ClipGradByNorm
(
clip_norm
=
clip_norm
)
else
:
grad_clip
=
None
optim_model
=
[]
for
i
in
range
(
len
(
model_list
)):
if
len
(
model_list
[
i
].
parameters
())
==
0
:
continue
if
optim_scope
==
"all"
:
# optimizer for all
optim_model
.
append
(
model_list
[
i
])
else
:
if
optim_scope
.
endswith
(
"Loss"
):
# optimizer for loss
for
m
in
model_list
[
i
].
sublayers
(
True
):
if
m
.
__class_name
==
optim_scope
:
optim_model
.
append
(
m
)
else
:
# opmizer for module in model, such as backbone, neck, head...
if
hasattr
(
model_list
[
i
],
optim_scope
):
optim_model
.
append
(
getattr
(
model_list
[
i
],
optim_scope
))
assert
len
(
optim_model
)
==
1
,
\
"Invalid optim model for optim scope({}), number of optim_model={}"
.
format
(
optim_scope
,
len
(
optim_model
))
optim
=
getattr
(
optimizer
,
optim_name
)(
learning_rate
=
lr
,
grad_clip
=
grad_clip
,
**
optim_cfg
)(
model_list
=
optim_model
)
logger
.
debug
(
"build optimizer ({}) for scope ({}) success.."
.
format
(
optim
,
optim_scope
))
optim_list
.
append
(
optim
)
lr_list
.
append
(
lr
)
return
optim_list
,
lr_list
ppcls/utils/save_load.py
浏览文件 @
43a03a0c
...
@@ -18,9 +18,6 @@ from __future__ import print_function
...
@@ -18,9 +18,6 @@ from __future__ import print_function
import
errno
import
errno
import
os
import
os
import
re
import
shutil
import
tempfile
import
paddle
import
paddle
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
...
@@ -47,10 +44,15 @@ def _mkdir_if_not_exist(path):
...
@@ -47,10 +44,15 @@ def _mkdir_if_not_exist(path):
def
load_dygraph_pretrain
(
model
,
path
=
None
):
def
load_dygraph_pretrain
(
model
,
path
=
None
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
raise
ValueError
(
"Model pretrain path {}
.pdparams
does not "
"exists."
.
format
(
path
))
"exists."
.
format
(
path
))
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
param_state_dict
=
paddle
.
load
(
path
+
".pdparams"
)
model
.
set_dict
(
param_state_dict
)
if
isinstance
(
model
,
list
):
for
m
in
model
:
if
hasattr
(
m
,
'set_dict'
):
m
.
set_dict
(
param_state_dict
)
else
:
model
.
set_dict
(
param_state_dict
)
return
return
...
@@ -85,7 +87,7 @@ def load_distillation_model(model, pretrained_model):
...
@@ -85,7 +87,7 @@ def load_distillation_model(model, pretrained_model):
pretrained_model
))
pretrained_model
))
def
init_model
(
config
,
net
,
optimizer
=
None
):
def
init_model
(
config
,
net
,
optimizer
=
None
,
loss
:
paddle
.
nn
.
Layer
=
None
):
"""
"""
load model from checkpoint or pretrained_model
load model from checkpoint or pretrained_model
"""
"""
...
@@ -95,11 +97,15 @@ def init_model(config, net, optimizer=None):
...
@@ -95,11 +97,15 @@ def init_model(config, net, optimizer=None):
"Given dir {}.pdparams not exist."
.
format
(
checkpoints
)
"Given dir {}.pdparams not exist."
.
format
(
checkpoints
)
assert
os
.
path
.
exists
(
checkpoints
+
".pdopt"
),
\
assert
os
.
path
.
exists
(
checkpoints
+
".pdopt"
),
\
"Given dir {}.pdopt not exist."
.
format
(
checkpoints
)
"Given dir {}.pdopt not exist."
.
format
(
checkpoints
)
para_dict
=
paddle
.
load
(
checkpoints
+
".pdparams"
)
# load state dict
opti_dict
=
paddle
.
load
(
checkpoints
+
".pdopt"
)
opti_dict
=
paddle
.
load
(
checkpoints
+
".pdopt"
)
para_dict
=
paddle
.
load
(
checkpoints
+
".pdparams"
)
metric_dict
=
paddle
.
load
(
checkpoints
+
".pdstates"
)
metric_dict
=
paddle
.
load
(
checkpoints
+
".pdstates"
)
net
.
set_dict
(
para_dict
)
# set state dict
optimizer
.
set_state_dict
(
opti_dict
)
net
.
set_state_dict
(
para_dict
)
loss
.
set_state_dict
(
para_dict
)
for
i
in
range
(
len
(
optimizer
)):
optimizer
[
i
].
set_state_dict
(
opti_dict
)
logger
.
info
(
"Finish load checkpoints from {}"
.
format
(
checkpoints
))
logger
.
info
(
"Finish load checkpoints from {}"
.
format
(
checkpoints
))
return
metric_dict
return
metric_dict
...
@@ -120,7 +126,8 @@ def save_model(net,
...
@@ -120,7 +126,8 @@ def save_model(net,
metric_info
,
metric_info
,
model_path
,
model_path
,
model_name
=
""
,
model_name
=
""
,
prefix
=
'ppcls'
):
prefix
=
'ppcls'
,
loss
:
paddle
.
nn
.
Layer
=
None
):
"""
"""
save model to the target path
save model to the target path
"""
"""
...
@@ -130,7 +137,14 @@ def save_model(net,
...
@@ -130,7 +137,14 @@ def save_model(net,
_mkdir_if_not_exist
(
model_path
)
_mkdir_if_not_exist
(
model_path
)
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
model_path
=
os
.
path
.
join
(
model_path
,
prefix
)
paddle
.
save
(
net
.
state_dict
(),
model_path
+
".pdparams"
)
params_state_dict
=
net
.
state_dict
()
paddle
.
save
(
optimizer
.
state_dict
(),
model_path
+
".pdopt"
)
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
loss_state_dict
.
keys
())
assert
len
(
keys_inter
)
==
0
,
\
f
"keys in model and loss state_dict must be unique, but got intersection
{
keys_inter
}
"
params_state_dict
.
update
(
loss_state_dict
)
paddle
.
save
(
params_state_dict
,
model_path
+
".pdparams"
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
model_path
+
".pdopt"
)
paddle
.
save
(
metric_info
,
model_path
+
".pdstates"
)
paddle
.
save
(
metric_info
,
model_path
+
".pdstates"
)
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
logger
.
info
(
"Already save model in {}"
.
format
(
model_path
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录