diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index f3b30c9d682907375971005fb4d88cb93fec4d06..c5ef7a35a90c60545c308943a7b3e1c5f6accb96 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -98,23 +98,6 @@ class Engine(object): logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) - # AMP training and evaluating - self.amp = "AMP" in self.config and self.config["AMP"] is not None - if self.amp: - self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) - self.use_dynamic_loss_scaling = self.config["AMP"].get( - "use_dynamic_loss_scaling", False) - else: - self.scale_loss = 1.0 - self.use_dynamic_loss_scaling = False - if self.amp: - AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } - if paddle.is_compiled_with_cuda(): - AMP_RELATED_FLAGS_SETTING.update({ - 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 - }) - paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) - if "class_num" in config["Global"]: global_class_num = config["Global"]["class_num"] if "class_num" not in config["Arch"]: @@ -228,27 +211,77 @@ class Engine(object): len(self.train_dataloader), [self.model, self.train_loss_func]) + # AMP training and evaluating + self.amp = "AMP" in self.config and self.config["AMP"] is not None + self.amp_eval = False # for amp if self.amp: + AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } + if paddle.is_compiled_with_cuda(): + AMP_RELATED_FLAGS_SETTING.update({ + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 + }) + paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + + self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) + self.use_dynamic_loss_scaling = self.config["AMP"].get( + "use_dynamic_loss_scaling", False) self.scaler = paddle.amp.GradScaler( init_loss_scaling=self.scale_loss, use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) - amp_level = self.config['AMP'].get("level", "O1") - if amp_level not in ["O1", "O2"]: + + self.amp_level = self.config['AMP'].get("level", "O1") + if self.amp_level not in ["O1", "O2"]: msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'." logger.warning(msg) self.config['AMP']["level"] = "O1" - amp_level = "O1" - self.model = paddle.amp.decorate( - models=self.model, level=amp_level, save_dtype='float32') - # TODO(gaotingquan): to compatible with Paddle develop and 2.2 - if isinstance(self.model, tuple): - self.model = self.model[0] + self.amp_level = "O1" + + self.amp_eval = self.config["AMP"].get("use_fp16_test", False) + # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2 + if self.config["Global"].get( + "eval_during_train", + True) and self.amp_level == "O2" and self.amp_eval == False: + msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. " + logger.warning(msg) + self.config["AMP"]["use_fp16_test"] = True + self.amp_eval = True + + # TODO(gaotingquan): to compatible with Paddle 2.2, 2.3, develop and so on. + paddle_version = sum([ + int(x) * 10**(2 - i) + for i, x in enumerate(paddle.__version__.split(".")[:3]) + ]) + # paddle version < 2.3.0 and not develop + if paddle_version < 230 and paddle_version != 0: + if self.mode == "train": + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level, + save_dtype='float32') + elif self.amp_eval: + if self.amp_level == "O2": + msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'." + logger.warning(msg) + self.amp_eval = False + else: + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + level=self.amp_level, + save_dtype='float32') + # paddle version >= 2.3.0 or develop + else: + self.model = paddle.amp.decorate( + models=self.model, + level=self.amp_level, + save_dtype='float32') + if self.mode == "train" and len(self.train_loss_func.parameters( )) > 0: self.train_loss_func = paddle.amp.decorate( models=self.train_loss_func, - level=amp_level, + level=self.amp_level, save_dtype='float32') # for distributed diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 5f8407e5d90d4ebc72d595781a2fef9b4b81271f..bbd7632d307a29bdba4e02e790099c2d82093d88 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -32,15 +32,6 @@ def classification_eval(engine, epoch_id=0): } print_batch_step = engine.config["Global"]["print_batch_step"] - if engine.amp: - amp_level = engine.config['AMP'].get("level", "O1").upper() - if amp_level == "O2" and engine.config["AMP"].get("use_fp16_test", - False): - engine.config["AMP"]["use_fp16_test"] = True - msg = "Only support FP16 evaluation when AMP O2 is enabled." - logger.warning(msg) - amp_eval = engine.config["AMP"].get("use_fp16_test", False) - metric_key = None tic = time.time() accum_samples = 0 @@ -67,12 +58,12 @@ def classification_eval(engine, epoch_id=0): batch[1] = batch[1].reshape([-1, 1]).astype("int64") # image input - if engine.amp and amp_eval: + if engine.amp and engine.amp_eval: with paddle.amp.auto_cast( custom_black_list={ "flatten_contiguous_range", "greater_than" }, - level=amp_level): + level=engine.amp_level): out = engine.model(batch[0]) else: out = engine.model(batch[0]) @@ -120,12 +111,12 @@ def classification_eval(engine, epoch_id=0): # calc loss if engine.eval_loss_func is not None: - if engine.amp and amp_eval: + if engine.amp and engine.amp_eval: with paddle.amp.auto_cast( custom_black_list={ "flatten_contiguous_range", "greater_than" }, - level=amp_level): + level=engine.amp_level): loss_dict = engine.eval_loss_func(preds, labels) else: loss_dict = engine.eval_loss_func(preds, labels)