diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index aacde2f76187e2a9680df4409ca04ec5b0303165..c5164002bf519120880652fbad5dcf10b5e6f33e 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -98,23 +98,6 @@ class Engine(object): logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) - # AMP training and evaluating - self.amp = "AMP" in self.config - if self.amp and self.config["AMP"] is not None: - self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) - self.use_dynamic_loss_scaling = self.config["AMP"].get( - "use_dynamic_loss_scaling", False) - else: - self.scale_loss = 1.0 - self.use_dynamic_loss_scaling = False - if self.amp: - AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } - if paddle.is_compiled_with_cuda(): - AMP_RELATED_FLAGS_SETTING.update({ - 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 - }) - paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) - if "class_num" in config["Global"]: global_class_num = config["Global"]["class_num"] if "class_num" not in config["Arch"]: @@ -228,26 +211,77 @@ class Engine(object): len(self.train_dataloader), [self.model, self.train_loss_func]) - # for amp training + # AMP training and evaluating + self.amp = "AMP" in self.config and self.config["AMP"] is not None + self.amp_eval = False + # for amp if self.amp: + AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } + if paddle.is_compiled_with_cuda(): + AMP_RELATED_FLAGS_SETTING.update({ + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 + }) + paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + + self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) + self.use_dynamic_loss_scaling = self.config["AMP"].get( + "use_dynamic_loss_scaling", False) self.scaler = paddle.amp.GradScaler( init_loss_scaling=self.scale_loss, use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) - amp_level = self.config['AMP'].get("level", "O1") - if amp_level not in ["O1", "O2"]: + + self.amp_level = self.config['AMP'].get("level", "O1") + if self.amp_level not in ["O1", "O2"]: msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'." logger.warning(msg) self.config['AMP']["level"] = "O1" - amp_level = "O1" - self.model, self.optimizer = paddle.amp.decorate( - models=self.model, - optimizers=self.optimizer, - level=amp_level, - save_dtype='float32') - if len(self.train_loss_func.parameters()) > 0: + self.amp_level = "O1" + + self.amp_eval = self.config["AMP"].get("use_fp16_test", False) + # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2 + if self.config["Global"].get( + "eval_during_train", + True) and self.amp_level == "O2" and self.amp_eval == False: + msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. " + logger.warning(msg) + self.config["AMP"]["use_fp16_test"] = True + self.amp_eval = True + + # TODO(gaotingquan): to compatible with Paddle 2.2, 2.3, develop and so on. + paddle_version = sum([ + int(x) * 10**(2 - i) + for i, x in enumerate(paddle.__version__.split(".")[:3]) + ]) + # paddle version < 2.3.0 and not develop + if paddle_version < 230 and paddle_version != 0: + if self.mode == "train": + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level, + save_dtype='float32') + elif self.amp_eval: + if self.amp_level == "O2": + msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'." + logger.warning(msg) + self.amp_eval = False + else: + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + level=self.amp_level, + save_dtype='float32') + # paddle version >= 2.3.0 or develop + else: + self.model = paddle.amp.decorate( + models=self.model, + level=self.amp_level, + save_dtype='float32') + + if self.mode == "train" and len(self.train_loss_func.parameters( + )) > 0: self.train_loss_func = paddle.amp.decorate( models=self.train_loss_func, - level=amp_level, + level=self.amp_level, save_dtype='float32') # check the gpu num diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index f4c90a393f5043575c5e49f16fd5b220c881e0fc..60595e6a9014b4003ab8008b8144d92d628a2acd 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -58,20 +58,12 @@ def classification_eval(engine, epoch_id=0): batch[1] = batch[1].reshape([-1, 1]).astype("int64") # image input - if engine.amp and ( - engine.config['AMP'].get("level", "O1").upper() == "O2" or - engine.config["AMP"].get("use_fp16_test", False)): - amp_level = engine.config['AMP'].get("level", "O1").upper() - - if amp_level == "O2": - msg = "Only support FP16 evaluation when AMP O2 is enabled." - logger.warning(msg) - + if engine.amp and engine.amp_eval: with paddle.amp.auto_cast( custom_black_list={ "flatten_contiguous_range", "greater_than" }, - level=amp_level): + level=engine.amp_level): out = engine.model(batch[0]) else: out = engine.model(batch[0]) @@ -114,13 +106,12 @@ def classification_eval(engine, epoch_id=0): # calc loss if engine.eval_loss_func is not None: - if engine.amp and engine.config["AMP"].get("use_fp16_test", False): - amp_level = engine.config['AMP'].get("level", "O1").upper() + if engine.amp and engine.amp_eval: with paddle.amp.auto_cast( custom_black_list={ "flatten_contiguous_range", "greater_than" }, - level=amp_level): + level=engine.amp_level): loss_dict = engine.eval_loss_func(preds, labels) else: loss_dict = engine.eval_loss_func(preds, labels)