未验证 提交 059d52e3 编写于 作者: W wwhio 提交者: GitHub

minor fix and support for `validate['save_img']` (#341)

* remove print

* add support for `validate['save_img']`

* `get_start.md` is linked to English doc

* remove spaces in empty line
上级 9e783d8d
...@@ -31,7 +31,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional ...@@ -31,7 +31,7 @@ GAN-Generative Adversarial Network, was praised by "the Father of Convolutional
* More applications, please refer to [ppgan.apps apis](./docs/en_US/apis/apps.md) * More applications, please refer to [ppgan.apps apis](./docs/en_US/apis/apps.md)
* More tutorials: * More tutorials:
- [Data preparation](./docs/en_US/data_prepare.md) - [Data preparation](./docs/en_US/data_prepare.md)
- [Training/Evaluating/Testing basic usage](./docs/zh_CN/get_started.md) - [Training/Evaluating/Testing basic usage](./docs/en_US/get_started.md)
## Model Tutorial ## Model Tutorial
......
...@@ -304,7 +304,6 @@ class REDSDataset(Dataset): ...@@ -304,7 +304,6 @@ class REDSDataset(Dataset):
else: else:
add_idx = i add_idx = i
return_l.append(add_idx) return_l.append(add_idx)
print(return_l)
name_b = '{:08d}'.format(crt_i) name_b = '{:08d}'.format(crt_i)
return return_l, name_b return return_l, name_b
......
...@@ -101,6 +101,8 @@ class Trainer: ...@@ -101,6 +101,8 @@ class Trainer:
validate_cfg = cfg.get('validate', None) validate_cfg = cfg.get('validate', None)
if validate_cfg and 'metrics' in validate_cfg: if validate_cfg and 'metrics' in validate_cfg:
self.metrics = self.model.setup_metrics(validate_cfg['metrics']) self.metrics = self.model.setup_metrics(validate_cfg['metrics'])
if validate_cfg and 'save_img' in validate_cfg:
self.is_save_img = validate_cfg['save_img']
self.enable_visualdl = cfg.get('enable_visualdl', False) self.enable_visualdl = cfg.get('enable_visualdl', False)
if self.enable_visualdl: if self.enable_visualdl:
...@@ -232,6 +234,7 @@ class Trainer: ...@@ -232,6 +234,7 @@ class Trainer:
self.model.setup_input(data) self.model.setup_input(data)
self.model.test_iter(metrics=self.metrics) self.model.test_iter(metrics=self.metrics)
if self.is_save_img:
visual_results = {} visual_results = {}
current_paths = self.model.get_image_paths() current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals() current_visuals = self.model.get_current_visuals()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册