diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 7058ec495e004ddeb85240da272ec349542eeb24..e79d3bd11dba88e092e7ca094fc091f153f5974a 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -274,33 +274,17 @@ class Engine(object): self.config["AMP"]["use_fp16_test"] = True self.amp_eval = True - # TODO(gaotingquan): to compatible with different versions of Paddle - paddle_version = paddle.__version__[:3] - # paddle version < 2.3.0 and not develop - if paddle_version not in ["2.3", "0.0"]: - if self.mode == "train": - self.model, self.optimizer = paddle.amp.decorate( - models=self.model, - optimizers=self.optimizer, - level=self.amp_level, - save_dtype='float32') - elif self.amp_eval: - if self.amp_level == "O2": - msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'." - logger.warning(msg) - self.amp_eval = False - else: - self.model, self.optimizer = paddle.amp.decorate( - models=self.model, - level=self.amp_level, - save_dtype='float32') - # paddle version >= 2.3.0 or develop - else: - if self.mode == "train" or self.amp_eval: - self.model = paddle.amp.decorate( - models=self.model, - level=self.amp_level, - save_dtype='float32') + if self.mode == "train": + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level, + save_dtype='float32') + elif self.amp_eval: + self.model = paddle.amp.decorate( + models=self.model, + level=self.amp_level, + save_dtype='float32') if self.mode == "train" and len(self.train_loss_func.parameters( )) > 0: