diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml index 1824cc1fdf99295d077fe382ba5ed7cd101fc108..da005d32911fdb942aa600b39937b660642f27f7 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml @@ -20,6 +20,7 @@ Arch: name: SE_ResNeXt101_32x4d class_num: 1000 input_image_channel: *image_channel + data_format: "NHWC" # loss function config for traing/eval process Loss: diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 0a78fdd1a9c84c3ddd1da919b9019a6bff53a793..a58ac1e936c5a0d01366d09fc44ee08a1eebf346 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -97,7 +97,7 @@ class Engine(object): paddle.__version__, self.device)) # AMP training - self.amp = True if "AMP" in self.config else False + self.amp = True if "AMP" in self.config and self.mode == "train" else False if self.amp and self.config["AMP"] is not None: self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.use_dynamic_loss_scaling = self.config["AMP"].get( @@ -223,8 +223,11 @@ class Engine(object): logger.warning(msg) self.config['AMP']["level"] = "O1" amp_level = "O1" - self.model = paddle.amp.decorate( - models=self.model, level=amp_level, save_dtype='float32') + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=amp_level, + save_dtype='float32') # for distributed self.config["Global"][