未验证 提交 f10e7b60 编写于 作者: L LielinJiang 提交者: GitHub

support distributed evaluate (#351)

上级 7131e67c
......@@ -79,6 +79,7 @@ class Trainer:
self.max_eval_steps = cfg.model.get('max_eval_steps', None)
self.local_rank = ParallelEnv().local_rank
self.world_size = ParallelEnv().nranks
self.log_interval = cfg.log_config.interval
self.visual_interval = cfg.log_config.visiual_interval
self.weight_interval = cfg.snapshot_config.interval
......@@ -217,8 +218,7 @@ class Trainer:
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False,
distributed=False)
is_train=False)
iter_loader = IterLoader(self.test_dataloader)
if self.max_eval_steps is None:
self.max_eval_steps = len(self.test_dataloader)
......@@ -231,6 +231,10 @@ class Trainer:
self.model.setup_train_mode(is_train=False)
for i in range(self.max_eval_steps):
if self.max_eval_steps < self.log_interval or i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' %
(i * self.world_size, self.max_eval_steps * self.world_size))
data = next(iter_loader)
self.model.setup_input(data)
self.model.test_iter(metrics=self.metrics)
......@@ -264,9 +268,6 @@ class Trainer:
step=self.batch_id,
is_save_image=True)
if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' %
(i, self.max_eval_steps))
if self.metrics:
for metric_name, metric in self.metrics.items():
......
......@@ -89,6 +89,12 @@ class LPIPSMetric(paddle.metric.Metric):
self.results.append(value.item())
def accumulate(self):
if paddle.distributed.get_world_size() > 1:
results = paddle.to_tensor(self.results)
results_list = []
paddle.distributed.all_gather(results_list, results)
self.results = paddle.concat(results_list).numpy()
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
......
......@@ -43,6 +43,12 @@ class PSNR(paddle.metric.Metric):
self.results.append(value)
def accumulate(self):
if paddle.distributed.get_world_size() > 1:
results = paddle.to_tensor(self.results)
results_list = []
paddle.distributed.all_gather(results_list, results)
self.results = paddle.concat(results_list).numpy()
if len(self.results) <= 0:
return 0.
return np.mean(self.results)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册