From 10c93c55d1a8605833e7ee828f6a2ebb27cca5af Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Tue, 11 Jan 2022 14:07:09 +0000 Subject: [PATCH] fix: enable amp only in training --- .../ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml | 1 + ppcls/engine/engine.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml index 1824cc1f..da005d32 100644 --- a/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml +++ b/ppcls/configs/ImageNet/SENet/SE_ResNeXt101_32x4d_amp_O2.yaml @@ -20,6 +20,7 @@ Arch: name: SE_ResNeXt101_32x4d class_num: 1000 input_image_channel: *image_channel + data_format: "NHWC" # loss function config for traing/eval process Loss: diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 0a78fdd1..a58ac1e9 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -97,7 +97,7 @@ class Engine(object): paddle.__version__, self.device)) # AMP training - self.amp = True if "AMP" in self.config else False + self.amp = True if "AMP" in self.config and self.mode == "train" else False if self.amp and self.config["AMP"] is not None: self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.use_dynamic_loss_scaling = self.config["AMP"].get( @@ -223,8 +223,11 @@ class Engine(object): logger.warning(msg) self.config['AMP']["level"] = "O1" amp_level = "O1" - self.model = paddle.amp.decorate( - models=self.model, level=amp_level, save_dtype='float32') + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=amp_level, + save_dtype='float32') # for distributed self.config["Global"][ -- GitLab