Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
915dde17
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 2 年 前同步成功
通知
118
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
915dde17
编写于
3月 14, 2023
作者:
T
Tingquan Gao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "refactor: rm train and eval from engine"
This reverts commit
5a6fe171
.
上级
aa52682c
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
455 addition
and
488 deletion
+455
-488
ppcls/engine/engine.py
ppcls/engine/engine.py
+192
-8
ppcls/engine/evaluation/__init__.py
ppcls/engine/evaluation/__init__.py
+4
-7
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+156
-177
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+6
-9
ppcls/engine/train/classification.py
ppcls/engine/train/classification.py
+0
-279
ppcls/engine/train/regular_train_epoch.py
ppcls/engine/train/regular_train_epoch.py
+89
-0
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+8
-8
未找到文件。
ppcls/engine/engine.py
浏览文件 @
915dde17
...
@@ -22,17 +22,25 @@ from paddle import nn
...
@@ -22,17 +22,25 @@ from paddle import nn
import
numpy
as
np
import
numpy
as
np
import
random
import
random
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
from
ppcls.utils.logger
import
init_logger
from
ppcls.utils.logger
import
init_logger
from
ppcls.utils.config
import
print_config
from
ppcls.utils.config
import
print_config
from
ppcls.data
import
build_dataloader
from
ppcls.arch
import
build_model
,
RecModel
,
DistillationModel
,
TheseusLayer
from
ppcls.arch
import
build_model
,
RecModel
,
DistillationModel
,
TheseusLayer
from
ppcls.loss
import
build_loss
from
ppcls.metric
import
build_metrics
from
ppcls.optimizer
import
build_optimizer
from
ppcls.utils.ema
import
ExponentialMovingAverage
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
from
ppcls.utils.save_load
import
init_model
,
ModelSaver
from
ppcls.data.utils.get_image_list
import
get_image_list
from
ppcls.data.utils.get_image_list
import
get_image_list
from
ppcls.data.postprocess
import
build_postprocess
from
ppcls.data.postprocess
import
build_postprocess
from
ppcls.data
import
create_operators
from
ppcls.data
import
create_operators
from
.train
import
build_train_func
from
.train
import
build_train_
epoch_
func
from
.evaluation
import
build_eval_func
from
.evaluation
import
build_eval_func
from
ppcls.engine.train.utils
import
type_name
from
ppcls.engine
import
evaluation
from
ppcls.engine
import
evaluation
from
ppcls.arch.gears.identity_head
import
IdentityHead
from
ppcls.arch.gears.identity_head
import
IdentityHead
...
@@ -42,35 +50,186 @@ class Engine(object):
...
@@ -42,35 +50,186 @@ class Engine(object):
assert
mode
in
[
"train"
,
"eval"
,
"infer"
,
"export"
]
assert
mode
in
[
"train"
,
"eval"
,
"infer"
,
"export"
]
self
.
mode
=
mode
self
.
mode
=
mode
self
.
config
=
config
self
.
config
=
config
self
.
start_eval_epoch
=
self
.
config
[
"Global"
].
get
(
"start_eval_epoch"
,
0
)
-
1
self
.
epochs
=
self
.
config
[
"Global"
].
get
(
"epochs"
,
1
)
# set seed
# set seed
self
.
_init_seed
()
self
.
_init_seed
()
# init logger
# init logger
log_file
=
os
.
path
.
join
(
self
.
config
[
'Global'
][
'output_dir'
],
self
.
output_dir
=
self
.
config
[
'Global'
][
'output_dir'
]
self
.
config
[
"Arch"
][
"name"
],
f
"
{
mode
}
.log"
)
log_file
=
os
.
path
.
join
(
self
.
output_dir
,
self
.
config
[
"Arch"
][
"name"
],
f
"
{
mode
}
.log"
)
init_logger
(
log_file
=
log_file
)
init_logger
(
log_file
=
log_file
)
# for visualdl
self
.
vdl_writer
=
self
.
_init_vdl
()
# init train_func and eval_func
self
.
train_epoch_func
=
build_train_epoch_func
(
self
.
config
)
self
.
eval_func
=
build_eval_func
(
self
.
config
)
# set device
# set device
self
.
_init_device
()
self
.
_init_device
()
# gradient accumulation
self
.
update_freq
=
self
.
config
[
"Global"
].
get
(
"update_freq"
,
1
)
# build dataloader
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
self
.
dataloader_dict
=
build_dataloader
(
self
.
config
,
mode
)
# build loss
self
.
train_loss_func
,
self
.
unlabel_train_loss_func
,
self
.
eval_loss_func
=
build_loss
(
self
.
config
,
self
.
mode
)
# build metric
self
.
train_metric_func
,
self
.
eval_metric_func
=
build_metrics
(
self
)
# build model
# build model
self
.
model
=
build_model
(
self
.
config
,
self
.
mode
)
self
.
model
=
build_model
(
self
.
config
,
self
.
mode
)
# load_pretrain
# load_pretrain
self
.
_init_pretrained
()
self
.
_init_pretrained
()
#
init train_func and eval_func
#
build optimizer
self
.
eval
=
build_eval_func
(
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
)
self
.
config
,
mode
=
self
.
mode
,
model
=
self
.
model
)
self
.
train
=
build_train_func
(
# AMP training and evaluating
self
.
config
,
mode
=
self
.
mode
,
model
=
self
.
model
,
eval_func
=
self
.
eval
)
self
.
_init_amp
(
)
# for distributed
# for distributed
self
.
_init_dist
()
self
.
_init_dist
()
# build model saver
self
.
model_saver
=
ModelSaver
(
self
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
)
print_config
(
config
)
print_config
(
config
)
def
train
(
self
):
assert
self
.
mode
==
"train"
print_batch_step
=
self
.
config
[
'Global'
][
'print_batch_step'
]
save_interval
=
self
.
config
[
"Global"
][
"save_interval"
]
best_metric
=
{
"metric"
:
-
1.0
,
"epoch"
:
0
,
}
# key:
# val: metrics list word
self
.
output_info
=
dict
()
self
.
time_info
=
{
"batch_cost"
:
AverageMeter
(
"batch_cost"
,
'.5f'
,
postfix
=
" s,"
),
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
# build EMA model
self
.
model_ema
=
self
.
_build_ema_model
()
# TODO: mv best_metric_ema to best_metric dict
best_metric_ema
=
0
self
.
_init_checkpoints
(
best_metric
)
# global iter counter
self
.
global_step
=
0
for
epoch_id
in
range
(
best_metric
[
"epoch"
]
+
1
,
self
.
epochs
+
1
):
# for one epoch train
self
.
train_epoch_func
(
self
,
epoch_id
,
print_batch_step
)
metric_msg
=
", "
.
join
(
[
self
.
output_info
[
key
].
avg_info
for
key
in
self
.
output_info
])
logger
.
info
(
"[Train][Epoch {}/{}][Avg]{}"
.
format
(
epoch_id
,
self
.
epochs
,
metric_msg
))
self
.
output_info
.
clear
()
acc
=
0.0
if
self
.
config
[
"Global"
][
"eval_during_train"
]
and
epoch_id
%
self
.
config
[
"Global"
][
"eval_interval"
]
==
0
and
epoch_id
>
self
.
start_eval_epoch
:
acc
=
self
.
eval
(
epoch_id
)
# step lr (by epoch) according to given metric, such as acc
for
i
in
range
(
len
(
self
.
lr_sch
)):
if
getattr
(
self
.
lr_sch
[
i
],
"by_epoch"
,
False
)
and
\
type_name
(
self
.
lr_sch
[
i
])
==
"ReduceOnPlateau"
:
self
.
lr_sch
[
i
].
step
(
acc
)
if
acc
>
best_metric
[
"metric"
]:
best_metric
[
"metric"
]
=
acc
best_metric
[
"epoch"
]
=
epoch_id
self
.
model_saver
.
save
(
best_metric
,
prefix
=
"best_model"
,
save_student_model
=
True
)
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
epoch_id
,
best_metric
[
"metric"
]))
logger
.
scaler
(
name
=
"eval_acc"
,
value
=
acc
,
step
=
epoch_id
,
writer
=
self
.
vdl_writer
)
self
.
model
.
train
()
if
self
.
model_ema
:
ori_model
,
self
.
model
=
self
.
model
,
self
.
model_ema
.
module
acc_ema
=
self
.
eval
(
epoch_id
)
self
.
model
=
ori_model
self
.
model_ema
.
module
.
eval
()
if
acc_ema
>
best_metric_ema
:
best_metric_ema
=
acc_ema
self
.
model_saver
.
save
(
{
"metric"
:
acc_ema
,
"epoch"
:
epoch_id
},
prefix
=
"best_model_ema"
)
logger
.
info
(
"[Eval][Epoch {}][best metric ema: {}]"
.
format
(
epoch_id
,
best_metric_ema
))
logger
.
scaler
(
name
=
"eval_acc_ema"
,
value
=
acc_ema
,
step
=
epoch_id
,
writer
=
self
.
vdl_writer
)
# save model
if
save_interval
>
0
and
epoch_id
%
save_interval
==
0
:
self
.
model_saver
.
save
(
{
"metric"
:
acc
,
"epoch"
:
epoch_id
},
prefix
=
f
"epoch_
{
epoch_id
}
"
)
# save the latest model
self
.
model_saver
.
save
(
{
"metric"
:
acc
,
"epoch"
:
epoch_id
},
prefix
=
"latest"
)
if
self
.
vdl_writer
is
not
None
:
self
.
vdl_writer
.
close
()
@
paddle
.
no_grad
()
def
eval
(
self
,
epoch_id
=
0
):
assert
self
.
mode
in
[
"train"
,
"eval"
]
self
.
model
.
eval
()
eval_result
=
self
.
eval_func
(
self
,
epoch_id
)
self
.
model
.
train
()
return
eval_result
@
paddle
.
no_grad
()
@
paddle
.
no_grad
()
def
infer
(
self
):
def
infer
(
self
):
assert
self
.
mode
==
"infer"
and
self
.
eval_mode
==
"classification"
assert
self
.
mode
==
"infer"
and
self
.
eval_mode
==
"classification"
...
@@ -167,6 +326,15 @@ class Engine(object):
...
@@ -167,6 +326,15 @@ class Engine(object):
f
"Export succeeded! The inference model exported has been saved in
\"
{
self
.
config
[
'Global'
][
'save_inference_dir'
]
}
\"
."
f
"Export succeeded! The inference model exported has been saved in
\"
{
self
.
config
[
'Global'
][
'save_inference_dir'
]
}
\"
."
)
)
def
_init_vdl
(
self
):
if
self
.
config
[
'Global'
][
'use_visualdl'
]
and
mode
==
"train"
and
dist
.
get_rank
()
==
0
:
vdl_writer_path
=
os
.
path
.
join
(
self
.
output_dir
,
"vdl"
)
if
not
os
.
path
.
exists
(
vdl_writer_path
):
os
.
makedirs
(
vdl_writer_path
)
return
LogWriter
(
logdir
=
vdl_writer_path
)
return
None
def
_init_seed
(
self
):
def
_init_seed
(
self
):
seed
=
self
.
config
[
"Global"
].
get
(
"seed"
,
False
)
seed
=
self
.
config
[
"Global"
].
get
(
"seed"
,
False
)
if
dist
.
get_world_size
()
!=
1
:
if
dist
.
get_world_size
()
!=
1
:
...
@@ -287,6 +455,22 @@ class Engine(object):
...
@@ -287,6 +455,22 @@ class Engine(object):
self
.
train_loss_func
=
paddle
.
DataParallel
(
self
.
train_loss_func
=
paddle
.
DataParallel
(
self
.
train_loss_func
)
self
.
train_loss_func
)
def
_build_ema_model
(
self
):
if
"EMA"
in
self
.
config
and
self
.
mode
==
"train"
:
model_ema
=
ExponentialMovingAverage
(
self
.
model
,
self
.
config
[
'EMA'
].
get
(
"decay"
,
0.9999
))
return
model_ema
else
:
return
None
def
_init_checkpoints
(
self
,
best_metric
):
if
self
.
config
[
"Global"
].
get
(
"checkpoints"
,
None
)
is
not
None
:
metric_info
=
init_model
(
self
.
config
.
Global
,
self
.
model
,
self
.
optimizer
,
self
.
train_loss_func
,
self
.
model_ema
)
if
metric_info
is
not
None
:
best_metric
.
update
(
metric_info
)
class
ExportModel
(
TheseusLayer
):
class
ExportModel
(
TheseusLayer
):
"""
"""
...
...
ppcls/engine/evaluation/__init__.py
浏览文件 @
915dde17
...
@@ -12,18 +12,15 @@
...
@@ -12,18 +12,15 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.classification
import
ClassE
val
from
.classification
import
classification_e
val
from
.retrieval
import
retrieval_eval
from
.retrieval
import
retrieval_eval
from
.adaface
import
adaface_eval
from
.adaface
import
adaface_eval
def
build_eval_func
(
config
,
mode
,
model
):
def
build_eval_func
(
config
):
if
mode
not
in
[
"eval"
,
"train"
]:
return
None
eval_mode
=
config
[
"Global"
].
get
(
"eval_mode"
,
None
)
eval_mode
=
config
[
"Global"
].
get
(
"eval_mode"
,
None
)
if
eval_mode
is
None
:
if
eval_mode
is
None
:
config
[
"Global"
][
"eval_mode"
]
=
"classification"
config
[
"Global"
][
"eval_mode"
]
=
"classification"
return
ClassEval
(
config
,
mode
,
model
)
return
classification_eval
else
:
else
:
return
getattr
(
sys
.
modules
[
__name__
],
eval_mode
+
"_eval"
)(
config
,
return
getattr
(
sys
.
modules
[
__name__
],
eval_mode
+
"_eval"
)
mode
,
model
)
ppcls/engine/evaluation/classification.py
浏览文件 @
915dde17
...
@@ -18,185 +18,164 @@ import time
...
@@ -18,185 +18,164 @@ import time
import
platform
import
platform
import
paddle
import
paddle
from
...utils.misc
import
AverageMeter
from
ppcls.utils.misc
import
AverageMeter
from
...utils
import
logger
from
ppcls.utils
import
logger
from
...data
import
build_dataloader
from
...loss
import
build_loss
from
...metric
import
build_metrics
def
classification_eval
(
engine
,
epoch_id
=
0
):
if
hasattr
(
engine
.
eval_metric_func
,
"reset"
):
engine
.
eval_metric_func
.
reset
()
class
ClassEval
(
object
):
output_info
=
dict
()
def
__init__
(
self
,
config
,
mode
,
model
):
time_info
=
{
self
.
config
=
config
"batch_cost"
:
AverageMeter
(
self
.
model
=
model
"batch_cost"
,
'.5f'
,
postfix
=
" s,"
),
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
"reader_cost"
:
AverageMeter
(
self
.
eval_metric_func
=
build_metrics
(
config
,
"eval"
)
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
self
.
eval_dataloader
=
build_dataloader
(
config
,
"eval"
)
}
self
.
eval_loss_func
=
build_loss
(
config
,
"eval"
)
print_batch_step
=
engine
.
config
[
"Global"
][
"print_batch_step"
]
self
.
output_info
=
dict
()
tic
=
time
.
time
()
@
paddle
.
no_grad
()
total_samples
=
engine
.
dataloader_dict
[
"Eval"
].
total_samples
def
__call__
(
self
,
epoch_id
=
0
):
accum_samples
=
0
self
.
model
.
eval
()
max_iter
=
engine
.
dataloader_dict
[
"Eval"
].
max_iter
for
iter_id
,
batch
in
enumerate
(
engine
.
dataloader_dict
[
"Eval"
]):
if
hasattr
(
self
.
eval_metric_func
,
"reset"
):
if
iter_id
>=
max_iter
:
self
.
eval_metric_func
.
reset
()
break
if
iter_id
==
5
:
time_info
=
{
for
key
in
time_info
:
"batch_cost"
:
AverageMeter
(
time_info
[
key
].
reset
()
"batch_cost"
,
'.5f'
,
postfix
=
" s,"
),
"reader_cost"
:
AverageMeter
(
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
batch_size
=
batch
[
0
].
shape
[
0
]
}
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
])
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
if
not
engine
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
):
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
tic
=
time
.
time
()
total_samples
=
self
.
eval_dataloader
[
"Eval"
].
total_samples
# image input
accum_samples
=
0
if
engine
.
amp
and
engine
.
amp_eval
:
max_iter
=
self
.
eval_dataloader
[
"Eval"
].
max_iter
with
paddle
.
amp
.
auto_cast
(
for
iter_id
,
batch
in
enumerate
(
self
.
eval_dataloader
[
"Eval"
]):
custom_black_list
=
{
if
iter_id
>=
max_iter
:
"flatten_contiguous_range"
,
"greater_than"
break
},
if
iter_id
==
5
:
level
=
engine
.
amp_level
):
for
key
in
time_info
:
out
=
engine
.
model
(
batch
)
time_info
[
key
].
reset
()
else
:
out
=
engine
.
model
(
batch
)
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
batch
[
0
].
shape
[
0
]
# just for DistributedBatchSampler issue: repeat sampling
batch
[
0
]
=
paddle
.
to_tensor
(
batch
[
0
])
current_samples
=
batch_size
*
paddle
.
distributed
.
get_world_size
()
if
not
self
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
):
accum_samples
+=
current_samples
batch
[
1
]
=
batch
[
1
].
reshape
([
-
1
,
1
]).
astype
(
"int64"
)
if
isinstance
(
out
,
dict
)
and
"Student"
in
out
:
# image input
out
=
out
[
"Student"
]
# if engine.amp and engine.amp_eval:
if
isinstance
(
out
,
dict
)
and
"logits"
in
out
:
# with paddle.amp.auto_cast(
out
=
out
[
"logits"
]
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# gather Tensor when distributed
# },
if
paddle
.
distributed
.
get_world_size
()
>
1
:
# level=engine.amp_level):
label_list
=
[]
# out = engine.model(batch)
device_id
=
paddle
.
distributed
.
ParallelEnv
().
device_id
# else:
label
=
batch
[
1
].
cuda
(
device_id
)
if
engine
.
config
[
"Global"
][
# out = self.model(batch)
"device"
]
==
"gpu"
else
batch
[
1
]
out
=
self
.
model
(
batch
)
paddle
.
distributed
.
all_gather
(
label_list
,
label
)
labels
=
paddle
.
concat
(
label_list
,
0
)
# just for DistributedBatchSampler issue: repeat sampling
current_samples
=
batch_size
*
paddle
.
distributed
.
get_world_size
()
if
isinstance
(
out
,
list
):
accum_samples
+=
current_samples
preds
=
[]
for
x
in
out
:
if
isinstance
(
out
,
dict
)
and
"Student"
in
out
:
out
=
out
[
"Student"
]
if
isinstance
(
out
,
dict
)
and
"logits"
in
out
:
out
=
out
[
"logits"
]
# gather Tensor when distributed
if
paddle
.
distributed
.
get_world_size
()
>
1
:
label_list
=
[]
device_id
=
paddle
.
distributed
.
ParallelEnv
().
device_id
label
=
batch
[
1
].
cuda
(
device_id
)
if
self
.
config
[
"Global"
][
"device"
]
==
"gpu"
else
batch
[
1
]
paddle
.
distributed
.
all_gather
(
label_list
,
label
)
labels
=
paddle
.
concat
(
label_list
,
0
)
if
isinstance
(
out
,
list
):
preds
=
[]
for
x
in
out
:
pred_list
=
[]
paddle
.
distributed
.
all_gather
(
pred_list
,
x
)
pred_x
=
paddle
.
concat
(
pred_list
,
0
)
preds
.
append
(
pred_x
)
else
:
pred_list
=
[]
pred_list
=
[]
paddle
.
distributed
.
all_gather
(
pred_list
,
out
)
paddle
.
distributed
.
all_gather
(
pred_list
,
x
)
preds
=
paddle
.
concat
(
pred_list
,
0
)
pred_x
=
paddle
.
concat
(
pred_list
,
0
)
preds
.
append
(
pred_x
)
if
accum_samples
>
total_samples
and
not
self
.
use_dali
:
if
isinstance
(
preds
,
list
):
preds
=
[
pred
[:
total_samples
+
current_samples
-
accum_samples
]
for
pred
in
preds
]
else
:
preds
=
preds
[:
total_samples
+
current_samples
-
accum_samples
]
labels
=
labels
[:
total_samples
+
current_samples
-
accum_samples
]
current_samples
=
total_samples
+
current_samples
-
accum_samples
else
:
else
:
labels
=
batch
[
1
]
pred_list
=
[]
preds
=
out
paddle
.
distributed
.
all_gather
(
pred_list
,
out
)
preds
=
paddle
.
concat
(
pred_list
,
0
)
# calc loss
if
self
.
eval_loss_func
is
not
None
:
if
accum_samples
>
total_samples
and
not
engine
.
use_dali
:
# if self.amp and self.amp_eval:
if
isinstance
(
preds
,
list
):
# with paddle.amp.auto_cast(
preds
=
[
# custom_black_list={
pred
[:
total_samples
+
current_samples
-
accum_samples
]
# "flatten_contiguous_range", "greater_than"
for
pred
in
preds
# },
]
# level=engine.amp_level):
# loss_dict = engine.eval_loss_func(preds, labels)
# else:
loss_dict
=
self
.
eval_loss_func
(
preds
,
labels
)
for
key
in
loss_dict
:
if
key
not
in
self
.
output_info
:
self
.
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
self
.
output_info
[
key
].
update
(
float
(
loss_dict
[
key
]),
current_samples
)
# calc metric
if
self
.
eval_metric_func
is
not
None
:
self
.
eval_metric_func
(
preds
,
labels
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
time_msg
=
"s, "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
time_info
[
key
].
avg
)
for
key
in
time_info
])
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
batch_size
/
time_info
[
"batch_cost"
].
avg
)
if
"ATTRMetric"
in
self
.
config
[
"Metric"
][
"Eval"
][
0
]:
metric_msg
=
""
else
:
else
:
metric_msg
=
", "
.
join
([
preds
=
preds
[:
total_samples
+
current_samples
-
"{}: {:.5f}"
.
format
(
key
,
self
.
output_info
[
key
].
val
)
accum_samples
]
for
key
in
self
.
output_info
labels
=
labels
[:
total_samples
+
current_samples
-
])
accum_samples
]
metric_msg
+=
", {}"
.
format
(
self
.
eval_metric_func
.
avg_info
)
current_samples
=
total_samples
+
current_samples
-
accum_samples
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
max_iter
,
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
if
self
.
use_dali
:
self
.
eval_dataloader
[
"Eval"
].
reset
()
if
"ATTRMetric"
in
self
.
config
[
"Metric"
][
"Eval"
][
0
]:
metric_msg
=
", "
.
join
([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}"
.
format
(
*
self
.
eval_metric_func
.
attr_res
())
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
if
self
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
self
.
eval_metric_func
.
attr_res
()[
0
]
else
:
else
:
metric_msg
=
", "
.
join
([
labels
=
batch
[
1
]
"{}: {:.5f}"
.
format
(
key
,
self
.
output_info
[
key
].
avg
)
preds
=
out
for
key
in
self
.
output_info
# calc loss
if
engine
.
eval_loss_func
is
not
None
:
if
engine
.
amp
and
engine
.
amp_eval
:
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
},
level
=
engine
.
amp_level
):
loss_dict
=
engine
.
eval_loss_func
(
preds
,
labels
)
else
:
loss_dict
=
engine
.
eval_loss_func
(
preds
,
labels
)
for
key
in
loss_dict
:
if
key
not
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
float
(
loss_dict
[
key
]),
current_samples
)
# calc metric
if
engine
.
eval_metric_func
is
not
None
:
engine
.
eval_metric_func
(
preds
,
labels
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
time_msg
=
"s, "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
time_info
[
key
].
avg
)
for
key
in
time_info
])
])
metric_msg
+=
", {}"
.
format
(
self
.
eval_metric_func
.
avg_info
)
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
batch_size
/
time_info
[
"batch_cost"
].
avg
)
# do not try to save best eval.model
if
self
.
eval_metric_func
is
None
:
if
"ATTRMetric"
in
engine
.
config
[
"Metric"
][
"Eval"
][
0
]:
return
-
1
metric_msg
=
""
# return 1st metric in the dict
else
:
return
self
.
eval_metric_func
.
avg
metric_msg
=
", "
.
join
([
self
.
model
.
train
()
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
return
eval_result
for
key
in
output_info
])
metric_msg
+=
", {}"
.
format
(
engine
.
eval_metric_func
.
avg_info
)
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
max_iter
,
metric_msg
,
time_msg
,
ips_msg
))
tic
=
time
.
time
()
if
engine
.
use_dali
:
engine
.
dataloader_dict
[
"Eval"
].
reset
()
if
"ATTRMetric"
in
engine
.
config
[
"Metric"
][
"Eval"
][
0
]:
metric_msg
=
", "
.
join
([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}"
.
format
(
*
engine
.
eval_metric_func
.
attr_res
())
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
if
engine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
engine
.
eval_metric_func
.
attr_res
()[
0
]
else
:
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
metric_msg
+=
", {}"
.
format
(
engine
.
eval_metric_func
.
avg_info
)
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
if
engine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
engine
.
eval_metric_func
.
avg
ppcls/engine/train/__init__.py
浏览文件 @
915dde17
...
@@ -13,19 +13,16 @@
...
@@ -13,19 +13,16 @@
# limitations under the License.
# limitations under the License.
from
.train_metabin
import
train_epoch_metabin
from
.train_metabin
import
train_epoch_metabin
from
.
classification
import
ClassTrainer
from
.
regular_train_epoch
import
regular_train_epoch
from
.train_fixmatch
import
train_epoch_fixmatch
from
.train_fixmatch
import
train_epoch_fixmatch
from
.train_fixmatch_ccssl
import
train_epoch_fixmatch_ccssl
from
.train_fixmatch_ccssl
import
train_epoch_fixmatch_ccssl
from
.train_progressive
import
train_epoch_progressive
from
.train_progressive
import
train_epoch_progressive
def
build_train_func
(
config
,
mode
,
model
,
eval_func
):
def
build_train_epoch_func
(
config
):
if
mode
!=
"train"
:
train_mode
=
config
[
"Global"
].
get
(
"train_mode"
,
None
)
return
None
train_mode
=
config
[
"Global"
].
get
(
"task"
,
None
)
if
train_mode
is
None
:
if
train_mode
is
None
:
config
[
"Global"
][
"t
ask"
]
=
"classificatio
n"
config
[
"Global"
][
"t
rain_mode"
]
=
"regular_trai
n"
return
ClassTrainer
(
config
,
mode
,
model
,
eval_func
)
return
regular_train_epoch
else
:
else
:
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)(
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)
config
,
mode
,
model
,
eval_func
)
ppcls/engine/train/classification.py
已删除
100644 → 0
浏览文件 @
aa52682c
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
time
import
paddle
from
.utils
import
update_loss
,
update_metric
,
log_info
from
...utils
import
logger
,
profiler
,
type_name
from
...utils.misc
import
AverageMeter
from
...data
import
build_dataloader
from
...loss
import
build_loss
from
...metric
import
build_metrics
from
...optimizer
import
build_optimizer
from
...utils.ema
import
ExponentialMovingAverage
from
...utils.save_load
import
init_model
,
ModelSaver
class
ClassTrainer
(
object
):
def
__init__
(
self
,
config
,
mode
,
model
,
eval_func
):
self
.
config
=
config
self
.
model
=
model
self
.
eval
=
eval_func
self
.
start_eval_epoch
=
self
.
config
[
"Global"
].
get
(
"start_eval_epoch"
,
0
)
-
1
self
.
epochs
=
self
.
config
[
"Global"
].
get
(
"epochs"
,
1
)
self
.
print_batch_step
=
self
.
config
[
'Global'
][
'print_batch_step'
]
self
.
save_interval
=
self
.
config
[
"Global"
][
"save_interval"
]
self
.
output_dir
=
self
.
config
[
'Global'
][
'output_dir'
]
# gradient accumulation
self
.
update_freq
=
self
.
config
[
"Global"
].
get
(
"update_freq"
,
1
)
# AMP training and evaluating
# self._init_amp()
# build dataloader
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
self
.
dataloader_dict
=
build_dataloader
(
self
.
config
,
mode
)
# build loss
self
.
train_loss_func
,
self
.
unlabel_train_loss_func
=
build_loss
(
self
.
config
,
mode
)
# build metric
self
.
train_metric_func
=
build_metrics
(
config
,
"train"
)
# build optimizer
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
config
,
self
.
dataloader_dict
[
"Train"
].
max_iter
,
[
self
.
model
,
self
.
train_loss_func
],
self
.
update_freq
)
# build model saver
self
.
model_saver
=
ModelSaver
(
self
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
)
# build best metric
self
.
best_metric
=
{
"metric"
:
-
1.0
,
"epoch"
:
0
,
}
# key:
# val: metrics list word
self
.
output_info
=
dict
()
self
.
time_info
=
{
"batch_cost"
:
AverageMeter
(
"batch_cost"
,
'.5f'
,
postfix
=
" s,"
),
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
# build EMA model
self
.
model_ema
=
self
.
_build_ema_model
()
self
.
_init_checkpoints
()
# for visualdl
self
.
vdl_writer
=
self
.
_init_vdl
()
def
__call__
(
self
):
# global iter counter
self
.
global_step
=
0
for
epoch_id
in
range
(
self
.
best_metric
[
"epoch"
]
+
1
,
self
.
epochs
+
1
):
# for one epoch train
self
.
train_epoch
(
epoch_id
)
metric_msg
=
", "
.
join
(
[
self
.
output_info
[
key
].
avg_info
for
key
in
self
.
output_info
])
logger
.
info
(
"[Train][Epoch {}/{}][Avg]{}"
.
format
(
epoch_id
,
self
.
epochs
,
metric_msg
))
self
.
output_info
.
clear
()
acc
=
0.0
if
self
.
config
[
"Global"
][
"eval_during_train"
]
and
epoch_id
%
self
.
config
[
"Global"
][
"eval_interval"
]
==
0
and
epoch_id
>
self
.
start_eval_epoch
:
acc
=
self
.
eval
(
epoch_id
)
# step lr (by epoch) according to given metric, such as acc
for
i
in
range
(
len
(
self
.
lr_sch
)):
if
getattr
(
self
.
lr_sch
[
i
],
"by_epoch"
,
False
)
and
\
type_name
(
self
.
lr_sch
[
i
])
==
"ReduceOnPlateau"
:
self
.
lr_sch
[
i
].
step
(
acc
)
if
acc
>
self
.
best_metric
[
"metric"
]:
self
.
best_metric
[
"metric"
]
=
acc
self
.
best_metric
[
"epoch"
]
=
epoch_id
self
.
model_saver
.
save
(
self
.
best_metric
,
prefix
=
"best_model"
,
save_student_model
=
True
)
logger
.
info
(
"[Eval][Epoch {}][best metric: {}]"
.
format
(
epoch_id
,
self
.
best_metric
[
"metric"
]))
logger
.
scaler
(
name
=
"eval_acc"
,
value
=
acc
,
step
=
epoch_id
,
writer
=
self
.
vdl_writer
)
self
.
model
.
train
()
if
self
.
model_ema
:
ori_model
,
self
.
model
=
self
.
model
,
self
.
model_ema
.
module
acc_ema
=
self
.
eval
(
epoch_id
)
self
.
model
=
ori_model
self
.
model_ema
.
module
.
eval
()
if
acc_ema
>
self
.
best_metric
[
"metric_ema"
]:
self
.
best_metric
[
"metric_ema"
]
=
acc_ema
self
.
model_saver
.
save
(
{
"metric"
:
acc_ema
,
"epoch"
:
epoch_id
},
prefix
=
"best_model_ema"
)
logger
.
info
(
"[Eval][Epoch {}][best metric ema: {}]"
.
format
(
epoch_id
,
self
.
best_metric
[
"metric_ema"
]))
logger
.
scaler
(
name
=
"eval_acc_ema"
,
value
=
acc_ema
,
step
=
epoch_id
,
writer
=
self
.
vdl_writer
)
# save model
if
self
.
save_interval
>
0
and
epoch_id
%
self
.
save_interval
==
0
:
self
.
model_saver
.
save
(
{
"metric"
:
acc
,
"epoch"
:
epoch_id
},
prefix
=
f
"epoch_
{
epoch_id
}
"
)
# save the latest model
self
.
model_saver
.
save
(
{
"metric"
:
acc
,
"epoch"
:
epoch_id
},
prefix
=
"latest"
)
def
train_epoch
(
self
,
epoch_id
):
tic
=
time
.
time
()
for
iter_id
in
range
(
self
.
dataloader_dict
[
"Train"
].
max_iter
):
batch
=
self
.
dataloader_dict
[
"Train"
].
get_batch
()
profiler
.
add_profiler_step
(
self
.
config
[
"profiler_options"
])
if
iter_id
==
5
:
for
key
in
self
.
time_info
:
self
.
time_info
[
key
].
reset
()
self
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
batch
[
0
].
shape
[
0
]
if
not
self
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
):
batch
[
1
]
=
batch
[
1
].
reshape
([
batch_size
,
-
1
])
self
.
global_step
+=
1
# forward & backward & step opt
# if engine.amp:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# scaled = engine.scaler.scale(loss)
# scaled.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.scaler.minimize(engine.optimizer[i], scaled)
# else:
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# loss.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.optimizer[i].step()
out
=
self
.
model
(
batch
)
loss_dict
=
self
.
train_loss_func
(
out
,
batch
[
1
])
loss
=
loss_dict
[
"loss"
]
/
self
.
update_freq
loss
.
backward
()
if
(
iter_id
+
1
)
%
self
.
update_freq
==
0
:
for
i
in
range
(
len
(
self
.
optimizer
)):
self
.
optimizer
[
i
].
step
()
if
(
iter_id
+
1
)
%
self
.
update_freq
==
0
:
# clear grad
for
i
in
range
(
len
(
self
.
optimizer
)):
self
.
optimizer
[
i
].
clear_grad
()
# step lr(by step)
for
i
in
range
(
len
(
self
.
lr_sch
)):
if
not
getattr
(
self
.
lr_sch
[
i
],
"by_epoch"
,
False
):
self
.
lr_sch
[
i
].
step
()
# update ema
if
self
.
model_ema
:
self
.
model_ema
.
update
(
self
.
model
)
# below code just for logging
# update metric_for_logger
update_metric
(
self
,
out
,
batch
,
batch_size
)
# update_loss_for_logger
update_loss
(
self
,
loss_dict
,
batch_size
)
self
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
self
.
print_batch_step
==
0
:
log_info
(
self
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
# step lr(by epoch)
for
i
in
range
(
len
(
self
.
lr_sch
)):
if
getattr
(
self
.
lr_sch
[
i
],
"by_epoch"
,
False
)
and
\
type_name
(
self
.
lr_sch
[
i
])
!=
"ReduceOnPlateau"
:
self
.
lr_sch
[
i
].
step
()
def
__del__
(
self
):
if
self
.
vdl_writer
is
not
None
:
self
.
vdl_writer
.
close
()
def
_init_vdl
(
self
):
if
self
.
config
[
'Global'
][
'use_visualdl'
]
and
dist
.
get_rank
()
==
0
:
vdl_writer_path
=
os
.
path
.
join
(
self
.
output_dir
,
"vdl"
)
if
not
os
.
path
.
exists
(
vdl_writer_path
):
os
.
makedirs
(
vdl_writer_path
)
return
LogWriter
(
logdir
=
vdl_writer_path
)
return
None
def
_build_ema_model
(
self
):
if
"EMA"
in
self
.
config
and
self
.
mode
==
"train"
:
model_ema
=
ExponentialMovingAverage
(
self
.
model
,
self
.
config
[
'EMA'
].
get
(
"decay"
,
0.9999
))
self
.
best_metric
[
"metric_ema"
]
=
0
return
model_ema
else
:
return
None
def
_init_checkpoints
(
self
):
if
self
.
config
[
"Global"
].
get
(
"checkpoints"
,
None
)
is
not
None
:
metric_info
=
init_model
(
self
.
config
.
Global
,
self
.
model
,
self
.
optimizer
,
self
.
train_loss_func
,
self
.
model_ema
)
if
metric_info
is
not
None
:
self
.
best_metric
.
update
(
metric_info
)
ppcls/engine/train/regular_train_epoch.py
0 → 100644
浏览文件 @
915dde17
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
time
import
paddle
from
ppcls.engine.train.utils
import
update_loss
,
update_metric
,
log_info
,
type_name
from
ppcls.utils
import
profiler
def
regular_train_epoch
(
engine
,
epoch_id
,
print_batch_step
):
tic
=
time
.
time
()
for
iter_id
in
range
(
engine
.
dataloader_dict
[
"Train"
].
max_iter
):
batch
=
engine
.
dataloader_dict
[
"Train"
].
get_batch
()
profiler
.
add_profiler_step
(
engine
.
config
[
"profiler_options"
])
if
iter_id
==
5
:
for
key
in
engine
.
time_info
:
engine
.
time_info
[
key
].
reset
()
engine
.
time_info
[
"reader_cost"
].
update
(
time
.
time
()
-
tic
)
batch_size
=
batch
[
0
].
shape
[
0
]
if
not
engine
.
config
[
"Global"
].
get
(
"use_multilabel"
,
False
):
batch
[
1
]
=
batch
[
1
].
reshape
([
batch_size
,
-
1
])
engine
.
global_step
+=
1
# forward & backward & step opt
if
engine
.
amp
:
with
paddle
.
amp
.
auto_cast
(
custom_black_list
=
{
"flatten_contiguous_range"
,
"greater_than"
},
level
=
engine
.
amp_level
):
out
=
engine
.
model
(
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
loss
=
loss_dict
[
"loss"
]
/
engine
.
update_freq
scaled
=
engine
.
scaler
.
scale
(
loss
)
scaled
.
backward
()
if
(
iter_id
+
1
)
%
engine
.
update_freq
==
0
:
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
scaler
.
minimize
(
engine
.
optimizer
[
i
],
scaled
)
else
:
out
=
engine
.
model
(
batch
)
loss_dict
=
engine
.
train_loss_func
(
out
,
batch
[
1
])
loss
=
loss_dict
[
"loss"
]
/
engine
.
update_freq
loss
.
backward
()
if
(
iter_id
+
1
)
%
engine
.
update_freq
==
0
:
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
step
()
if
(
iter_id
+
1
)
%
engine
.
update_freq
==
0
:
# clear grad
for
i
in
range
(
len
(
engine
.
optimizer
)):
engine
.
optimizer
[
i
].
clear_grad
()
# step lr(by step)
for
i
in
range
(
len
(
engine
.
lr_sch
)):
if
not
getattr
(
engine
.
lr_sch
[
i
],
"by_epoch"
,
False
):
engine
.
lr_sch
[
i
].
step
()
# update ema
if
engine
.
model_ema
:
engine
.
model_ema
.
update
(
engine
.
model
)
# below code just for logging
# update metric_for_logger
update_metric
(
engine
,
out
,
batch
,
batch_size
)
# update_loss_for_logger
update_loss
(
engine
,
loss_dict
,
batch_size
)
engine
.
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
log_info
(
engine
,
batch_size
,
epoch_id
,
iter_id
)
tic
=
time
.
time
()
# step lr(by epoch)
for
i
in
range
(
len
(
engine
.
lr_sch
)):
if
getattr
(
engine
.
lr_sch
[
i
],
"by_epoch"
,
False
)
and
\
type_name
(
engine
.
lr_sch
[
i
])
!=
"ReduceOnPlateau"
:
engine
.
lr_sch
[
i
].
step
()
ppcls/utils/save_load.py
浏览文件 @
915dde17
...
@@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."):
...
@@ -151,20 +151,20 @@ def _extract_student_weights(all_params, student_prefix="Student."):
class
ModelSaver
(
object
):
class
ModelSaver
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
trainer
,
engine
,
net_name
=
"model"
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
):
model_ema_name
=
"model_ema"
):
# net, loss, opt, model_ema, output_dir,
# net, loss, opt, model_ema, output_dir,
self
.
trainer
=
trainer
self
.
engine
=
engine
self
.
net_name
=
net_name
self
.
net_name
=
net_name
self
.
loss_name
=
loss_name
self
.
loss_name
=
loss_name
self
.
opt_name
=
opt_name
self
.
opt_name
=
opt_name
self
.
model_ema_name
=
model_ema_name
self
.
model_ema_name
=
model_ema_name
arch_name
=
trainer
.
config
[
"Arch"
][
"name"
]
arch_name
=
engine
.
config
[
"Arch"
][
"name"
]
self
.
output_dir
=
os
.
path
.
join
(
trainer
.
output_dir
,
arch_name
)
self
.
output_dir
=
os
.
path
.
join
(
engine
.
output_dir
,
arch_name
)
_mkdir_if_not_exist
(
self
.
output_dir
)
_mkdir_if_not_exist
(
self
.
output_dir
)
def
save
(
self
,
metric_info
,
prefix
=
'ppcls'
,
save_student_model
=
False
):
def
save
(
self
,
metric_info
,
prefix
=
'ppcls'
,
save_student_model
=
False
):
...
@@ -174,8 +174,8 @@ class ModelSaver(object):
...
@@ -174,8 +174,8 @@ class ModelSaver(object):
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
prefix
)
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
prefix
)
params_state_dict
=
getattr
(
self
.
trainer
,
self
.
net_name
).
state_dict
()
params_state_dict
=
getattr
(
self
.
engine
,
self
.
net_name
).
state_dict
()
loss
=
getattr
(
self
.
trainer
,
self
.
loss_name
)
loss
=
getattr
(
self
.
engine
,
self
.
loss_name
)
if
loss
is
not
None
:
if
loss
is
not
None
:
loss_state_dict
=
loss
.
state_dict
()
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
...
@@ -190,11 +190,11 @@ class ModelSaver(object):
...
@@ -190,11 +190,11 @@ class ModelSaver(object):
paddle
.
save
(
s_params
,
save_dir
+
"_student.pdparams"
)
paddle
.
save
(
s_params
,
save_dir
+
"_student.pdparams"
)
paddle
.
save
(
params_state_dict
,
save_dir
+
".pdparams"
)
paddle
.
save
(
params_state_dict
,
save_dir
+
".pdparams"
)
model_ema
=
getattr
(
self
.
trainer
,
self
.
model_ema_name
)
model_ema
=
getattr
(
self
.
engine
,
self
.
model_ema_name
)
if
model_ema
is
not
None
:
if
model_ema
is
not
None
:
paddle
.
save
(
model_ema
.
module
.
state_dict
(),
paddle
.
save
(
model_ema
.
module
.
state_dict
(),
save_dir
+
".ema.pdparams"
)
save_dir
+
".ema.pdparams"
)
optimizer
=
getattr
(
self
.
trainer
,
self
.
opt_name
)
optimizer
=
getattr
(
self
.
engine
,
self
.
opt_name
)
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
save_dir
+
".pdopt"
)
save_dir
+
".pdopt"
)
paddle
.
save
(
metric_info
,
save_dir
+
".pdstates"
)
paddle
.
save
(
metric_info
,
save_dir
+
".pdstates"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录