提交 10c93c55 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: enable amp only in training

上级 7040ce83
...@@ -20,6 +20,7 @@ Arch: ...@@ -20,6 +20,7 @@ Arch:
name: SE_ResNeXt101_32x4d name: SE_ResNeXt101_32x4d
class_num: 1000 class_num: 1000
input_image_channel: *image_channel input_image_channel: *image_channel
data_format: "NHWC"
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
......
...@@ -97,7 +97,7 @@ class Engine(object): ...@@ -97,7 +97,7 @@ class Engine(object):
paddle.__version__, self.device)) paddle.__version__, self.device))
# AMP training # AMP training
self.amp = True if "AMP" in self.config else False self.amp = True if "AMP" in self.config and self.mode == "train" else False
if self.amp and self.config["AMP"] is not None: if self.amp and self.config["AMP"] is not None:
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get( self.use_dynamic_loss_scaling = self.config["AMP"].get(
...@@ -223,8 +223,11 @@ class Engine(object): ...@@ -223,8 +223,11 @@ class Engine(object):
logger.warning(msg) logger.warning(msg)
self.config['AMP']["level"] = "O1" self.config['AMP']["level"] = "O1"
amp_level = "O1" amp_level = "O1"
self.model = paddle.amp.decorate( self.model, self.optimizer = paddle.amp.decorate(
models=self.model, level=amp_level, save_dtype='float32') models=self.model,
optimizers=self.optimizer,
level=amp_level,
save_dtype='float32')
# for distributed # for distributed
self.config["Global"][ self.config["Global"][
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册