From 5a4ee1aec449babf417edb6b66d1125c4e107895 Mon Sep 17 00:00:00 2001 From: Tingquan Gao <35441050@qq.com> Date: Tue, 14 Mar 2023 16:16:40 +0800 Subject: [PATCH] Revert "refactor" This reverts commit 0e28a39da3952d10efaba71b1329459bb4df9db2. --- ppcls/engine/engine.py | 2 - ppcls/engine/train/regular_train_epoch.py | 18 +++-- ppcls/utils/model_saver.py | 80 ----------------------- 3 files changed, 12 insertions(+), 88 deletions(-) delete mode 100644 ppcls/utils/model_saver.py diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index d29eb969..f287b726 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -447,8 +447,6 @@ class Engine(object): level=self.amp_level, save_dtype='float32') - self.amp_level = engine.config["AMP"].get("level", "O1").upper() - def _init_dist(self): # check the gpu num world_size = dist.get_world_size() diff --git a/ppcls/engine/train/regular_train_epoch.py b/ppcls/engine/train/regular_train_epoch.py index f49e57e4..d43a969c 100644 --- a/ppcls/engine/train/regular_train_epoch.py +++ b/ppcls/engine/train/regular_train_epoch.py @@ -36,25 +36,31 @@ def regular_train_epoch(engine, epoch_id, print_batch_step): batch[1] = batch[1].reshape([batch_size, -1]) engine.global_step += 1 - # forward & backward & step opt + # image input if engine.amp: + amp_level = engine.config["AMP"].get("level", "O1").upper() with paddle.amp.auto_cast( custom_black_list={ "flatten_contiguous_range", "greater_than" }, - level=engine.amp_level): + level=amp_level): out = engine.model(batch) loss_dict = engine.train_loss_func(out, batch[1]) - loss = loss_dict["loss"] / engine.update_freq + else: + out = engine.model(batch) + loss_dict = engine.train_loss_func(out, batch[1]) + + # loss + loss = loss_dict["loss"] / engine.update_freq + + # backward & step opt + if engine.amp: scaled = engine.scaler.scale(loss) scaled.backward() if (iter_id + 1) % engine.update_freq == 0: for i in range(len(engine.optimizer)): engine.scaler.minimize(engine.optimizer[i], scaled) else: - out = engine.model(batch) - loss_dict = engine.train_loss_func(out, batch[1]) - loss = loss_dict["loss"] / engine.update_freq loss.backward() if (iter_id + 1) % engine.update_freq == 0: for i in range(len(engine.optimizer)): diff --git a/ppcls/utils/model_saver.py b/ppcls/utils/model_saver.py deleted file mode 100644 index eb15d129..00000000 --- a/ppcls/utils/model_saver.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import paddle - -from . import logger - - -def _mkdir_if_not_exist(path): - """ - mkdir if not exists, ignore the exception when multiprocess mkdir together - """ - if not os.path.exists(path): - try: - os.makedirs(path) - except OSError as e: - if e.errno == errno.EEXIST and os.path.isdir(path): - logger.warning( - 'be happy if some process has already created {}'.format( - path)) - else: - raise OSError('Failed to mkdir {}'.format(path)) - - -def _extract_student_weights(all_params, student_prefix="Student."): - s_params = { - key[len(student_prefix):]: all_params[key] - for key in all_params if student_prefix in key - } - return s_params - - -class ModelSaver(object): - def __init__(self, - engine, - net_name="model", - loss_name="train_loss_func", - opt_name="optimizer", - model_ema_name="model_ema"): - # net, loss, opt, model_ema, output_dir, - self.engine = engine - self.net_name = net_name - self.loss_name = loss_name - self.opt_name = opt_name - self.model_ema_name = model_ema_name - - arch_name = engine.config["Arch"]["name"] - self.output_dir = os.path.join(engine.output_dir, arch_name) - _mkdir_if_not_exist(self.output_dir) - - def save(self, metric_info, prefix='ppcls', save_student_model=False): - - if paddle.distributed.get_rank() != 0: - return - - save_dir = os.path.join(self.output_dir, prefix) - - params_state_dict = getattr(self.engine, self.net_name).state_dict() - loss = getattr(self.engine, self.loss_name) - if loss is not None: - loss_state_dict = loss.state_dict() - keys_inter = set(params_state_dict.keys()) & set( - loss_state_dict.keys()) - assert len(keys_inter) == 0, \ - f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}" - params_state_dict.update(loss_state_dict) - - if save_student_model: - s_params = _extract_student_weights(params_state_dict) - if len(s_params) > 0: - paddle.save(s_params, save_dir + "_student.pdparams") - - paddle.save(params_state_dict, save_dir + ".pdparams") - model_ema = getattr(self.engine, self.model_ema_name) - if model_ema is not None: - paddle.save(model_ema.module.state_dict(), - save_dir + ".ema.pdparams") - optimizer = getattr(self.engine, self.opt_name) - paddle.save([opt.state_dict() for opt in optimizer], - save_dir + ".pdopt") - paddle.save(metric_info, save_dir + ".pdstates") - logger.info("Already save model in {}".format(save_dir)) -- GitLab