diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index cbd70a499d3ee6e63a0959f4712d9640e56f0859..e10be2f29fe794b2a9e425fe830f78e5dcfad0d8 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -211,6 +211,14 @@ class Engine(object): self.optimizer, self.lr_sch = build_optimizer( self.config["Optimizer"], self.config["Global"]["epochs"], len(self.train_dataloader), [self.model]) + + # for amp training + if self.amp: + self.scaler = paddle.amp.GradScaler( + init_loss_scaling=self.scale_loss, + use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) + if self.config['AMP']['use_pure_fp16'] is True: + self.model = paddle.amp.decorate(models=self.model, level='O2') # for distributed self.config["Global"][