diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 1df4ffed419c1edc6f17575ef0397e1408029345..a0bd0376c86685639ce931c45384f6ad4103d0be 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -33,8 +33,7 @@ from ppcls.metric import build_metrics from ppcls.optimizer import build_optimizer from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url -from ppcls.utils.save_load import init_model -from ppcls.utils.model_saver import ModelSaver +from ppcls.utils.save_load import init_model, ModelSaver from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.postprocess import build_postprocess @@ -100,6 +99,14 @@ class Engine(object): # for distributed self._init_dist() + # build model saver + self.model_saver = ModelSaver( + self, + net_name="model", + loss_name="train_loss_func", + opt_name="optimizer", + model_ema_name="model_ema") + print_config(config) def train(self): @@ -129,14 +136,6 @@ class Engine(object): # TODO: mv best_metric_ema to best_metric dict best_metric_ema = 0 - # build model saver - model_saver = ModelSaver( - self, - net_name="model", - loss_name="train_loss_func", - opt_name="optimizer", - model_ema_name="model_ema") - self._init_checkpoints(best_metric) # global iter counter @@ -166,7 +165,7 @@ class Engine(object): if acc > best_metric["metric"]: best_metric["metric"] = acc best_metric["epoch"] = epoch_id - model_saver.save( + self.model_saver.save( best_metric, prefix="best_model", save_student_model=True) @@ -189,7 +188,7 @@ class Engine(object): if acc_ema > best_metric_ema: best_metric_ema = acc_ema - model_saver.save( + self.model_saver.save( { "metric": acc_ema, "epoch": epoch_id @@ -205,7 +204,7 @@ class Engine(object): # save model if save_interval > 0 and epoch_id % save_interval == 0: - model_saver.save( + self.model_saver.save( { "metric": acc, "epoch": epoch_id @@ -213,7 +212,7 @@ class Engine(object): prefix=f"epoch_{epoch_id}") # save the latest model - model_saver.save( + self.model_saver.save( { "metric": acc, "epoch": epoch_id diff --git a/ppcls/utils/model_saver.py b/ppcls/utils/model_saver.py deleted file mode 100644 index eb15d1292dfc84892b8d397f87c9d4e02c0cfd21..0000000000000000000000000000000000000000 --- a/ppcls/utils/model_saver.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import paddle - -from . import logger - - -def _mkdir_if_not_exist(path): - """ - mkdir if not exists, ignore the exception when multiprocess mkdir together - """ - if not os.path.exists(path): - try: - os.makedirs(path) - except OSError as e: - if e.errno == errno.EEXIST and os.path.isdir(path): - logger.warning( - 'be happy if some process has already created {}'.format( - path)) - else: - raise OSError('Failed to mkdir {}'.format(path)) - - -def _extract_student_weights(all_params, student_prefix="Student."): - s_params = { - key[len(student_prefix):]: all_params[key] - for key in all_params if student_prefix in key - } - return s_params - - -class ModelSaver(object): - def __init__(self, - engine, - net_name="model", - loss_name="train_loss_func", - opt_name="optimizer", - model_ema_name="model_ema"): - # net, loss, opt, model_ema, output_dir, - self.engine = engine - self.net_name = net_name - self.loss_name = loss_name - self.opt_name = opt_name - self.model_ema_name = model_ema_name - - arch_name = engine.config["Arch"]["name"] - self.output_dir = os.path.join(engine.output_dir, arch_name) - _mkdir_if_not_exist(self.output_dir) - - def save(self, metric_info, prefix='ppcls', save_student_model=False): - - if paddle.distributed.get_rank() != 0: - return - - save_dir = os.path.join(self.output_dir, prefix) - - params_state_dict = getattr(self.engine, self.net_name).state_dict() - loss = getattr(self.engine, self.loss_name) - if loss is not None: - loss_state_dict = loss.state_dict() - keys_inter = set(params_state_dict.keys()) & set( - loss_state_dict.keys()) - assert len(keys_inter) == 0, \ - f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}" - params_state_dict.update(loss_state_dict) - - if save_student_model: - s_params = _extract_student_weights(params_state_dict) - if len(s_params) > 0: - paddle.save(s_params, save_dir + "_student.pdparams") - - paddle.save(params_state_dict, save_dir + ".pdparams") - model_ema = getattr(self.engine, self.model_ema_name) - if model_ema is not None: - paddle.save(model_ema.module.state_dict(), - save_dir + ".ema.pdparams") - optimizer = getattr(self.engine, self.opt_name) - paddle.save([opt.state_dict() for opt in optimizer], - save_dir + ".pdopt") - paddle.save(metric_info, save_dir + ".pdstates") - logger.info("Already save model in {}".format(save_dir)) diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index cb3c3edb8f305ead0658b129d857def3f6a012e7..2515be363ee5b268617c239428758e1e5a82ed04 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -123,3 +123,79 @@ def init_model(config, load_dygraph_pretrain(net, path=pretrained_model) logger.info("Finish load pretrained model from {}".format( pretrained_model)) + + +def _mkdir_if_not_exist(path): + """ + mkdir if not exists, ignore the exception when multiprocess mkdir together + """ + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + logger.warning( + 'be happy if some process has already created {}'.format( + path)) + else: + raise OSError('Failed to mkdir {}'.format(path)) + + +def _extract_student_weights(all_params, student_prefix="Student."): + s_params = { + key[len(student_prefix):]: all_params[key] + for key in all_params if student_prefix in key + } + return s_params + + +class ModelSaver(object): + def __init__(self, + engine, + net_name="model", + loss_name="train_loss_func", + opt_name="optimizer", + model_ema_name="model_ema"): + # net, loss, opt, model_ema, output_dir, + self.engine = engine + self.net_name = net_name + self.loss_name = loss_name + self.opt_name = opt_name + self.model_ema_name = model_ema_name + + arch_name = engine.config["Arch"]["name"] + self.output_dir = os.path.join(engine.output_dir, arch_name) + _mkdir_if_not_exist(self.output_dir) + + def save(self, metric_info, prefix='ppcls', save_student_model=False): + + if paddle.distributed.get_rank() != 0: + return + + save_dir = os.path.join(self.output_dir, prefix) + + params_state_dict = getattr(self.engine, self.net_name).state_dict() + loss = getattr(self.engine, self.loss_name) + if loss is not None: + loss_state_dict = loss.state_dict() + keys_inter = set(params_state_dict.keys()) & set( + loss_state_dict.keys()) + assert len(keys_inter) == 0, \ + f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}" + params_state_dict.update(loss_state_dict) + + if save_student_model: + s_params = _extract_student_weights(params_state_dict) + if len(s_params) > 0: + paddle.save(s_params, save_dir + "_student.pdparams") + + paddle.save(params_state_dict, save_dir + ".pdparams") + model_ema = getattr(self.engine, self.model_ema_name) + if model_ema is not None: + paddle.save(model_ema.module.state_dict(), + save_dir + ".ema.pdparams") + optimizer = getattr(self.engine, self.opt_name) + paddle.save([opt.state_dict() for opt in optimizer], + save_dir + ".pdopt") + paddle.save(metric_info, save_dir + ".pdstates") + logger.info("Already save model in {}".format(save_dir))