From 8b8a02d6079f1fb6b245ec6d334a559bbd2bc294 Mon Sep 17 00:00:00 2001 From: flytocc Date: Thu, 28 Apr 2022 00:50:28 +0800 Subject: [PATCH] add update_freq option for gradient accumulation --- .../ImageNet/ConvNeXt/convnext_tiny.yaml | 4 +-- ppcls/engine/engine.py | 6 +++- ppcls/engine/train/train.py | 35 +++++++++++-------- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml b/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml index b8c865dd..6185edb4 100644 --- a/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml +++ b/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml @@ -15,7 +15,7 @@ Global: save_inference_dir: ./inference # training model under @to_static to_static: False - + update_freq: 8 # model ema EMA: @@ -51,7 +51,7 @@ Optimizer: one_dim_param_no_weight_decay: True lr: name: Cosine - learning_rate: 5e-4 + learning_rate: 4e-3 # lr 4e-3 for total_batch_size 4096 eta_min: 1e-6 warmup_epoch: 20 warmup_start_lr: 0 diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 772bf8ed..05151a1b 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -119,6 +119,9 @@ class Engine(object): # EMA model self.ema = "EMA" in self.config and self.mode == "train" + # gradient accumulation + self.update_freq = self.config["Global"].get("update_freq", 1) + if "class_num" in config["Global"]: global_class_num = config["Global"]["class_num"] if "class_num" not in config["Arch"]: @@ -229,7 +232,7 @@ class Engine(object): if self.mode == 'train': self.optimizer, self.lr_sch = build_optimizer( self.config["Optimizer"], self.config["Global"]["epochs"], - len(self.train_dataloader), + len(self.train_dataloader) // self.update_freq, [self.model, self.train_loss_func]) # for amp training @@ -312,6 +315,7 @@ class Engine(object): self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) + self.max_iter = self.max_iter // self.update_freq * self.update_freq for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1): acc = 0.0 diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index c46650a0..a04243e6 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -53,25 +53,32 @@ def train_epoch(engine, epoch_id, print_batch_step): out = forward(engine, batch) loss_dict = engine.train_loss_func(out, batch[1]) + # loss + loss = loss_dict["loss"] / engine.update_freq + # step opt if engine.amp: - scaled = engine.scaler.scale(loss_dict["loss"]) + scaled = engine.scaler.scale(loss) scaled.backward() - for i in range(len(engine.optimizer)): - engine.scaler.minimize(engine.optimizer[i], scaled) + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + engine.scaler.minimize(engine.optimizer[i], scaled) else: - loss_dict["loss"].backward() + loss.backward() + if (iter_id + 1) % engine.update_freq == 0: + for i in range(len(engine.optimizer)): + engine.optimizer[i].step() + + if (iter_id + 1) % engine.update_freq == 0: + # clear grad for i in range(len(engine.optimizer)): - engine.optimizer[i].step() - # clear grad - for i in range(len(engine.optimizer)): - engine.optimizer[i].clear_grad() - # step lr - for i in range(len(engine.lr_sch)): - engine.lr_sch[i].step() - # update ema - if engine.ema: - engine.model_ema.update(engine.model) + engine.optimizer[i].clear_grad() + # step lr + for i in range(len(engine.lr_sch)): + engine.lr_sch[i].step() + # update ema + if engine.ema: + engine.model_ema.update(engine.model) # below code just for logging # update metric_for_logger -- GitLab