From 96551765ebcfc060c471d0c1abe2479126c7f631 Mon Sep 17 00:00:00 2001 From: Birdylx <29754889+Birdylx@users.noreply.github.com> Date: Fri, 21 Oct 2022 17:50:18 +0800 Subject: [PATCH] Support amp for esrgan (#712) --- ppgan/engine/trainer.py | 17 ++-- ppgan/models/edvr_model.py | 6 +- ppgan/models/esrgan_model.py | 85 +++++++++++++++++++ ppgan/models/msvsr_model.py | 6 +- ppgan/models/sr_model.py | 18 ++++ test_tipc/configs/edvr/train_infer_python.txt | 2 +- .../configs/esrgan/train_infer_python.txt | 2 +- .../configs/msvsr/train_infer_python.txt | 2 +- test_tipc/prepare.sh | 1 + 9 files changed, 124 insertions(+), 15 deletions(-) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index cb1b2d3..6bff8a7 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -133,7 +133,7 @@ class Trainer: cfg.optimizer) # setup amp train - self.scaler = self.setup_amp_train() if self.cfg.amp else None + self.scalers = self.setup_amp_train() if self.cfg.amp else None # multiple gpus prepare if ParallelEnv().nranks > 1: @@ -164,11 +164,10 @@ class Trainer: self.profiler_options = cfg.profiler_options def setup_amp_train(self): - """ decerate model, optimizer and return a GradScaler """ - + """ decerate model, optimizer and return a list of GradScaler """ self.logger.info('use AMP to train. AMP level = {}'.format( self.cfg.amp_level)) - scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + # need to decorate model and optim if amp_level == 'O2' if self.cfg.amp_level == 'O2': nets, optimizers = list(self.model.nets.values()), list( @@ -181,7 +180,13 @@ class Trainer: self.model.nets[k] = nets[i] for i, (k, _) in enumerate(self.optimizers.items()): self.optimizers[k] = optimizers[i] - return scaler + + scalers = [ + paddle.amp.GradScaler(init_loss_scaling=1024) + for i in range(len(self.optimizers)) + ] + + return scalers def distributed_data_parallel(self): paddle.distributed.init_parallel_env() @@ -223,7 +228,7 @@ class Trainer: self.model.setup_input(data) if self.cfg.amp: - self.model.train_iter_amp(self.optimizers, self.scaler, + self.model.train_iter_amp(self.optimizers, self.scalers, self.cfg.amp_level) # amp train else: self.model.train_iter(self.optimizers) # norm train diff --git a/ppgan/models/edvr_model.py b/ppgan/models/edvr_model.py index eb579b2..3b5c50a 100644 --- a/ppgan/models/edvr_model.py +++ b/ppgan/models/edvr_model.py @@ -76,7 +76,7 @@ class EDVRModel(BaseSRModel): self.current_iter += 1 # amp train with brute force implementation - def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'): + def train_iter_amp(self, optims=None, scalers=None, amp_level='O1'): optims['optim'].clear_grad() if self.tsa_iter: if self.current_iter == 1: @@ -97,9 +97,9 @@ class EDVRModel(BaseSRModel): loss_pixel = self.pixel_criterion(self.output, self.gt) self.losses['loss_pixel'] = loss_pixel - scaled_loss = scaler.scale(loss_pixel) + scaled_loss = scalers[0].scale(loss_pixel) scaled_loss.backward() - scaler.minimize(optims['optim'], scaled_loss) + scalers[0].minimize(optims['optim'], scaled_loss) self.current_iter += 1 diff --git a/ppgan/models/esrgan_model.py b/ppgan/models/esrgan_model.py index fe67cff..08c7b67 100644 --- a/ppgan/models/esrgan_model.py +++ b/ppgan/models/esrgan_model.py @@ -29,6 +29,7 @@ class ESRGAN(BaseSRModel): ESRGAN paper: https://arxiv.org/pdf/1809.00219.pdf """ + def __init__(self, generator, discriminator=None, @@ -127,3 +128,87 @@ class ESRGAN(BaseSRModel): else: l_total.backward() optimizers['optimG'].step() + + # amp training + def train_iter_amp(self, optimizers=None, scalers=None, amp_level='O1'): + optimizers['optimG'].clear_grad() + l_total = 0 + + # put loss computation in amp context + with paddle.amp.auto_cast(enable=True, level=amp_level): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + # pixel loss + if self.pixel_criterion: + l_pix = self.pixel_criterion(self.output, self.gt) + l_total += l_pix + self.losses['loss_pix'] = l_pix + if self.perceptual_criterion: + l_g_percep, l_g_style = self.perceptual_criterion( + self.output, self.gt) + # l_total += l_pix + if l_g_percep is not None: + l_total += l_g_percep + self.losses['loss_percep'] = l_g_percep + if l_g_style is not None: + l_total += l_g_style + self.losses['loss_style'] = l_g_style + + # gan loss (relativistic gan) + if hasattr(self, 'gan_criterion'): + self.set_requires_grad(self.nets['discriminator'], False) + + # put fwd and loss computation in amp context + with paddle.amp.auto_cast(enable=True, level=amp_level): + real_d_pred = self.nets['discriminator'](self.gt).detach() + fake_g_pred = self.nets['discriminator'](self.output) + l_g_real = self.gan_criterion(real_d_pred - + paddle.mean(fake_g_pred), + False, + is_disc=False) + l_g_fake = self.gan_criterion(fake_g_pred - + paddle.mean(real_d_pred), + True, + is_disc=False) + l_g_gan = (l_g_real + l_g_fake) / 2 + + l_total += l_g_gan + self.losses['l_g_gan'] = l_g_gan + + scaled_l_total = scalers[0].scale(l_total) + scaled_l_total.backward() + optimizers['optimG'].step() + scalers[0].minimize(optimizers['optimG'], scaled_l_total) + + self.set_requires_grad(self.nets['discriminator'], True) + optimizers['optimD'].clear_grad() + + with paddle.amp.auto_cast(enable=True, level=amp_level): + # real + fake_d_pred = self.nets['discriminator'](self.output).detach() + real_d_pred = self.nets['discriminator'](self.gt) + l_d_real = self.gan_criterion( + real_d_pred - paddle.mean(fake_d_pred), True, + is_disc=True) * 0.5 + + # fake + fake_d_pred = self.nets['discriminator'](self.output.detach()) + l_d_fake = self.gan_criterion( + fake_d_pred - paddle.mean(real_d_pred.detach()), + False, + is_disc=True) * 0.5 + + l_temp = l_d_real + l_d_fake + scaled_l_temp = scalers[1].scale(l_temp) + scaled_l_temp.backward() + scalers[0].minimize(optimizers['optimD'], scaled_l_temp) + + self.losses['l_d_real'] = l_d_real + self.losses['l_d_fake'] = l_d_fake + self.losses['out_d_real'] = paddle.mean(real_d_pred.detach()) + self.losses['out_d_fake'] = paddle.mean(fake_d_pred.detach()) + else: + scaled_l_total = scalers[0].scale(l_total) + scaled_l_total.backward() + optimizers['optimG'].step() + scalers[0].minimize(optimizers['optimG'], scaled_l_total) diff --git a/ppgan/models/msvsr_model.py b/ppgan/models/msvsr_model.py index 1774e01..1c394ef 100644 --- a/ppgan/models/msvsr_model.py +++ b/ppgan/models/msvsr_model.py @@ -98,7 +98,7 @@ class MultiStageVSRModel(BaseSRModel): self.current_iter += 1 # amp train with brute force implementation - def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'): + def train_iter_amp(self, optims=None, scalers=None, amp_level='O1'): optims['optim'].clear_grad() if self.fix_iter: if self.current_iter == 1: @@ -133,9 +133,9 @@ class MultiStageVSRModel(BaseSRModel): if 'loss_pix' in _key) self.losses['loss'] = self.loss - scaled_loss = scaler.scale(self.loss) + scaled_loss = scalers[0].scale(self.loss) scaled_loss.backward() - scaler.minimize(optims['optim'], scaled_loss) + scalers[0].minimize(optims['optim'], scaled_loss) self.current_iter += 1 diff --git a/ppgan/models/sr_model.py b/ppgan/models/sr_model.py index 767bf27..e81e1f3 100644 --- a/ppgan/models/sr_model.py +++ b/ppgan/models/sr_model.py @@ -27,6 +27,7 @@ from ..modules.init import reset_parameters class BaseSRModel(BaseModel): """Base SR model for single image super-resolution. """ + def __init__(self, generator, pixel_criterion=None, use_init_weight=False): """ Args: @@ -65,6 +66,22 @@ class BaseSRModel(BaseModel): loss_pixel.backward() optims['optim'].step() + # amp training + def train_iter_amp(self, optims=None, scalers=None, amp_level='O1'): + optims['optim'].clear_grad() + + # put fwd and loss computation in amp context + with paddle.amp.auto_cast(enable=True, level=amp_level): + self.output = self.nets['generator'](self.lq) + self.visual_items['output'] = self.output + # pixel loss + loss_pixel = self.pixel_criterion(self.output, self.gt) + self.losses['loss_pixel'] = loss_pixel + + scaled_loss_pixel = scalers[0].scale(loss_pixel) + scaled_loss_pixel.backward() + scalers[0].minimize(optims['optim'], scaled_loss_pixel) + def test_iter(self, metrics=None): self.nets['generator'].eval() with paddle.no_grad(): @@ -84,6 +101,7 @@ class BaseSRModel(BaseModel): def init_sr_weight(net): + def reset_func(m): if hasattr(m, 'weight') and (not isinstance( m, (nn.BatchNorm, nn.BatchNorm2D))): diff --git a/test_tipc/configs/edvr/train_infer_python.txt b/test_tipc/configs/edvr/train_infer_python.txt index 46624f0..acc1875 100644 --- a/test_tipc/configs/edvr/train_infer_python.txt +++ b/test_tipc/configs/edvr/train_infer_python.txt @@ -51,7 +51,7 @@ null:null null:null ===========================train_benchmark_params========================== batch_size:64 -fp_items:fp32 +fp_items:fp32|fp16 total_iters:100 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile flags:FLAGS_cudnn_exhaustive_search=1 diff --git a/test_tipc/configs/esrgan/train_infer_python.txt b/test_tipc/configs/esrgan/train_infer_python.txt index dfbb98d..275cf0a 100644 --- a/test_tipc/configs/esrgan/train_infer_python.txt +++ b/test_tipc/configs/esrgan/train_infer_python.txt @@ -51,7 +51,7 @@ null:null null:null ===========================train_benchmark_params========================== batch_size:32|64 -fp_items:fp32 +fp_items:fp32|fp16 total_iters:500 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile flags:FLAGS_cudnn_exhaustive_search=1 diff --git a/test_tipc/configs/msvsr/train_infer_python.txt b/test_tipc/configs/msvsr/train_infer_python.txt index 3becf9b..7a0b1b8 100644 --- a/test_tipc/configs/msvsr/train_infer_python.txt +++ b/test_tipc/configs/msvsr/train_infer_python.txt @@ -51,7 +51,7 @@ null:null null:null ===========================train_benchmark_params========================== batch_size:2|4 -fp_items:fp32 +fp_items:fp32|fp16 total_iters:60 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile flags:FLAGS_cudnn_exhaustive_search=1 diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index c172b96..c43542c 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -197,5 +197,6 @@ elif [ ${MODE} = "cpp_infer" ]; then rm -rf ./inference/msvsr* wget -nc -P ./inference https://paddlegan.bj.bcebos.com/static_model/msvsr.tar --no-check-certificate cd ./inference && tar xf msvsr.tar && cd ../ + wget -nc -P ./data https://paddlegan.bj.bcebos.com/datasets/low_res.mp4 --no-check-certificate fi fi -- GitLab