Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
32593b63
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32593b63
编写于
3月 07, 2023
作者:
G
gaotingquan
提交者:
Wei Shengyu
3月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor
上级
d3941dc1
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
112 addition
and
178 deletion
+112
-178
ppcls/data/__init__.py
ppcls/data/__init__.py
+43
-83
ppcls/engine/engine.py
ppcls/engine/engine.py
+3
-1
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+9
-9
ppcls/engine/train/__init__.py
ppcls/engine/train/__init__.py
+1
-1
ppcls/engine/train/classification.py
ppcls/engine/train/classification.py
+16
-19
ppcls/engine/train/utils.py
ppcls/engine/train/utils.py
+4
-4
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+23
-43
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+13
-18
未找到文件。
ppcls/data/__init__.py
浏览文件 @
32593b63
...
...
@@ -88,25 +88,32 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random
.
seed
(
worker_seed
)
def
build
(
config
,
mode
,
use_dali
=
Fals
e
,
seed
=
None
):
def
build
_dataloader
(
config
,
mod
e
,
seed
=
None
):
assert
mode
in
[
'Train'
,
'Eval'
,
'Test'
,
'Gallery'
,
'Query'
,
'UnLabelTrain'
],
"Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert
mode
in
config
.
keys
(),
"{} config not in yaml"
.
format
(
mode
)
assert
mode
in
config
[
"DataLoader"
].
keys
(),
"{} config not in yaml"
.
format
(
mode
)
dataloader_config
=
config
[
"DataLoader"
][
mode
]
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
epochs
=
config
[
"Global"
][
"epochs"
]
use_dali
=
config
[
"Global"
].
get
(
"use_dali"
,
False
)
num_workers
=
dataloader_config
[
'loader'
][
"num_workers"
]
use_shared_memory
=
dataloader_config
[
'loader'
][
"use_shared_memory"
]
# build dataset
if
use_dali
:
from
ppcls.data.dataloader.dali
import
dali_dataloader
return
dali_dataloader
(
config
,
config
[
"DataLoader"
]
,
mode
,
paddle
.
device
.
get_device
(),
num_threads
=
config
[
mode
][
'loader'
][
"num_workers"
]
,
num_threads
=
num_workers
,
seed
=
seed
,
enable_fuse
=
True
)
class_num
=
config
.
get
(
"class_num"
,
None
)
epochs
=
config
.
get
(
"epochs"
,
None
)
config_dataset
=
config
[
mode
][
'dataset'
]
config_dataset
=
dataloader_config
[
'dataset'
]
config_dataset
=
copy
.
deepcopy
(
config_dataset
)
dataset_name
=
config_dataset
.
pop
(
'name'
)
if
'batch_transform_ops'
in
config_dataset
:
...
...
@@ -119,7 +126,7 @@ def build(config, mode, use_dali=False, seed=None):
logger
.
debug
(
"build dataset({}) success..."
.
format
(
dataset
))
# build sampler
config_sampler
=
config
[
mode
]
[
'sampler'
]
config_sampler
=
dataloader_config
[
'sampler'
]
if
config_sampler
and
"name"
not
in
config_sampler
:
batch_sampler
=
None
batch_size
=
config_sampler
[
"batch_size"
]
...
...
@@ -153,11 +160,6 @@ def build(config, mode, use_dali=False, seed=None):
else
:
batch_collate_fn
=
None
# build dataloader
config_loader
=
config
[
mode
][
'loader'
]
num_workers
=
config_loader
[
"num_workers"
]
use_shared_memory
=
config_loader
[
"use_shared_memory"
]
init_fn
=
partial
(
worker_init_fn
,
num_workers
=
num_workers
,
...
...
@@ -194,78 +196,36 @@ def build(config, mode, use_dali=False, seed=None):
data_loader
.
max_iter
=
max_iter
data_loader
.
total_samples
=
total_samples
logger
.
debug
(
"build data_loader({}) success..."
.
format
(
data_loader
))
return
data_loader
# TODO(gaotingquan): perf
class
DataIterator
(
object
):
def
__init__
(
self
,
dataloader
,
use_dali
=
False
):
self
.
dataloader
=
dataloader
self
.
use_dali
=
use_dali
self
.
iterator
=
iter
(
dataloader
)
self
.
max_iter
=
dataloader
.
max_iter
self
.
total_samples
=
dataloader
.
total_samples
def
get_batch
(
self
):
# fetch data batch from dataloader
try
:
batch
=
next
(
self
.
iterator
)
except
Exception
:
# NOTE: reset DALI dataloader manually
if
self
.
use_dali
:
self
.
dataloader
.
reset
()
self
.
iterator
=
iter
(
self
.
dataloader
)
batch
=
next
(
self
.
iterator
)
return
batch
def
build_dataloader
(
config
,
mode
):
class_num
=
config
[
"Arch"
].
get
(
"class_num"
,
None
)
config
[
"DataLoader"
].
update
({
"class_num"
:
class_num
})
config
[
"DataLoader"
].
update
({
"epochs"
:
config
[
"Global"
][
"epochs"
]})
use_dali
=
config
[
"Global"
].
get
(
"use_dali"
,
False
)
dataloader_dict
=
{
"Train"
:
None
,
"UnLabelTrain"
:
None
,
"Eval"
:
None
,
"Query"
:
None
,
"Gallery"
:
None
,
"GalleryQuery"
:
None
}
if
mode
==
'train'
:
train_dataloader
=
build
(
config
[
"DataLoader"
],
"Train"
,
use_dali
,
seed
=
None
)
if
config
[
"DataLoader"
][
"Train"
].
get
(
"max_iter"
,
None
):
# TODO(gaotingquan): mv to build_sampler
if
mode
==
"train"
:
if
dataloader_config
[
"Train"
].
get
(
"max_iter"
,
None
):
# set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch.
max_iter
=
config
[
"Train"
].
get
(
"max_iter"
)
update_freq
=
config
[
"Global"
].
get
(
"update_freq"
,
1
)
max_iter
=
train_data
loader
.
max_iter
//
update_freq
*
update_freq
train_data
loader
.
max_iter
=
max_iter
if
config
[
"DataLoader"
][
"Train"
].
get
(
"convert_iterator"
,
True
):
train_dataloader
=
DataIterator
(
train_dataloader
,
use_dali
)
dataloader_dict
[
"Train"
]
=
train_data
loader
max_iter
=
data_
loader
.
max_iter
//
update_freq
*
update_freq
data_
loader
.
max_iter
=
max_iter
logger
.
debug
(
"build data_loader({}) success..."
.
format
(
data_loader
)
)
return
data_
loader
if
config
[
"DataLoader"
].
get
(
'UnLabelTrain'
,
None
)
is
not
None
:
dataloader_dict
[
"UnLabelTrain"
]
=
build
(
config
[
"DataLoader"
],
"UnLabelTrain"
,
use_dali
,
seed
=
None
)
if
mode
==
"eval"
or
(
mode
==
"train"
and
config
[
"Global"
][
"eval_during_train"
]):
task
=
config
[
"Global"
].
get
(
"task"
,
"classification"
)
if
task
in
[
"classification"
,
"adaface"
]:
dataloader_dict
[
"Eval"
]
=
build
(
config
[
"DataLoader"
],
"Eval"
,
use_dali
,
seed
=
None
)
elif
task
==
"retrieval"
:
if
len
(
config
[
"DataLoader"
][
"Eval"
].
keys
())
==
1
:
key
=
list
(
config
[
"DataLoader"
][
"Eval"
].
keys
())[
0
]
dataloader_dict
[
"GalleryQuery"
]
=
build
(
config
[
"DataLoader"
][
"Eval"
],
key
,
use_dali
)
else
:
dataloader_dict
[
"Gallery"
]
=
build
(
config
[
"DataLoader"
][
"Eval"
],
"Gallery"
,
use_dali
)
dataloader_dict
[
"Query"
]
=
build
(
config
[
"DataLoader"
][
"Eval"
],
"Query"
,
use_dali
)
return
dataloader_dict
# # TODO(gaotingquan): the length of dataloader should be determined by sampler
# class DataIterator(object):
# def __init__(self, dataloader, use_dali=False):
# self.dataloader = dataloader
# self.use_dali = use_dali
# self.iterator = iter(dataloader)
# self.max_iter = dataloader.max_iter
# self.total_samples = dataloader.total_samples
# def get_batch(self):
# # fetch data batch from dataloader
# try:
# batch = next(self.iterator)
# except Exception:
# # NOTE: reset DALI dataloader manually
# if self.use_dali:
# self.dataloader.reset()
# self.iterator = iter(self.dataloader)
# batch = next(self.iterator)
# return batch
ppcls/engine/engine.py
浏览文件 @
32593b63
...
...
@@ -60,6 +60,8 @@ class Engine(object):
# load_pretrain
self
.
_init_pretrained
()
self
.
_init_amp
()
# init train_func and eval_func
self
.
eval
=
build_eval_func
(
self
.
config
,
mode
=
self
.
mode
,
model
=
self
.
model
)
...
...
@@ -69,7 +71,7 @@ class Engine(object):
# for distributed
self
.
_init_dist
()
print_config
(
config
)
print_config
(
self
.
config
)
@
paddle
.
no_grad
()
def
infer
(
self
):
...
...
ppcls/engine/evaluation/classification.py
浏览文件 @
32593b63
...
...
@@ -29,10 +29,11 @@ class ClassEval(object):
def
__init__
(
self
,
config
,
mode
,
model
):
self
.
config
=
config
self
.
model
=
model
self
.
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
self
.
eval_metric_func
=
build_metrics
(
config
,
"eval"
)
self
.
eval_dataloader
=
build_dataloader
(
config
,
"e
val"
)
self
.
eval_loss_func
=
build_loss
(
config
,
"e
val"
)
self
.
eval_metric_func
=
build_metrics
(
self
.
config
,
"eval"
)
self
.
eval_dataloader
=
build_dataloader
(
self
.
config
,
"E
val"
)
self
.
eval_loss_func
=
build_loss
(
self
.
config
,
"E
val"
)
self
.
output_info
=
dict
()
@
paddle
.
no_grad
()
...
...
@@ -48,13 +49,12 @@ class ClassEval(object):
"reader_cost"
:
AverageMeter
(
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
print_batch_step
=
self
.
config
[
"Global"
][
"print_batch_step"
]
tic
=
time
.
time
()
total_samples
=
self
.
eval_dataloader
[
"Eval"
]
.
total_samples
total_samples
=
self
.
eval_dataloader
.
total_samples
accum_samples
=
0
max_iter
=
self
.
eval_dataloader
[
"Eval"
]
.
max_iter
for
iter_id
,
batch
in
enumerate
(
self
.
eval_dataloader
[
"Eval"
]
):
max_iter
=
self
.
eval_dataloader
.
max_iter
for
iter_id
,
batch
in
enumerate
(
self
.
eval_dataloader
):
if
iter_id
>=
max_iter
:
break
if
iter_id
==
5
:
...
...
@@ -130,7 +130,7 @@ class ClassEval(object):
self
.
eval_metric_func
(
preds
,
labels
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
if
iter_id
%
self
.
print_batch_step
==
0
:
time_msg
=
"s, "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
time_info
[
key
].
avg
)
for
key
in
time_info
...
...
@@ -153,7 +153,7 @@ class ClassEval(object):
tic
=
time
.
time
()
if
self
.
use_dali
:
self
.
eval_dataloader
[
"Eval"
]
.
reset
()
self
.
eval_dataloader
.
reset
()
if
"ATTRMetric"
in
self
.
config
[
"Metric"
][
"Eval"
][
0
]:
metric_msg
=
", "
.
join
([
...
...
ppcls/engine/train/__init__.py
浏览文件 @
32593b63
...
...
@@ -25,7 +25,7 @@ def build_train_func(config, mode, model, eval_func):
train_mode
=
config
[
"Global"
].
get
(
"task"
,
None
)
if
train_mode
is
None
:
config
[
"Global"
][
"task"
]
=
"classification"
return
ClassTrainer
(
config
,
mode
,
mode
l
,
eval_func
)
return
ClassTrainer
(
config
,
model
,
eval_func
)
else
:
return
getattr
(
sys
.
modules
[
__name__
],
"train_epoch_"
+
train_mode
)(
config
,
mode
,
model
,
eval_func
)
ppcls/engine/train/classification.py
浏览文件 @
32593b63
...
...
@@ -28,7 +28,7 @@ from ...utils.save_load import init_model, ModelSaver
class
ClassTrainer
(
object
):
def
__init__
(
self
,
config
,
mode
,
mode
l
,
eval_func
):
def
__init__
(
self
,
config
,
model
,
eval_func
):
self
.
config
=
config
self
.
model
=
model
self
.
eval
=
eval_func
...
...
@@ -41,32 +41,32 @@ class ClassTrainer(object):
# gradient accumulation
self
.
update_freq
=
self
.
config
[
"Global"
].
get
(
"update_freq"
,
1
)
# AMP training and evaluating
# self._init_amp()
# TODO(gaotingquan): mv to build_model
# build EMA model
self
.
model_ema
=
self
.
_build_ema_model
()
# build dataloader
self
.
use_dali
=
self
.
config
[
"Global"
].
get
(
"use_dali"
,
False
)
self
.
dataloader
_dict
=
build_dataloader
(
self
.
config
,
mode
)
self
.
dataloader
=
build_dataloader
(
self
.
config
,
"Train"
)
# build loss
self
.
train_loss_func
,
self
.
unlabel_train_loss_func
=
build_loss
(
self
.
config
,
mode
)
self
.
loss_func
=
build_loss
(
config
,
"Train"
)
# build metric
self
.
train_metric_func
=
build_metrics
(
config
,
"train"
)
# build optimizer
self
.
optimizer
,
self
.
lr_sch
=
build_optimizer
(
self
.
config
,
self
.
dataloader
_dict
[
"Train"
]
.
max_iter
,
[
self
.
model
,
self
.
train_
loss_func
],
self
.
update_freq
)
self
.
config
,
self
.
dataloader
.
max_iter
,
[
self
.
model
,
self
.
loss_func
],
self
.
update_freq
)
# build model saver
self
.
model_saver
=
ModelSaver
(
self
,
net
_name
=
"model"
,
loss
_name
=
"train_loss_func"
,
opt
_name
=
"optimizer"
,
model_ema
_name
=
"model_ema"
)
config
=
self
.
config
,
net
=
self
.
model
,
loss
=
self
.
loss_func
,
opt
=
self
.
optimizer
,
model_ema
=
self
.
model_ema
if
self
.
model_ema
else
None
)
# build best metric
self
.
best_metric
=
{
...
...
@@ -84,8 +84,6 @@ class ClassTrainer(object):
"reader_cost"
,
".5f"
,
postfix
=
" s,"
),
}
# build EMA model
self
.
model_ema
=
self
.
_build_ema_model
()
self
.
_init_checkpoints
()
# for visualdl
...
...
@@ -173,11 +171,10 @@ class ClassTrainer(object):
},
prefix
=
"latest"
)
def
train_epoch
(
self
,
epoch_id
):
self
.
model
.
train
()
tic
=
time
.
time
()
for
iter_id
in
range
(
self
.
dataloader_dict
[
"Train"
].
max_iter
):
batch
=
self
.
dataloader_dict
[
"Train"
].
get_batch
()
for
iter_id
,
batch
in
enumerate
(
self
.
dataloader
):
profiler
.
add_profiler_step
(
self
.
config
[
"profiler_options"
])
if
iter_id
==
5
:
for
key
in
self
.
time_info
:
...
...
@@ -190,7 +187,7 @@ class ClassTrainer(object):
self
.
global_step
+=
1
out
=
self
.
model
(
batch
)
loss_dict
=
self
.
train_
loss_func
(
out
,
batch
[
1
])
loss_dict
=
self
.
loss_func
(
out
,
batch
[
1
])
# TODO(gaotingquan): mv update_freq to loss and optimizer
loss
=
loss_dict
[
"loss"
]
/
self
.
update_freq
loss
.
backward
()
...
...
ppcls/engine/train/utils.py
浏览文件 @
32593b63
...
...
@@ -55,13 +55,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
batch_size
/
trainer
.
time_info
[
"batch_cost"
].
avg
)
eta_sec
=
((
trainer
.
config
[
"Global"
][
"epochs"
]
-
epoch_id
+
1
)
*
trainer
.
dataloader_dict
[
"Train"
].
max_iter
-
iter_id
)
*
len
(
trainer
.
dataloader
)
-
iter_id
)
*
trainer
.
time_info
[
"batch_cost"
].
avg
eta_msg
=
"eta: {:s}"
.
format
(
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_sec
))))
logger
.
info
(
"[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}"
.
format
(
epoch_id
,
trainer
.
config
[
"Global"
][
"epochs"
],
iter_id
,
trainer
.
dataloader_dict
[
"Train"
]
.
max_iter
,
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
epoch_id
,
trainer
.
config
[
"Global"
][
"epochs"
],
iter_id
,
len
(
trainer
.
dataloader
),
lr_msg
,
metric_msg
,
time_msg
,
ips_msg
,
eta_msg
))
for
i
,
lr
in
enumerate
(
trainer
.
lr_sch
):
logger
.
scaler
(
...
...
ppcls/loss/__init__.py
浏览文件 @
32593b63
...
...
@@ -50,8 +50,9 @@ from .metabinloss import IntraDomainScatterLoss
class
CombinedLoss
(
nn
.
Layer
):
def
__init__
(
self
,
config_list
,
amp_config
=
None
):
def
__init__
(
self
,
config_list
,
mode
,
amp_config
=
None
):
super
().
__init__
()
self
.
mode
=
mode
loss_func
=
[]
self
.
loss_weight
=
[]
assert
isinstance
(
config_list
,
list
),
(
...
...
@@ -68,11 +69,13 @@ class CombinedLoss(nn.Layer):
self
.
loss_func
=
nn
.
LayerList
(
loss_func
)
logger
.
debug
(
"build loss {} success."
.
format
(
loss_func
))
self
.
scaler
=
None
if
amp_config
:
self
.
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
config
[
"AMP"
].
get
(
"scale_loss"
,
1.0
),
use_dynamic_loss_scaling
=
config
[
"AMP"
].
get
(
"use_dynamic_loss_scaling"
,
False
))
if
self
.
mode
==
"Train"
or
AMPForwardDecorator
.
amp_eval
:
self
.
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
amp_config
.
get
(
"scale_loss"
,
1.0
),
use_dynamic_loss_scaling
=
amp_config
.
get
(
"use_dynamic_loss_scaling"
,
False
))
@
AMP_forward_decorator
def
__call__
(
self
,
input
,
batch
):
...
...
@@ -89,49 +92,26 @@ class CombinedLoss(nn.Layer):
loss
=
{
key
:
loss
[
key
]
*
weight
for
key
in
loss
}
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
# TODO(gaotingquan): if amp_eval & eval_loss ?
if
AMPForwardDecorator
.
amp_level
:
if
self
.
scaler
:
self
.
scaler
(
loss_dict
[
"loss"
])
return
loss_dict
def
build_loss
(
config
,
mode
=
"train"
):
if
mode
==
"train"
:
label_loss_info
=
config
[
"Loss"
][
"Train"
]
if
label_loss_info
:
train_loss_func
=
CombinedLoss
(
copy
.
deepcopy
(
label_loss_info
),
config
.
get
(
"AMP"
,
None
))
unlabel_loss_info
=
config
.
get
(
"UnLabelLoss"
,
{}).
get
(
"Train"
,
None
)
if
unlabel_loss_info
:
unlabel_train_loss_func
=
CombinedLoss
(
copy
.
deepcopy
(
unlabel_loss_info
),
config
.
get
(
"AMP"
,
None
))
else
:
unlabel_train_loss_func
=
None
def
build_loss
(
config
,
mode
):
if
config
[
"Loss"
][
mode
]
is
None
:
return
None
module_class
=
CombinedLoss
(
copy
.
deepcopy
(
config
[
"Loss"
][
mode
]),
mode
,
amp_config
=
config
.
get
(
"AMP"
,
None
))
if
AMPForwardDecorator
.
amp_level
is
not
None
:
train_loss_func
=
paddle
.
amp
.
decorate
(
models
=
train_loss_func
,
level
=
AMPForwardDecorator
.
amp_level
,
save_dtype
=
'float32'
)
# TODO(gaotingquan): unlabel_loss_info may be None
unlabel_train_loss_func
=
paddle
.
amp
.
decorate
(
models
=
unlabel_train_loss_func
,
if
AMPForwardDecorator
.
amp_level
is
not
None
:
if
mode
==
"Train"
or
AMPForwardDecorator
.
amp_eval
:
module_class
=
paddle
.
amp
.
decorate
(
models
=
module_class
,
level
=
AMPForwardDecorator
.
amp_level
,
save_dtype
=
'float32'
)
return
train_loss_func
,
unlabel_train_loss_func
if
mode
==
"eval"
or
(
mode
==
"train"
and
config
[
"Global"
][
"eval_during_train"
]):
loss_config
=
config
.
get
(
"Loss"
,
None
)
if
loss_config
is
not
None
:
loss_config
=
loss_config
.
get
(
"Eval"
)
if
loss_config
is
not
None
:
eval_loss_func
=
CombinedLoss
(
copy
.
deepcopy
(
loss_config
),
config
.
get
(
"AMP"
,
None
))
if
AMPForwardDecorator
.
amp_level
is
not
None
and
AMPForwardDecorator
.
amp_eval
:
eval_loss_func
=
paddle
.
amp
.
decorate
(
models
=
eval_loss_func
,
level
=
AMPForwardDecorator
.
amp_level
,
save_dtype
=
'float32'
)
return
eval_loss_func
logger
.
debug
(
"build loss {} success."
.
format
(
module_class
))
return
module_class
ppcls/utils/save_load.py
浏览文件 @
32593b63
...
...
@@ -150,21 +150,16 @@ def _extract_student_weights(all_params, student_prefix="Student."):
class
ModelSaver
(
object
):
def
__init__
(
self
,
trainer
,
net_name
=
"model"
,
loss_name
=
"train_loss_func"
,
opt_name
=
"optimizer"
,
model_ema_name
=
"model_ema"
):
def
__init__
(
self
,
config
,
net
,
loss
,
opt
,
model_ema
):
# net, loss, opt, model_ema, output_dir,
self
.
trainer
=
trainer
self
.
net_name
=
net_name
self
.
loss_name
=
loss_name
self
.
opt_name
=
opt_name
self
.
model_ema_name
=
model_ema_name
arch_name
=
trainer
.
config
[
"Arch"
][
"name"
]
self
.
output_dir
=
os
.
path
.
join
(
trainer
.
output_dir
,
arch_name
)
self
.
net
=
net
self
.
loss
=
loss
self
.
opt
=
opt
self
.
model_ema
=
model_ema
arch_name
=
config
[
"Arch"
][
"name"
]
self
.
output_dir
=
os
.
path
.
join
(
config
[
"Global"
][
"output_dir"
],
arch_name
)
_mkdir_if_not_exist
(
self
.
output_dir
)
def
save
(
self
,
metric_info
,
prefix
=
'ppcls'
,
save_student_model
=
False
):
...
...
@@ -174,8 +169,8 @@ class ModelSaver(object):
save_dir
=
os
.
path
.
join
(
self
.
output_dir
,
prefix
)
params_state_dict
=
getattr
(
self
.
trainer
,
self
.
net_name
)
.
state_dict
()
loss
=
getattr
(
self
.
trainer
,
self
.
loss_name
)
params_state_dict
=
self
.
net
.
state_dict
()
loss
=
self
.
loss
if
loss
is
not
None
:
loss_state_dict
=
loss
.
state_dict
()
keys_inter
=
set
(
params_state_dict
.
keys
())
&
set
(
...
...
@@ -190,11 +185,11 @@ class ModelSaver(object):
paddle
.
save
(
s_params
,
save_dir
+
"_student.pdparams"
)
paddle
.
save
(
params_state_dict
,
save_dir
+
".pdparams"
)
model_ema
=
getattr
(
self
.
trainer
,
self
.
model_ema_name
)
model_ema
=
self
.
model_ema
if
model_ema
is
not
None
:
paddle
.
save
(
model_ema
.
module
.
state_dict
(),
save_dir
+
".ema.pdparams"
)
optimizer
=
getattr
(
self
.
trainer
,
self
.
opt_name
)
optimizer
=
self
.
opt
paddle
.
save
([
opt
.
state_dict
()
for
opt
in
optimizer
],
save_dir
+
".pdopt"
)
paddle
.
save
(
metric_info
,
save_dir
+
".pdstates"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录