未验证 提交 53ed9239 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1320 from RainFrost1/develop

fix clas distributed eval bug
...@@ -34,6 +34,10 @@ def classification_eval(engine, epoch_id=0): ...@@ -34,6 +34,10 @@ def classification_eval(engine, epoch_id=0):
metric_key = None metric_key = None
tic = time.time() tic = time.time()
accum_samples = 0
total_samples = len(
engine.eval_dataloader.
dataset) if not engine.use_dali else engine.eval_dataloader.size
max_iter = len(engine.eval_dataloader) - 1 if platform.system( max_iter = len(engine.eval_dataloader) - 1 if platform.system(
) == "Windows" else len(engine.eval_dataloader) ) == "Windows" else len(engine.eval_dataloader)
for iter_id, batch in enumerate(engine.eval_dataloader): for iter_id, batch in enumerate(engine.eval_dataloader):
...@@ -61,15 +65,31 @@ def classification_eval(engine, epoch_id=0): ...@@ -61,15 +65,31 @@ def classification_eval(engine, epoch_id=0):
if key not in output_info: if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f') output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size) output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
accum_samples += current_samples
# calc metric # calc metric
if engine.eval_metric_func is not None: if engine.eval_metric_func is not None:
metric_dict = engine.eval_metric_func(out, batch[1])
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
for key in metric_dict: pred_list = []
paddle.distributed.all_reduce( label_list = []
metric_dict[key], op=paddle.distributed.ReduceOp.SUM) if isinstance(out, dict):
metric_dict[key] = metric_dict[ out = out["logits"]
key] / paddle.distributed.get_world_size() paddle.distributed.all_gather(pred_list, out)
paddle.distributed.all_gather(label_list, batch[1])
pred = paddle.concat(pred_list, 0)
labels = paddle.concat(label_list, 0)
if accum_samples > total_samples:
pred = pred[:total_samples + current_samples -
accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples]
current_samples = total_samples + current_samples - accum_samples
metric_dict = engine.eval_metric_func(pred, labels)
else:
metric_dict = engine.eval_metric_func(out, batch[1])
for key in metric_dict: for key in metric_dict:
if metric_key is None: if metric_key is None:
metric_key = key metric_key = key
...@@ -77,7 +97,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -77,7 +97,7 @@ def classification_eval(engine, epoch_id=0):
output_info[key] = AverageMeter(key, '7.5f') output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(metric_dict[key].numpy()[0], output_info[key].update(metric_dict[key].numpy()[0],
batch_size) current_samples)
time_info["batch_cost"].update(time.time() - tic) time_info["batch_cost"].update(time.time() - tic)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册