diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index a80b54118f242b5d0f4679589bd8f50a13eb8e0f..6e7fc1a76fe8c3bc4402d9428d372b9c2b50a17b 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -73,68 +73,71 @@ def classification_eval(engine, epoch_id=0): }, level=amp_level): out = engine.model(batch[0]) - # calc loss - if engine.eval_loss_func is not None: - loss_dict = engine.eval_loss_func(out, batch[1]) - for key in loss_dict: - if key not in output_info: - output_info[key] = AverageMeter(key, '7.5f') - output_info[key].update(loss_dict[key].numpy()[0], - batch_size) else: out = engine.model(batch[0]) - # calc loss - if engine.eval_loss_func is not None: - loss_dict = engine.eval_loss_func(out, batch[1]) - for key in loss_dict: - if key not in output_info: - output_info[key] = AverageMeter(key, '7.5f') - output_info[key].update(loss_dict[key].numpy()[0], - batch_size) # just for DistributedBatchSampler issue: repeat sampling current_samples = batch_size * paddle.distributed.get_world_size() accum_samples += current_samples - # calc metric - if engine.eval_metric_func is not None: - if paddle.distributed.get_world_size() > 1: - label_list = [] - paddle.distributed.all_gather(label_list, batch[1]) - labels = paddle.concat(label_list, 0) - - if isinstance(out, dict): - if "Student" in out: - out = out["Student"] - if isinstance(out, dict): - out = out["logits"] - elif "logits" in out: + # gather Tensor when distributed + if paddle.distributed.get_world_size() > 1: + label_list = [] + paddle.distributed.all_gather(label_list, batch[1]) + labels = paddle.concat(label_list, 0) + + if isinstance(out, dict): + if "Student" in out: + out = out["Student"] + if isinstance(out, dict): out = out["logits"] - else: - msg = "Error: Wrong key in out!" - raise Exception(msg) - if isinstance(out, list): - pred = [] - for x in out: - pred_list = [] - paddle.distributed.all_gather(pred_list, x) - pred_x = paddle.concat(pred_list, 0) - pred.append(pred_x) + elif "logits" in out: + out = out["logits"] else: + msg = "Error: Wrong key in out!" + raise Exception(msg) + if isinstance(out, list): + preds = [] + for x in out: pred_list = [] - paddle.distributed.all_gather(pred_list, out) - pred = paddle.concat(pred_list, 0) + paddle.distributed.all_gather(pred_list, x) + pred_x = paddle.concat(pred_list, 0) + preds.append(pred_x) + else: + pred_list = [] + paddle.distributed.all_gather(pred_list, out) + preds = paddle.concat(pred_list, 0) - if accum_samples > total_samples and not engine.use_dali: - pred = pred[:total_samples + current_samples - + if accum_samples > total_samples and not engine.use_dali: + preds = preds[:total_samples + current_samples - accum_samples] + labels = labels[:total_samples + current_samples - accum_samples] - labels = labels[:total_samples + current_samples - - accum_samples] - current_samples = total_samples + current_samples - accum_samples - metric_dict = engine.eval_metric_func(pred, labels) + current_samples = total_samples + current_samples - accum_samples + else: + labels = batch[1] + preds = out + + # calc loss + if engine.eval_loss_func is not None: + if engine.amp and engine.config["AMP"].get("use_fp16_test", False): + amp_level = engine.config['AMP'].get("level", "O1").upper() + with paddle.amp.auto_cast( + custom_black_list={ + "flatten_contiguous_range", "greater_than" + }, + level=amp_level): + loss_dict = engine.eval_loss_func(preds, labels) else: - metric_dict = engine.eval_metric_func(out, batch[1]) + loss_dict = engine.eval_loss_func(preds, labels) + for key in loss_dict: + if key not in output_info: + output_info[key] = AverageMeter(key, '7.5f') + output_info[key].update(loss_dict[key].numpy()[0], + current_samples) + # calc metric + if engine.eval_metric_func is not None: + metric_dict = engine.eval_metric_func(preds, labels) for key in metric_dict: if metric_key is None: metric_key = key diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 8471a42c7421648b75a56e426bf3e1ab9c14a5fd..3dfe6337c39751acc56c49e2d8369a4116403594 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -89,9 +89,6 @@ def retrieval_eval(engine, epoch_id=0): def cal_feature(engine, name='gallery'): - all_feas = None - all_image_id = None - all_unique_id = None has_unique_id = False if name == 'gallery': @@ -103,6 +100,9 @@ def cal_feature(engine, name='gallery'): else: raise RuntimeError("Only support gallery or query dataset") + batch_feas_list = [] + img_id_list = [] + unique_id_list = [] max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len( dataloader) for idx, batch in enumerate(dataloader): # load is very time-consuming @@ -140,32 +140,39 @@ def cal_feature(engine, name='gallery'): if engine.config["Global"].get("feature_binarize") == "sign": batch_feas = paddle.sign(batch_feas).astype("float32") - if all_feas is None: - all_feas = batch_feas + if paddle.distributed.get_world_size() > 1: + batch_feas_gather = [] + img_id_gather = [] + unique_id_gather = [] + paddle.distributed.all_gather(batch_feas_gather, batch_feas) + paddle.distributed.all_gather(img_id_gather, batch[1]) + batch_feas_list.append(paddle.concat(batch_feas_gather)) + img_id_list.append(paddle.concat(img_id_gather)) if has_unique_id: - all_unique_id = batch[2] - all_image_id = batch[1] + paddle.distributed.all_gather(unique_id_gather, batch[2]) + unique_id_list.append(paddle.concat(unique_id_gather)) else: - all_feas = paddle.concat([all_feas, batch_feas]) - all_image_id = paddle.concat([all_image_id, batch[1]]) + batch_feas_list.append(batch_feas) + img_id_list.append(batch[1]) if has_unique_id: - all_unique_id = paddle.concat([all_unique_id, batch[2]]) + unique_id_list.append(batch[2]) if engine.use_dali: dataloader.reset() - if paddle.distributed.get_world_size() > 1: - feat_list = [] - img_id_list = [] - unique_id_list = [] - paddle.distributed.all_gather(feat_list, all_feas) - paddle.distributed.all_gather(img_id_list, all_image_id) - all_feas = paddle.concat(feat_list, axis=0) - all_image_id = paddle.concat(img_id_list, axis=0) - if has_unique_id: - paddle.distributed.all_gather(unique_id_list, all_unique_id) - all_unique_id = paddle.concat(unique_id_list, axis=0) + all_feas = paddle.concat(batch_feas_list) + all_img_id = paddle.concat(img_id_list) + if has_unique_id: + all_unique_id = paddle.concat(unique_id_list) + + # just for DistributedBatchSampler issue: repeat sampling + total_samples = len( + dataloader.dataset) if not engine.use_dali else dataloader.size + all_feas = all_feas[:total_samples] + all_img_id = all_img_id[:total_samples] + if has_unique_id: + all_unique_id = all_unique_id[:total_samples] logger.info("Build {} done, all feat shape: {}, begin to eval..".format( name, all_feas.shape)) - return all_feas, all_image_id, all_unique_id + return all_feas, all_img_id, all_unique_id