diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 7a7bbde6b4bd431275f8e5def2190d3061c5ec6f..b36aeb70cf5ceb1917e50a7c51d4abcc9c8d1a65 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -98,8 +98,8 @@ class Engine(object): logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) - # AMP training - self.amp = True if "AMP" in self.config and self.mode == "train" else False + # AMP training and evaluating + self.amp = "AMP" in self.config if self.amp and self.config["AMP"] is not None: self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.use_dynamic_loss_scaling = self.config["AMP"].get(