From dcc93452f151c9afebba5ec25e7e5ac7cfcc4a64 Mon Sep 17 00:00:00 2001 From: parap1uie-s Date: Tue, 8 Aug 2023 18:49:26 +0800 Subject: [PATCH] Fatal bug fix for hapi.eval (#55884) * Fatal bug fix for hapi.eval When multiple metrics are specified in hapi for eval, the output samples of the model are incorrectly truncated, resulting in inaccurate metric calculations * fix codestyle --- python/paddle/hapi/model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index daab2601335..b029c4d3a1c 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -892,17 +892,14 @@ class DynamicGraphAdapter: if self._nranks > 1: outputs = [_all_gather(o) for o in to_list(outputs)] labels = [_all_gather(l) for l in labels] - metrics = [] - for metric in self.model._metrics: - # cut off padding value. - if ( - self.model._test_dataloader is not None - and self._nranks > 1 - and isinstance(self.model._test_dataloader, DataLoader) + + if self.model._test_dataloader is not None and isinstance( + self.model._test_dataloader, DataLoader ): total_size = len(self.model._test_dataloader.dataset) samples = outputs[0].shape[0] current_count = self._merge_count.get(self.mode + '_total', 0) + if current_count + samples >= total_size: outputs = [ o[: int(total_size - current_count)] for o in outputs @@ -918,6 +915,9 @@ class DynamicGraphAdapter: self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_batch'] = samples + metrics = [] + for metric in self.model._metrics: + # cut off padding value. metric_outs = metric.compute(*(to_list(outputs) + labels)) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m) -- GitLab