From 0efda2c75e1cc9c1d1a6168cc7bcf0506236ae93 Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Tue, 14 Mar 2023 16:16:40 +0800 Subject: [PATCH] Revert "refactor: simpfy engine.train()" This reverts commit fad5c8e348f748f047508226c3c23335874ef9dc. --- ppcls/engine/engine.py | 134 +++++++++++----------- ppcls/engine/train/regular_train_epoch.py | 2 +- ppcls/utils/__init__.py | 2 +- ppcls/utils/logger.py | 9 +- ppcls/utils/save_load.py | 70 ++++++++++- 5 files changed, 140 insertions(+), 77 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index f287b726..307cc968 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -34,7 +34,7 @@ from ppcls.optimizer import build_optimizer from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import init_model -from ppcls.utils.model_saver import ModelSaver +from ppcls.utils import save_load from ppcls.data.utils.get_image_list import get_image_list from ppcls.data.postprocess import build_postprocess @@ -56,10 +56,7 @@ class Engine(object): self._init_seed() # init logger - self.output_dir = self.config['Global']['output_dir'] - log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], - f"{mode}.log") - init_logger(log_file=log_file) + init_logger(self.config, mode=mode) # for visualdl self.vdl_writer = self._init_vdl() @@ -106,14 +103,21 @@ class Engine(object): assert self.mode == "train" print_batch_step = self.config['Global']['print_batch_step'] save_interval = self.config["Global"]["save_interval"] - start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1 - epochs = self.config["Global"]["epochs"] - best_metric = { "metric": -1.0, "epoch": 0, } + # build EMA model + self.ema = "EMA" in self.config and self.mode == "train" + if self.ema: + self.model_ema = ExponentialMovingAverage( + self.model, self.config['EMA'].get("decay", 0.9999)) + best_metric_ema = 0.0 + ema_module = self.model_ema.module + else: + ema_module = None + # key: # val: metrics list word self.output_info = dict() @@ -123,35 +127,31 @@ class Engine(object): "reader_cost": AverageMeter( "reader_cost", ".5f", postfix=" s,"), } - - # build EMA model - self.model_ema = self._build_ema_model() - # TODO: mv best_metric_ema to best_metric dict - best_metric_ema = 0 - - # build model saver - model_saver = ModelSaver( - self, - net_name="model", - loss_name="train_loss_func", - opt_name="optimizer", - model_ema_name="model_ema") - - self._init_checkpoints(best_metric) - # global iter counter self.global_step = 0 - for epoch_id in range(best_metric["epoch"] + 1, epochs + 1): + + if self.config.Global.checkpoints is not None: + metric_info = init_model(self.config.Global, self.model, + self.optimizer, self.train_loss_func, + ema_module) + if metric_info is not None: + best_metric.update(metric_info) + + for epoch_id in range(best_metric["epoch"] + 1, + self.config["Global"]["epochs"] + 1): + acc = 0.0 # for one epoch train self.train_epoch_func(self, epoch_id, print_batch_step) metric_msg = ", ".join( [self.output_info[key].avg_info for key in self.output_info]) - logger.info("[Train][Epoch {}/{}][Avg]{}".format(epoch_id, epochs, - metric_msg)) + logger.info("[Train][Epoch {}/{}][Avg]{}".format( + epoch_id, self.config["Global"]["epochs"], metric_msg)) self.output_info.clear() - acc = 0.0 + # eval model and save model if possible + start_eval_epoch = self.config["Global"].get("start_eval_epoch", + 0) - 1 if self.config["Global"][ "eval_during_train"] and epoch_id % self.config["Global"][ "eval_interval"] == 0 and epoch_id > start_eval_epoch: @@ -166,11 +166,16 @@ class Engine(object): if acc > best_metric["metric"]: best_metric["metric"] = acc best_metric["epoch"] = epoch_id - model_saver.save( + save_load.save_model( + self.model, + self.optimizer, best_metric, + self.output_dir, + ema=ema_module, + model_name=self.config["Arch"]["name"], prefix="best_model", + loss=self.train_loss_func, save_student_model=True) - logger.info("[Eval][Epoch {}][best metric: {}]".format( epoch_id, best_metric["metric"])) logger.scaler( @@ -181,20 +186,24 @@ class Engine(object): self.model.train() - if self.model_ema: - ori_model, self.model = self.model, self.model_ema.module + if self.ema: + ori_model, self.model = self.model, ema_module acc_ema = self.eval(epoch_id) self.model = ori_model - self.model_ema.module.eval() + ema_module.eval() if acc_ema > best_metric_ema: best_metric_ema = acc_ema - model_saver.save( - { - "metric": acc_ema, - "epoch": epoch_id - }, - prefix="best_model_ema") + save_load.save_model( + self.model, + self.optimizer, + {"metric": acc_ema, + "epoch": epoch_id}, + self.output_dir, + ema=ema_module, + model_name=self.config["Arch"]["name"], + prefix="best_model_ema", + loss=self.train_loss_func) logger.info("[Eval][Epoch {}][best metric ema: {}]".format( epoch_id, best_metric_ema)) logger.scaler( @@ -205,19 +214,25 @@ class Engine(object): # save model if save_interval > 0 and epoch_id % save_interval == 0: - model_saver.save( - { - "metric": acc, - "epoch": epoch_id - }, - prefix=f"epoch_{epoch_id}") - + save_load.save_model( + self.model, + self.optimizer, {"metric": acc, + "epoch": epoch_id}, + self.output_dir, + ema=ema_module, + model_name=self.config["Arch"]["name"], + prefix="epoch_{}".format(epoch_id), + loss=self.train_loss_func) # save the latest model - model_saver.save( - { - "metric": acc, - "epoch": epoch_id - }, prefix="latest") + save_load.save_model( + self.model, + self.optimizer, {"metric": acc, + "epoch": epoch_id}, + self.output_dir, + ema=ema_module, + model_name=self.config["Arch"]["name"], + prefix="latest", + loss=self.train_loss_func) if self.vdl_writer is not None: self.vdl_writer.close() @@ -468,23 +483,6 @@ class Engine(object): self.train_loss_func = paddle.DataParallel( self.train_loss_func) - def _build_ema_model(self): - if "EMA" in self.config and self.mode == "train": - model_ema = ExponentialMovingAverage( - self.model, self.config['EMA'].get("decay", 0.9999)) - return model_ema - else: - return None - - def _init_checkpoints(self, best_metric): - if self.config["Global"].get("checkpoints", None) is not None: - metric_info = init_model(self.config.Global, self.model, - self.optimizer, self.train_loss_func, - self.model_ema) - if metric_info is not None: - best_metric.update(metric_info) - return best_metric - class ExportModel(TheseusLayer): """ diff --git a/ppcls/engine/train/regular_train_epoch.py b/ppcls/engine/train/regular_train_epoch.py index d43a969c..78629396 100644 --- a/ppcls/engine/train/regular_train_epoch.py +++ b/ppcls/engine/train/regular_train_epoch.py @@ -75,7 +75,7 @@ def regular_train_epoch(engine, epoch_id, print_batch_step): if not getattr(engine.lr_sch[i], "by_epoch", False): engine.lr_sch[i].step() # update ema - if engine.model_ema: + if engine.ema: engine.model_ema.update(engine.model) # below code just for logging diff --git a/ppcls/utils/__init__.py b/ppcls/utils/__init__.py index 59c0d050..f9307ffd 100644 --- a/ppcls/utils/__init__.py +++ b/ppcls/utils/__init__.py @@ -25,4 +25,4 @@ from .metrics import mean_average_precision from .metrics import multi_hot_encode from .metrics import precision_recall_fscore from .misc import AverageMeter -from .save_load import init_model +from .save_load import init_model, save_model diff --git a/ppcls/utils/logger.py b/ppcls/utils/logger.py index 5edca7a1..c05aa02b 100644 --- a/ppcls/utils/logger.py +++ b/ppcls/utils/logger.py @@ -22,16 +22,15 @@ import paddle.distributed as dist _logger = None -def init_logger(name='ppcls', log_file=None, log_level=logging.INFO): +def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO): """Initialize and get a logger by name. If the logger has not been initialized, this method will initialize the logger by adding one or two handlers, otherwise the initialized logger will be directly returned. During initialization, a StreamHandler will always be - added. If `log_file` is specified a FileHandler will also be added. + added. Args: + config(dict): Training config. name (str): Logger name. - log_file (str | None): The log filename. If specified, a FileHandler - will be added to the logger. log_level (int): The logger level. Note that only the process of rank 0 is affected, and other processes will set the level to "Error" thus be silent most of the time. @@ -63,6 +62,8 @@ def init_logger(name='ppcls', log_file=None, log_level=logging.INFO): if init_flag: _logger.addHandler(stream_handler) + log_file = os.path.join(config['Global']['output_dir'], + config["Arch"]["name"], f"{mode}.log") if log_file is not None and dist.get_rank() == 0: log_file_folder = os.path.split(log_file)[0] os.makedirs(log_file_folder, exist_ok=True) diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index cb3c3edb..deab91ed 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -26,6 +26,30 @@ from .download import get_weights_path_from_url __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] +def _mkdir_if_not_exist(path): + """ + mkdir if not exists, ignore the exception when multiprocess mkdir together + """ + if not os.path.exists(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(path): + logger.warning( + 'be happy if some process has already created {}'.format( + path)) + else: + raise OSError('Failed to mkdir {}'.format(path)) + + +def _extract_student_weights(all_params, student_prefix="Student."): + s_params = { + key[len(student_prefix):]: all_params[key] + for key in all_params if student_prefix in key + } + return s_params + + def load_dygraph_pretrain(model, path=None): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {}.pdparams does not " @@ -86,7 +110,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None, - model_ema=None): + ema=None): """ load model from checkpoint or pretrained_model """ @@ -106,11 +130,11 @@ def init_model(config, for i in range(len(optimizer)): optimizer[i].set_state_dict(opti_dict[i] if isinstance( opti_dict, list) else opti_dict) - if model_ema is not None: + if ema is not None: assert os.path.exists(checkpoints + ".ema.pdparams"), \ "Given dir {}.ema.pdparams not exist.".format(checkpoints) para_ema_dict = paddle.load(checkpoints + ".ema.pdparams") - model_ema.module.set_state_dict(para_ema_dict) + ema.set_state_dict(para_ema_dict) logger.info("Finish load checkpoints from {}".format(checkpoints)) return metric_dict @@ -123,3 +147,43 @@ def init_model(config, load_dygraph_pretrain(net, path=pretrained_model) logger.info("Finish load pretrained model from {}".format( pretrained_model)) + + +def save_model(net, + optimizer, + metric_info, + model_path, + ema=None, + model_name="", + prefix='ppcls', + loss: paddle.nn.Layer=None, + save_student_model=False): + """ + save model to the target path + """ + if paddle.distributed.get_rank() != 0: + return + model_path = os.path.join(model_path, model_name) + _mkdir_if_not_exist(model_path) + model_path = os.path.join(model_path, prefix) + + params_state_dict = net.state_dict() + if loss is not None: + loss_state_dict = loss.state_dict() + keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys( + )) + assert len(keys_inter) == 0, \ + f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}" + params_state_dict.update(loss_state_dict) + + if save_student_model: + s_params = _extract_student_weights(params_state_dict) + if len(s_params) > 0: + paddle.save(s_params, model_path + "_student.pdparams") + + paddle.save(params_state_dict, model_path + ".pdparams") + if ema is not None: + paddle.save(ema.state_dict(), model_path + ".ema.pdparams") + paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt") + paddle.save(metric_info, model_path + ".pdstates") + logger.info("Already save model in {}".format(model_path)) -- GitLab