From f313a6d87376778b4bbb45d17d417aff84fde2e7 Mon Sep 17 00:00:00 2001 From: tianyi1997 <93087391+tianyi1997@users.noreply.github.com> Date: Fri, 17 Feb 2023 15:09:23 +0800 Subject: [PATCH] Support training without amp --- ppcls/engine/train/train_metabin.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/ppcls/engine/train/train_metabin.py b/ppcls/engine/train/train_metabin.py index b4994b69..25186f25 100644 --- a/ppcls/engine/train/train_metabin.py +++ b/ppcls/engine/train/train_metabin.py @@ -98,7 +98,6 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step): for key, value in mtest_loss_dict.items()} } # step lr (by iter) - # the last lr_sch is cyclic_lr for i in range(len(engine.lr_sch)): if not getattr(engine.lr_sch[i], "by_epoch", False): engine.lr_sch[i].step() @@ -117,7 +116,6 @@ def train_epoch_metabin(engine, epoch_id, print_batch_step): tic = time.time() # step lr(by epoch) - # the last lr_sch is cyclic_lr for i in range(len(engine.lr_sch)): if getattr(engine.lr_sch[i], "by_epoch", False) and \ type_name(engine.lr_sch[i]) != "ReduceOnPlateau": @@ -191,10 +189,16 @@ def get_meta_data(meta_dataloader_iter, num_domain): def forward(engine, batch, loss_func): batch_info = defaultdict() batch_info = {"label": batch[1], "domain": batch[2]} - amp_level = engine.config["AMP"].get("level", "O1").upper() - with paddle.amp.auto_cast( - custom_black_list={"flatten_contiguous_range", "greater_than"}, - level=amp_level): + if engine.amp: + amp_level = engine.config["AMP"].get("level", "O1").upper() + with paddle.amp.auto_cast( + custom_black_list={ + "flatten_contiguous_range", "greater_than" + }, + level=amp_level): + out = engine.model(batch[0], batch[1]) + loss_dict = loss_func(out, batch_info) + else: out = engine.model(batch[0], batch[1]) loss_dict = loss_func(out, batch_info) return out, loss_dict @@ -202,9 +206,13 @@ def forward(engine, batch, loss_func): def backward(engine, loss, optimizer): optimizer.clear_grad() - scaled = engine.scaler.scale(loss) - scaled.backward() - engine.scaler.minimize(optimizer, scaled) + if engine.amp: + scaled = engine.scaler.scale(loss) + scaled.backward() + engine.scaler.minimize(optimizer, scaled) + else: + loss.backward() + optimizer.step() for name, layer in engine.model.backbone.named_sublayers(): if "gate" == name.split('.')[-1]: layer.clip_gate() -- GitLab