未验证 提交 dcc93452 编写于 作者: P parap1uie-s 提交者: GitHub

Fatal bug fix for hapi.eval (#55884)

* Fatal bug fix for hapi.eval

When multiple metrics are specified in hapi for eval, the output samples of the model are incorrectly truncated, resulting in inaccurate metric calculations

* fix codestyle
上级 96402ded
...@@ -892,17 +892,14 @@ class DynamicGraphAdapter: ...@@ -892,17 +892,14 @@ class DynamicGraphAdapter:
if self._nranks > 1: if self._nranks > 1:
outputs = [_all_gather(o) for o in to_list(outputs)] outputs = [_all_gather(o) for o in to_list(outputs)]
labels = [_all_gather(l) for l in labels] labels = [_all_gather(l) for l in labels]
metrics = []
for metric in self.model._metrics: if self.model._test_dataloader is not None and isinstance(
# cut off padding value. self.model._test_dataloader, DataLoader
if (
self.model._test_dataloader is not None
and self._nranks > 1
and isinstance(self.model._test_dataloader, DataLoader)
): ):
total_size = len(self.model._test_dataloader.dataset) total_size = len(self.model._test_dataloader.dataset)
samples = outputs[0].shape[0] samples = outputs[0].shape[0]
current_count = self._merge_count.get(self.mode + '_total', 0) current_count = self._merge_count.get(self.mode + '_total', 0)
if current_count + samples >= total_size: if current_count + samples >= total_size:
outputs = [ outputs = [
o[: int(total_size - current_count)] for o in outputs o[: int(total_size - current_count)] for o in outputs
...@@ -918,6 +915,9 @@ class DynamicGraphAdapter: ...@@ -918,6 +915,9 @@ class DynamicGraphAdapter:
self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples self._merge_count[self.mode + '_batch'] = samples
metrics = []
for metric in self.model._metrics:
# cut off padding value.
metric_outs = metric.compute(*(to_list(outputs) + labels)) metric_outs = metric.compute(*(to_list(outputs) + labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册