提交 10c93c55 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: enable amp only in training

上级 7040ce83
......@@ -20,6 +20,7 @@ Arch:
name: SE_ResNeXt101_32x4d
class_num: 1000
input_image_channel: *image_channel
data_format: "NHWC"
# loss function config for traing/eval process
Loss:
......
......@@ -97,7 +97,7 @@ class Engine(object):
paddle.__version__, self.device))
# AMP training
self.amp = True if "AMP" in self.config else False
self.amp = True if "AMP" in self.config and self.mode == "train" else False
if self.amp and self.config["AMP"] is not None:
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
......@@ -223,8 +223,11 @@ class Engine(object):
logger.warning(msg)
self.config['AMP']["level"] = "O1"
amp_level = "O1"
self.model = paddle.amp.decorate(
models=self.model, level=amp_level, save_dtype='float32')
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=amp_level,
save_dtype='float32')
# for distributed
self.config["Global"][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册