提交 a7ba6eab 编写于 作者: G gaotingquan 提交者: cuicheng01

optimizer must be decorated when training with AMPO2

上级 dc2c8528
...@@ -274,33 +274,17 @@ class Engine(object): ...@@ -274,33 +274,17 @@ class Engine(object):
self.config["AMP"]["use_fp16_test"] = True self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True self.amp_eval = True
# TODO(gaotingquan): to compatible with different versions of Paddle if self.mode == "train":
paddle_version = paddle.__version__[:3] self.model, self.optimizer = paddle.amp.decorate(
# paddle version < 2.3.0 and not develop models=self.model,
if paddle_version not in ["2.3", "0.0"]: optimizers=self.optimizer,
if self.mode == "train": level=self.amp_level,
self.model, self.optimizer = paddle.amp.decorate( save_dtype='float32')
models=self.model, elif self.amp_eval:
optimizers=self.optimizer, self.model = paddle.amp.decorate(
level=self.amp_level, models=self.model,
save_dtype='float32') level=self.amp_level,
elif self.amp_eval: save_dtype='float32')
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters( if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0: )) > 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册