diff --git a/README.md b/README.md index a8d8d11d6fabf2bab38b5e6328449b65ea8244e8..3cdeb3e42092cb0893ec5a4f546d0d8954d09e44 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional * More applications, please refer to [ppgan.apps apis](./docs/en_US/apis/apps.md) * More tutorials: - [Data preparation](./docs/en_US/data_prepare.md) - - [Training/Evaluating/Testing basic usage](./docs/zh_CN/get_started.md) + - [Training/Evaluating/Testing basic usage](./docs/en_US/get_started.md) ## Model Tutorial diff --git a/ppgan/datasets/edvr_dataset.py b/ppgan/datasets/edvr_dataset.py index 97a1a2f8faf13852c6e11d59233b7140d27c628d..c9b587c63abef32a6041bae772658ddcb565fa72 100644 --- a/ppgan/datasets/edvr_dataset.py +++ b/ppgan/datasets/edvr_dataset.py @@ -304,7 +304,6 @@ class REDSDataset(Dataset): else: add_idx = i return_l.append(add_idx) - print(return_l) name_b = '{:08d}'.format(crt_i) return return_l, name_b diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index e7cf0cbc103601ac4e7ed4624ca98d602b31e7c7..274ff2847b774054507c101f1fcb0c9af2828ac0 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -101,6 +101,8 @@ class Trainer: validate_cfg = cfg.get('validate', None) if validate_cfg and 'metrics' in validate_cfg: self.metrics = self.model.setup_metrics(validate_cfg['metrics']) + if validate_cfg and 'save_img' in validate_cfg: + self.is_save_img = validate_cfg['save_img'] self.enable_visualdl = cfg.get('enable_visualdl', False) if self.enable_visualdl: @@ -232,33 +234,34 @@ class Trainer: self.model.setup_input(data) self.model.test_iter(metrics=self.metrics) - visual_results = {} - current_paths = self.model.get_image_paths() - current_visuals = self.model.get_current_visuals() + if self.is_save_img: + visual_results = {} + current_paths = self.model.get_image_paths() + current_visuals = self.model.get_current_visuals() - if len(current_visuals) > 0 and list( - current_visuals.values())[0].shape == 4: - num_samples = list(current_visuals.values())[0].shape[0] - else: - num_samples = 1 - - for j in range(num_samples): - if j < len(current_paths): - short_path = os.path.basename(current_paths[j]) - basename = os.path.splitext(short_path)[0] + if len(current_visuals) > 0 and list( + current_visuals.values())[0].shape == 4: + num_samples = list(current_visuals.values())[0].shape[0] else: - basename = '{:04d}_{:04d}'.format(i, j) - for k, img_tensor in current_visuals.items(): - name = '%s_%s' % (basename, k) - if len(img_tensor.shape) == 4: - visual_results.update({name: img_tensor[j]}) - else: - visual_results.update({name: img_tensor}) + num_samples = 1 - self.visual('visual_test', - visual_results=visual_results, - step=self.batch_id, - is_save_image=True) + for j in range(num_samples): + if j < len(current_paths): + short_path = os.path.basename(current_paths[j]) + basename = os.path.splitext(short_path)[0] + else: + basename = '{:04d}_{:04d}'.format(i, j) + for k, img_tensor in current_visuals.items(): + name = '%s_%s' % (basename, k) + if len(img_tensor.shape) == 4: + visual_results.update({name: img_tensor[j]}) + else: + visual_results.update({name: img_tensor}) + + self.visual('visual_test', + visual_results=visual_results, + step=self.batch_id, + is_save_image=True) if i % self.log_interval == 0: self.logger.info('Test iter: [%d/%d]' %