From 059d52e354dbd0da806f02176454161fc5c0af28 Mon Sep 17 00:00:00 2001 From: wwhio Date: Wed, 16 Jun 2021 18:28:07 +0800 Subject: [PATCH] minor fix and support for `validate['save_img']` (#341) * remove print * add support for `validate['save_img']` * `get_start.md` is linked to English doc * remove spaces in empty line --- README.md | 2 +- ppgan/datasets/edvr_dataset.py | 1 - ppgan/engine/trainer.py | 51 ++++++++++++++++++---------------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index a8d8d11..3cdeb3e 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional * More applications, please refer to [ppgan.apps apis](./docs/en_US/apis/apps.md) * More tutorials: - [Data preparation](./docs/en_US/data_prepare.md) - - [Training/Evaluating/Testing basic usage](./docs/zh_CN/get_started.md) + - [Training/Evaluating/Testing basic usage](./docs/en_US/get_started.md) ## Model Tutorial diff --git a/ppgan/datasets/edvr_dataset.py b/ppgan/datasets/edvr_dataset.py index 97a1a2f..c9b587c 100644 --- a/ppgan/datasets/edvr_dataset.py +++ b/ppgan/datasets/edvr_dataset.py @@ -304,7 +304,6 @@ class REDSDataset(Dataset): else: add_idx = i return_l.append(add_idx) - print(return_l) name_b = '{:08d}'.format(crt_i) return return_l, name_b diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index e7cf0cb..274ff28 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -101,6 +101,8 @@ class Trainer: validate_cfg = cfg.get('validate', None) if validate_cfg and 'metrics' in validate_cfg: self.metrics = self.model.setup_metrics(validate_cfg['metrics']) + if validate_cfg and 'save_img' in validate_cfg: + self.is_save_img = validate_cfg['save_img'] self.enable_visualdl = cfg.get('enable_visualdl', False) if self.enable_visualdl: @@ -232,33 +234,34 @@ class Trainer: self.model.setup_input(data) self.model.test_iter(metrics=self.metrics) - visual_results = {} - current_paths = self.model.get_image_paths() - current_visuals = self.model.get_current_visuals() + if self.is_save_img: + visual_results = {} + current_paths = self.model.get_image_paths() + current_visuals = self.model.get_current_visuals() - if len(current_visuals) > 0 and list( - current_visuals.values())[0].shape == 4: - num_samples = list(current_visuals.values())[0].shape[0] - else: - num_samples = 1 - - for j in range(num_samples): - if j < len(current_paths): - short_path = os.path.basename(current_paths[j]) - basename = os.path.splitext(short_path)[0] + if len(current_visuals) > 0 and list( + current_visuals.values())[0].shape == 4: + num_samples = list(current_visuals.values())[0].shape[0] else: - basename = '{:04d}_{:04d}'.format(i, j) - for k, img_tensor in current_visuals.items(): - name = '%s_%s' % (basename, k) - if len(img_tensor.shape) == 4: - visual_results.update({name: img_tensor[j]}) - else: - visual_results.update({name: img_tensor}) + num_samples = 1 - self.visual('visual_test', - visual_results=visual_results, - step=self.batch_id, - is_save_image=True) + for j in range(num_samples): + if j < len(current_paths): + short_path = os.path.basename(current_paths[j]) + basename = os.path.splitext(short_path)[0] + else: + basename = '{:04d}_{:04d}'.format(i, j) + for k, img_tensor in current_visuals.items(): + name = '%s_%s' % (basename, k) + if len(img_tensor.shape) == 4: + visual_results.update({name: img_tensor[j]}) + else: + visual_results.update({name: img_tensor}) + + self.visual('visual_test', + visual_results=visual_results, + step=self.batch_id, + is_save_image=True) if i % self.log_interval == 0: self.logger.info('Test iter: [%d/%d]' % -- GitLab