diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 460856e3187285377b1b4ddd47d3618dab0e7dc1..5b5c4da8a6500ab90c31f33097075db5f8ee5f89 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -247,13 +247,10 @@ class Engine(object): self.config["AMP"]["use_fp16_test"] = True self.amp_eval = True - # TODO(gaotingquan): to compatible with Paddle 2.2, 2.3, develop and so on. - paddle_version = sum([ - int(x) * 10**(2 - i) - for i, x in enumerate(paddle.__version__.split(".")[:3]) - ]) + # TODO(gaotingquan): to compatible with different versions of Paddle + paddle_version = paddle.__version__[:3] # paddle version < 2.3.0 and not develop - if paddle_version < 230 and paddle_version != 0: + if paddle_version not in ["2.3", "0.0"]: if self.mode == "train": self.model, self.optimizer = paddle.amp.decorate( models=self.model,