未验证 提交 059d52e3 编写于 作者: W wwhio 提交者: GitHub

minor fix and support for `validate['save_img']` (#341)

* remove print

* add support for `validate['save_img']`

* `get_start.md` is linked to English doc

* remove spaces in empty line
上级 9e783d8d
...@@ -31,7 +31,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional ...@@ -31,7 +31,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* More applications, please refer to [ppgan.apps apis](./docs/en_US/apis/apps.md) * More applications, please refer to [ppgan.apps apis](./docs/en_US/apis/apps.md)
* More tutorials: * More tutorials:
- [Data preparation](./docs/en_US/data_prepare.md) - [Data preparation](./docs/en_US/data_prepare.md)
- [Training/Evaluating/Testing basic usage](./docs/zh_CN/get_started.md) - [Training/Evaluating/Testing basic usage](./docs/en_US/get_started.md)
## Model Tutorial ## Model Tutorial
......
...@@ -304,7 +304,6 @@ class REDSDataset(Dataset): ...@@ -304,7 +304,6 @@ class REDSDataset(Dataset):
else: else:
add_idx = i add_idx = i
return_l.append(add_idx) return_l.append(add_idx)
print(return_l)
name_b = '{:08d}'.format(crt_i) name_b = '{:08d}'.format(crt_i)
return return_l, name_b return return_l, name_b
......
...@@ -101,6 +101,8 @@ class Trainer: ...@@ -101,6 +101,8 @@ class Trainer:
validate_cfg = cfg.get('validate', None) validate_cfg = cfg.get('validate', None)
if validate_cfg and 'metrics' in validate_cfg: if validate_cfg and 'metrics' in validate_cfg:
self.metrics = self.model.setup_metrics(validate_cfg['metrics']) self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
if validate_cfg and 'save_img' in validate_cfg:
self.is_save_img = validate_cfg['save_img']
self.enable_visualdl = cfg.get('enable_visualdl', False) self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl: if self.enable_visualdl:
...@@ -232,33 +234,34 @@ class Trainer: ...@@ -232,33 +234,34 @@ class Trainer:
self.model.setup_input(data) self.model.setup_input(data)
self.model.test_iter(metrics=self.metrics) self.model.test_iter(metrics=self.metrics)
visual_results = {} if self.is_save_img:
current_paths = self.model.get_image_paths() visual_results = {}
current_visuals = self.model.get_current_visuals() current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()
if len(current_visuals) > 0 and list( if len(current_visuals) > 0 and list(
current_visuals.values())[0].shape == 4: current_visuals.values())[0].shape == 4:
num_samples = list(current_visuals.values())[0].shape[0] num_samples = list(current_visuals.values())[0].shape[0]
else:
num_samples = 1
for j in range(num_samples):
if j < len(current_paths):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
else: else:
basename = '{:04d}_{:04d}'.format(i, j) num_samples = 1
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
if len(img_tensor.shape) == 4:
visual_results.update({name: img_tensor[j]})
else:
visual_results.update({name: img_tensor})
self.visual('visual_test', for j in range(num_samples):
visual_results=visual_results, if j < len(current_paths):
step=self.batch_id, short_path = os.path.basename(current_paths[j])
is_save_image=True) basename = os.path.splitext(short_path)[0]
else:
basename = '{:04d}_{:04d}'.format(i, j)
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
if len(img_tensor.shape) == 4:
visual_results.update({name: img_tensor[j]})
else:
visual_results.update({name: img_tensor})
self.visual('visual_test',
visual_results=visual_results,
step=self.batch_id,
is_save_image=True)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' % self.logger.info('Test iter: [%d/%d]' %
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册