未验证 提交 81fa1823 编写于 作者: W whs 提交者: GitHub

Fix eval function for distributed training (#1228)

上级 fc90903f
......@@ -68,8 +68,9 @@ def parse_args():
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
nranks = paddle.distributed.ParallelEnv().local_rank
batch_sampler = paddle.io.DistributedBatchSampler(
if nranks > 1 and paddle.distributed.get_rank() != 0:
return
batch_sampler = paddle.io.BatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader(
eval_dataset,
......@@ -116,30 +117,9 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
paddle.to_tensor(label),
eval_dataset.num_classes,
ignore_index=eval_dataset.ignore_index)
if nranks > 1:
intersect_area_list = []
pred_area_list = []
label_area_list = []
paddle.distributed.all_gather(intersect_area_list, intersect_area)
paddle.distributed.all_gather(pred_area_list, pred_area)
paddle.distributed.all_gather(label_area_list, label_area)
# Some image has been evaluated and should be eliminated in last iter
if (iter + 1) * nranks > len(eval_dataset):
valid = len(eval_dataset) - iter * nranks
intersect_area_list = intersect_area_list[:valid]
pred_area_list = pred_area_list[:valid]
label_area_list = label_area_list[:valid]
for i in range(len(intersect_area_list)):
intersect_area_all = intersect_area_all + intersect_area_list[i]
pred_area_all = pred_area_all + pred_area_list[i]
label_area_all = label_area_all + label_area_list[i]
else:
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
label_area_all)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册