From b54ee04491d081d98efea0c30737b0563f433ad7 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 20 Dec 2021 06:36:56 +0000 Subject: [PATCH] Accelerate dynamic graph amp training --- ppcls/engine/engine.py | 2 ++ ppcls/engine/train/train.py | 13 +++++++------ ppcls/optimizer/optimizer.py | 5 ++++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index fe069b1d..fca3a82b 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -250,6 +250,8 @@ class Engine(object): self.scaler = paddle.amp.GradScaler( init_loss_scaling=self.scale_loss, use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) + if self.config['AMP']['use_pure_fp16'] is True: + self.model = paddle.amp.decorate(models=self.model, level='O2') self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index cbf868e4..d8f425dc 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -41,14 +41,15 @@ def train_epoch(engine, epoch_id, print_batch_step): # image input if engine.amp: - with paddle.amp.auto_cast(custom_black_list={ - "flatten_contiguous_range", "greater_than" - }): + amp_level = 'O1' + if engine.config['AMP']['use_pure_fp16'] is True: + amp_level = 'O2' + with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level): out = forward(engine, batch) + loss_dict = engine.train_loss_func(out, batch[1]) else: out = forward(engine, batch) - - loss_dict = engine.train_loss_func(out, batch[1]) + loss_dict = engine.train_loss_func(out, batch[1]) # step opt and lr if engine.amp: @@ -58,7 +59,7 @@ def train_epoch(engine, epoch_id, print_batch_step): else: loss_dict["loss"].backward() engine.optimizer.step() - engine.optimizer.clear_grad() + engine.optimizer.clear_grad(set_to_zero=True) engine.lr_sch.step() # below code just for logging diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index f429755f..290632d0 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -36,13 +36,15 @@ class Momentum(object): momentum, weight_decay=None, grad_clip=None, - multi_precision=False): + multi_precision=True, + use_multi_tensor=True): super().__init__() self.learning_rate = learning_rate self.momentum = momentum self.weight_decay = weight_decay self.grad_clip = grad_clip self.multi_precision = multi_precision + self.use_multi_tensor = use_multi_tensor def __call__(self, model_list): # model_list is None in static graph @@ -54,6 +56,7 @@ class Momentum(object): weight_decay=self.weight_decay, grad_clip=self.grad_clip, multi_precision=self.multi_precision, + use_multi_tensor=self.use_multi_tensor, parameters=parameters) return opt -- GitLab