From f10e7b60d41bb449084c2ff97983481ffd7e037f Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Wed, 30 Jun 2021 10:42:53 +0800 Subject: [PATCH] support distributed evaluate (#351) --- ppgan/engine/trainer.py | 11 ++++++----- ppgan/metrics/lpips.py | 6 ++++++ ppgan/metrics/psnr_ssim.py | 6 ++++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 30d2ccd..756cbfc 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -79,6 +79,7 @@ class Trainer: self.max_eval_steps = cfg.model.get('max_eval_steps', None) self.local_rank = ParallelEnv().local_rank + self.world_size = ParallelEnv().nranks self.log_interval = cfg.log_config.interval self.visual_interval = cfg.log_config.visiual_interval self.weight_interval = cfg.snapshot_config.interval @@ -217,8 +218,7 @@ class Trainer: def test(self): if not hasattr(self, 'test_dataloader'): self.test_dataloader = build_dataloader(self.cfg.dataset.test, - is_train=False, - distributed=False) + is_train=False) iter_loader = IterLoader(self.test_dataloader) if self.max_eval_steps is None: self.max_eval_steps = len(self.test_dataloader) @@ -231,6 +231,10 @@ class Trainer: self.model.setup_train_mode(is_train=False) for i in range(self.max_eval_steps): + if self.max_eval_steps < self.log_interval or i % self.log_interval == 0: + self.logger.info('Test iter: [%d/%d]' % + (i * self.world_size, self.max_eval_steps * self.world_size)) + data = next(iter_loader) self.model.setup_input(data) self.model.test_iter(metrics=self.metrics) @@ -264,9 +268,6 @@ class Trainer: step=self.batch_id, is_save_image=True) - if i % self.log_interval == 0: - self.logger.info('Test iter: [%d/%d]' % - (i, self.max_eval_steps)) if self.metrics: for metric_name, metric in self.metrics.items(): diff --git a/ppgan/metrics/lpips.py b/ppgan/metrics/lpips.py index aadeabf..2a77eaf 100644 --- a/ppgan/metrics/lpips.py +++ b/ppgan/metrics/lpips.py @@ -89,6 +89,12 @@ class LPIPSMetric(paddle.metric.Metric): self.results.append(value.item()) def accumulate(self): + if paddle.distributed.get_world_size() > 1: + results = paddle.to_tensor(self.results) + results_list = [] + paddle.distributed.all_gather(results_list, results) + self.results = paddle.concat(results_list).numpy() + if len(self.results) <= 0: return 0. return np.mean(self.results) diff --git a/ppgan/metrics/psnr_ssim.py b/ppgan/metrics/psnr_ssim.py index 5bd9dc8..72702de 100644 --- a/ppgan/metrics/psnr_ssim.py +++ b/ppgan/metrics/psnr_ssim.py @@ -43,6 +43,12 @@ class PSNR(paddle.metric.Metric): self.results.append(value) def accumulate(self): + if paddle.distributed.get_world_size() > 1: + results = paddle.to_tensor(self.results) + results_list = [] + paddle.distributed.all_gather(results_list, results) + self.results = paddle.concat(results_list).numpy() + if len(self.results) <= 0: return 0. return np.mean(self.results) -- GitLab