From 205592a3e3453d8636feae0dd0658eaf85092c34 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 7 Jan 2022 02:10:46 +0000 Subject: [PATCH] fix amp with distribute bug --- ppcls/engine/engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index cbd70a49..e10be2f2 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -211,6 +211,14 @@ class Engine(object): self.optimizer, self.lr_sch = build_optimizer( self.config["Optimizer"], self.config["Global"]["epochs"], len(self.train_dataloader), [self.model]) + + # for amp training + if self.amp: + self.scaler = paddle.amp.GradScaler( + init_loss_scaling=self.scale_loss, + use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) + if self.config['AMP']['use_pure_fp16'] is True: + self.model = paddle.amp.decorate(models=self.model, level='O2') # for distributed self.config["Global"][ -- GitLab