From 978157e78246327a851f8d277b4d55c86f1d8760 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Fri, 22 Apr 2022 07:19:19 +0000 Subject: [PATCH] fix: fix the bug that DistributedBatchSampler may sample repeatedly --- ppcls/engine/evaluation/retrieval.py | 51 ++++++++++++++++------------ 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 8471a42c..3dfe6337 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -89,9 +89,6 @@ def retrieval_eval(engine, epoch_id=0): def cal_feature(engine, name='gallery'): - all_feas = None - all_image_id = None - all_unique_id = None has_unique_id = False if name == 'gallery': @@ -103,6 +100,9 @@ def cal_feature(engine, name='gallery'): else: raise RuntimeError("Only support gallery or query dataset") + batch_feas_list = [] + img_id_list = [] + unique_id_list = [] max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len( dataloader) for idx, batch in enumerate(dataloader): # load is very time-consuming @@ -140,32 +140,39 @@ def cal_feature(engine, name='gallery'): if engine.config["Global"].get("feature_binarize") == "sign": batch_feas = paddle.sign(batch_feas).astype("float32") - if all_feas is None: - all_feas = batch_feas + if paddle.distributed.get_world_size() > 1: + batch_feas_gather = [] + img_id_gather = [] + unique_id_gather = [] + paddle.distributed.all_gather(batch_feas_gather, batch_feas) + paddle.distributed.all_gather(img_id_gather, batch[1]) + batch_feas_list.append(paddle.concat(batch_feas_gather)) + img_id_list.append(paddle.concat(img_id_gather)) if has_unique_id: - all_unique_id = batch[2] - all_image_id = batch[1] + paddle.distributed.all_gather(unique_id_gather, batch[2]) + unique_id_list.append(paddle.concat(unique_id_gather)) else: - all_feas = paddle.concat([all_feas, batch_feas]) - all_image_id = paddle.concat([all_image_id, batch[1]]) + batch_feas_list.append(batch_feas) + img_id_list.append(batch[1]) if has_unique_id: - all_unique_id = paddle.concat([all_unique_id, batch[2]]) + unique_id_list.append(batch[2]) if engine.use_dali: dataloader.reset() - if paddle.distributed.get_world_size() > 1: - feat_list = [] - img_id_list = [] - unique_id_list = [] - paddle.distributed.all_gather(feat_list, all_feas) - paddle.distributed.all_gather(img_id_list, all_image_id) - all_feas = paddle.concat(feat_list, axis=0) - all_image_id = paddle.concat(img_id_list, axis=0) - if has_unique_id: - paddle.distributed.all_gather(unique_id_list, all_unique_id) - all_unique_id = paddle.concat(unique_id_list, axis=0) + all_feas = paddle.concat(batch_feas_list) + all_img_id = paddle.concat(img_id_list) + if has_unique_id: + all_unique_id = paddle.concat(unique_id_list) + + # just for DistributedBatchSampler issue: repeat sampling + total_samples = len( + dataloader.dataset) if not engine.use_dali else dataloader.size + all_feas = all_feas[:total_samples] + all_img_id = all_img_id[:total_samples] + if has_unique_id: + all_unique_id = all_unique_id[:total_samples] logger.info("Build {} done, all feat shape: {}, begin to eval..".format( name, all_feas.shape)) - return all_feas, all_image_id, all_unique_id + return all_feas, all_img_id, all_unique_id -- GitLab