diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index b36aeb70cf5ceb1917e50a7c51d4abcc9c8d1a65..f3b30c9d682907375971005fb4d88cb93fec4d06 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -99,8 +99,8 @@ class Engine(object): paddle.__version__, self.device)) # AMP training and evaluating - self.amp = "AMP" in self.config - if self.amp and self.config["AMP"] is not None: + self.amp = "AMP" in self.config and self.config["AMP"] is not None + if self.amp: self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.use_dynamic_loss_scaling = self.config["AMP"].get( "use_dynamic_loss_scaling", False) @@ -228,7 +228,7 @@ class Engine(object): len(self.train_dataloader), [self.model, self.train_loss_func]) - # for amp training + # for amp if self.amp: self.scaler = paddle.amp.GradScaler( init_loss_scaling=self.scale_loss, @@ -239,12 +239,13 @@ class Engine(object): logger.warning(msg) self.config['AMP']["level"] = "O1" amp_level = "O1" - self.model, self.optimizer = paddle.amp.decorate( - models=self.model, - optimizers=self.optimizer, - level=amp_level, - save_dtype='float32') - if len(self.train_loss_func.parameters()) > 0: + self.model = paddle.amp.decorate( + models=self.model, level=amp_level, save_dtype='float32') + # TODO(gaotingquan): to compatible with Paddle develop and 2.2 + if isinstance(self.model, tuple): + self.model = self.model[0] + if self.mode == "train" and len(self.train_loss_func.parameters( + )) > 0: self.train_loss_func = paddle.amp.decorate( models=self.train_loss_func, level=amp_level, diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 6e7fc1a76fe8c3bc4402d9428d372b9c2b50a17b..5f8407e5d90d4ebc72d595781a2fef9b4b81271f 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -32,6 +32,15 @@ def classification_eval(engine, epoch_id=0): } print_batch_step = engine.config["Global"]["print_batch_step"] + if engine.amp: + amp_level = engine.config['AMP'].get("level", "O1").upper() + if amp_level == "O2" and engine.config["AMP"].get("use_fp16_test", + False): + engine.config["AMP"]["use_fp16_test"] = True + msg = "Only support FP16 evaluation when AMP O2 is enabled." + logger.warning(msg) + amp_eval = engine.config["AMP"].get("use_fp16_test", False) + metric_key = None tic = time.time() accum_samples = 0 @@ -58,15 +67,7 @@ def classification_eval(engine, epoch_id=0): batch[1] = batch[1].reshape([-1, 1]).astype("int64") # image input - if engine.amp and ( - engine.config['AMP'].get("level", "O1").upper() == "O2" or - engine.config["AMP"].get("use_fp16_test", False)): - amp_level = engine.config['AMP'].get("level", "O1").upper() - - if amp_level == "O2": - msg = "Only support FP16 evaluation when AMP O2 is enabled." - logger.warning(msg) - + if engine.amp and amp_eval: with paddle.amp.auto_cast( custom_black_list={ "flatten_contiguous_range", "greater_than" @@ -119,8 +120,7 @@ def classification_eval(engine, epoch_id=0): # calc loss if engine.eval_loss_func is not None: - if engine.amp and engine.config["AMP"].get("use_fp16_test", False): - amp_level = engine.config['AMP'].get("level", "O1").upper() + if engine.amp and amp_eval: with paddle.amp.auto_cast( custom_black_list={ "flatten_contiguous_range", "greater_than"