From 8b218b01acff7230079113d4db5e0ba29145cfa6 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Tue, 23 May 2023 09:16:12 +0000 Subject: [PATCH] refactor amp auto_cast context manager & loss scaler --- ppcls/engine/engine.py | 59 ++++++++++++----------- ppcls/engine/evaluation/classification.py | 10 +--- ppcls/engine/evaluation/retrieval.py | 5 +- ppcls/engine/train/train.py | 24 +++------ ppcls/engine/train/train_fixmatch.py | 23 +++------ ppcls/engine/train/train_metabin.py | 19 +++----- ppcls/utils/amp.py | 50 +++++++++++++++++++ 7 files changed, 103 insertions(+), 87 deletions(-) create mode 100644 ppcls/utils/amp.py diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 741dd36d..e423b318 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -33,6 +33,7 @@ from ppcls.arch import apply_to_static from ppcls.loss import build_loss from ppcls.metric import build_metrics from ppcls.optimizer import build_optimizer +from ppcls.utils.amp import AutoCast, build_scaler from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import init_model @@ -459,12 +460,7 @@ class Engine(object): if len(batch_data) >= batch_size or idx == len(image_list) - 1: batch_tensor = paddle.to_tensor(batch_data) - if self.amp and self.amp_eval: - with paddle.amp.auto_cast( - level=self.amp_level, - use_promote=self.use_promote): - out = self.model(batch_tensor) - else: + with self.auto_cast(is_eval=True): out = self.model(batch_tensor) if isinstance(out, list): @@ -528,10 +524,13 @@ class Engine(object): ) def _init_amp(self): - self.amp = "AMP" in self.config and self.config["AMP"] is not None - self.amp_eval = False + amp_config = self.config.get("AMP", None) + use_amp = True if amp_config else False - if self.amp: + if not use_amp: + self.auto_cast = AutoCast(use_amp) + self.scaler = build_scaler(use_amp) + else: AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } if paddle.is_compiled_with_cuda(): AMP_RELATED_FLAGS_SETTING.update({ @@ -539,42 +538,46 @@ class Engine(object): }) paddle.set_flags(AMP_RELATED_FLAGS_SETTING) - self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) - self.use_dynamic_loss_scaling = self.config["AMP"].get( - "use_dynamic_loss_scaling", False) - self.scaler = paddle.amp.GradScaler( - init_loss_scaling=self.scale_loss, - use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) - - self.use_promote = self.config['AMP'].get("use_promote", False) - self.amp_level = self.config['AMP'].get("level", "O1") - if self.amp_level not in ["O1", "O2"]: + use_promote = amp_config.get("use_promote", False) + amp_level = amp_config.get("level", "O1") + if amp_level not in ["O1", "O2"]: msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'." logger.warning(msg) - self.config['AMP']["level"] = "O1" - self.amp_level = "O1" + amp_level = amp_config["level"] = "O1" - self.amp_eval = self.config["AMP"].get("use_fp16_test", False) + amp_eval = self.config["AMP"].get("use_fp16_test", False) # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2 if self.mode == "train" and self.config["Global"].get( "eval_during_train", - True) and self.amp_level == "O2" and self.amp_eval == False: + True) and amp_level == "O2" and amp_eval == False: msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. " logger.warning(msg) self.config["AMP"]["use_fp16_test"] = True - self.amp_eval = True + amp_eval = True + + self.auto_cast = AutoCast( + use_amp, + amp_level=amp_level, + use_promote=use_promote, + amp_eval=amp_eval) + + scale_loss = amp_config.get("scale_loss", 1.0) + use_dynamic_loss_scaling = amp_config.get( + "use_dynamic_loss_scaling", False) + self.scaler = build_scaler( + use_amp, + scale_loss=scale_loss, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) if self.mode == "train": self.model, self.optimizer = paddle.amp.decorate( models=self.model, optimizers=self.optimizer, - level=self.amp_level, + level=amp_level, save_dtype='float32') elif self.amp_eval: self.model = paddle.amp.decorate( - models=self.model, - level=self.amp_level, - save_dtype='float32') + models=self.model, level=amp_level, save_dtype='float32') if self.mode == "train" and len(self.train_loss_func.parameters( )) > 0: diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index b9f3b870..4719c4a1 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -55,10 +55,7 @@ def classification_eval(engine, epoch_id=0): batch[1] = batch[1].reshape([-1, 1]).astype("int64") # image input - if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast(level=engine.amp_level): - out = engine.model(batch[0]) - else: + with engine.auto_cast(is_eval=True): out = engine.model(batch[0]) # just for DistributedBatchSampler issue: repeat sampling @@ -109,10 +106,7 @@ def classification_eval(engine, epoch_id=0): # calc loss if engine.eval_loss_func is not None: - if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast(level=engine.amp_level): - loss_dict = engine.eval_loss_func(preds, labels) - else: + with engine.auto_cast(is_eval=True): loss_dict = engine.eval_loss_func(preds, labels) for key in loss_dict: diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 6644d57d..12987662 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -136,10 +136,7 @@ def compute_feature(engine, name="gallery"): if len(batch) >= 3: has_camera = True batch[2] = batch[2].reshape([-1, 1]).astype("int64") - if engine.amp and engine.amp_eval: - with paddle.amp.auto_cast(level=engine.amp_level): - out = engine.model(batch[0]) - else: + with engine.auto_cast(is_eval=True): out = engine.model(batch[0]) if "Student" in out: out = out["Student"] diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 3c8035fc..639e2d47 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -48,12 +48,7 @@ def train_epoch(engine, epoch_id, print_batch_step): engine.global_step += 1 # image input - if engine.amp: - amp_level = engine.config["AMP"].get("level", "O1").upper() - with paddle.amp.auto_cast(level=amp_level): - out = forward(engine, batch) - loss_dict = engine.train_loss_func(out, batch[1]) - else: + with engine.auto_cast(is_eval=False): out = forward(engine, batch) loss_dict = engine.train_loss_func(out, batch[1]) @@ -61,17 +56,12 @@ def train_epoch(engine, epoch_id, print_batch_step): loss = loss_dict["loss"] / engine.update_freq # backward & step opt - if engine.amp: - scaled = engine.scaler.scale(loss) - scaled.backward() - if (iter_id + 1) % engine.update_freq == 0: - for i in range(len(engine.optimizer)): - engine.scaler.minimize(engine.optimizer[i], scaled) - else: - loss.backward() - if (iter_id + 1) % engine.update_freq == 0: - for i in range(len(engine.optimizer)): - engine.optimizer[i].step() + scaled = engine.scaler.scale(loss) + scaled.backward() + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + # optimizer.step() with auto amp + engine.scaler.minimize(engine.optimizer[i], scaled) if (iter_id + 1) % engine.update_freq == 0: # clear grad diff --git a/ppcls/engine/train/train_fixmatch.py b/ppcls/engine/train/train_fixmatch.py index 93839c89..6404af16 100644 --- a/ppcls/engine/train/train_fixmatch.py +++ b/ppcls/engine/train/train_fixmatch.py @@ -62,13 +62,7 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step): inputs = paddle.concat([inputs_x, inputs_u_w, inputs_u_s], axis=0) # image input - if engine.amp: - amp_level = engine.config['AMP'].get("level", "O1").upper() - with paddle.amp.auto_cast(level=amp_level): - loss_dict, logits_label = get_loss( - engine, inputs, batch_size_label, temperture, threshold, - targets_x) - else: + with engine.auto_cast(is_eval=False): loss_dict, logits_label = get_loss(engine, inputs, batch_size_label, temperture, threshold, targets_x) @@ -77,16 +71,11 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step): loss = loss_dict["loss"] # backward & step opt - if engine.amp: - scaled = engine.scaler.scale(loss) - scaled.backward() - - for i in range(len(engine.optimizer)): - engine.scaler.minimize(engine.optimizer[i], scaled) - else: - loss.backward() - for i in range(len(engine.optimizer)): - engine.optimizer[i].step() + scaled = engine.scaler.scale(loss) + scaled.backward() + + for i in range(len(engine.optimizer)): + engine.scaler.minimize(engine.optimizer[i], scaled) # step lr(by step) for i in range(len(engine.lr_sch)): diff --git a/ppcls/engine/train/train_metabin.py b/ppcls/engine/train/train_metabin.py index eed4c4d9..f0413fb4 100644 --- a/ppcls/engine/train/train_metabin.py +++ b/ppcls/engine/train/train_metabin.py @@ -189,26 +189,19 @@ def get_meta_data(meta_dataloader_iter, num_domain): def forward(engine, batch, loss_func): batch_info = defaultdict() batch_info = {"label": batch[1], "domain": batch[2]} - if engine.amp: - amp_level = engine.config["AMP"].get("level", "O1").upper() - with paddle.amp.auto_cast(level=amp_level): - out = engine.model(batch[0], batch[1]) - loss_dict = loss_func(out, batch_info) - else: + + with engine.auto_cast(is_eval=False): out = engine.model(batch[0], batch[1]) loss_dict = loss_func(out, batch_info) + return out, loss_dict def backward(engine, loss, optimizer): optimizer.clear_grad() - if engine.amp: - scaled = engine.scaler.scale(loss) - scaled.backward() - engine.scaler.minimize(optimizer, scaled) - else: - loss.backward() - optimizer.step() + scaled = engine.scaler.scale(loss) + scaled.backward() + engine.scaler.minimize(optimizer, scaled) for name, layer in engine.model.backbone.named_sublayers(): if "gate" == name.split('.')[-1]: layer.clip_gate() diff --git a/ppcls/utils/amp.py b/ppcls/utils/amp.py new file mode 100644 index 00000000..18c8bcfa --- /dev/null +++ b/ppcls/utils/amp.py @@ -0,0 +1,50 @@ +from functools import partial +import contextlib +import paddle + + +class AutoCast: + def __init__(self, + use_amp=False, + amp_level="O1", + use_promote=False, + amp_eval=False): + self.use_amp = use_amp + self.amp_eval = amp_eval + + if self.use_amp: + self.cast_context = partial(paddle.amp.auto_cast, level=amp_level) + + def __call__(self, is_eval=False): + if self.use_amp: + # not is_eval: cast for all training + # is_eval and self.amp_eval: cast for evaluation only when amp_eval is True + if not is_eval or (is_eval and self.amp_eval): + return self.cast_context() + + return contextlib.nullcontext() + + +def build_scaler(use_amp=False, scale_loss=1.0, + use_dynamic_loss_scaling=False): + class Foo: + def __init__(self): + pass + + def scale(self, loss): + return loss + + def step(self, optimizer): + optimizer.step() + + def update(self): + return + + def minimize(self, optimizer, loss): + optimizer.step() + + if use_amp: + return paddle.amp.GradScaler( + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) + return Foo() -- GitLab