diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index d7b5c47620bcd03b9ef8ddd44deeea0621ca041d..d0c9edbb7801b69e57db1f751e709773a4efc486 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -66,66 +66,70 @@ def classification_eval(engine, epoch_id=0): }, level=amp_level): out = engine.model(batch[0]) - # calc loss - if engine.eval_loss_func is not None: - loss_dict = engine.eval_loss_func(out, batch[1]) - for key in loss_dict: - if key not in output_info: - output_info[key] = AverageMeter(key, '7.5f') - output_info[key].update(loss_dict[key].numpy()[0], - batch_size) else: out = engine.model(batch[0]) - # calc loss - if engine.eval_loss_func is not None: - loss_dict = engine.eval_loss_func(out, batch[1]) - for key in loss_dict: - if key not in output_info: - output_info[key] = AverageMeter(key, '7.5f') - output_info[key].update(loss_dict[key].numpy()[0], - batch_size) # just for DistributedBatchSampler issue: repeat sampling current_samples = batch_size * paddle.distributed.get_world_size() accum_samples += current_samples - # calc metric - if engine.eval_metric_func is not None: - if paddle.distributed.get_world_size() > 1: - label_list = [] - paddle.distributed.all_gather(label_list, batch[1]) - labels = paddle.concat(label_list, 0) - - if isinstance(out, dict): - if "Student" in out: - out = out["Student"] - elif "logits" in out: + # gather Tensor when distributed + if paddle.distributed.get_world_size() > 1: + label_list = [] + paddle.distributed.all_gather(label_list, batch[1]) + labels = paddle.concat(label_list, 0) + + if isinstance(out, dict): + if "Student" in out: + out = out["Student"] + if isinstance(out, dict): out = out["logits"] - else: - msg = "Error: Wrong key in out!" - raise Exception(msg) - if isinstance(out, list): - pred = [] - for x in out: - pred_list = [] - paddle.distributed.all_gather(pred_list, x) - pred_x = paddle.concat(pred_list, 0) - pred.append(pred_x) + elif "logits" in out: + out = out["logits"] else: + msg = "Error: Wrong key in out!" + raise Exception(msg) + if isinstance(out, list): + preds = [] + for x in out: pred_list = [] - paddle.distributed.all_gather(pred_list, out) - pred = paddle.concat(pred_list, 0) + paddle.distributed.all_gather(pred_list, x) + pred_x = paddle.concat(pred_list, 0) + preds.append(pred_x) + else: + pred_list = [] + paddle.distributed.all_gather(pred_list, out) + preds = paddle.concat(pred_list, 0) - if accum_samples > total_samples and not engine.use_dali: - pred = pred[:total_samples + current_samples - + if accum_samples > total_samples and not engine.use_dali: + preds = preds[:total_samples + current_samples - accum_samples] + labels = labels[:total_samples + current_samples - accum_samples] - labels = labels[:total_samples + current_samples - - accum_samples] - current_samples = total_samples + current_samples - accum_samples - metric_dict = engine.eval_metric_func(pred, labels) + current_samples = total_samples + current_samples - accum_samples + else: + labels = batch[1] + preds = out + + # calc loss + if engine.eval_loss_func is not None: + if engine.amp and engine.config["AMP"].get("use_fp16_test", False): + amp_level = engine.config['AMP'].get("level", "O1").upper() + with paddle.amp.auto_cast( + custom_black_list={ + "flatten_contiguous_range", "greater_than" + }, + level=amp_level): + loss_dict = engine.eval_loss_func(preds, labels) else: - metric_dict = engine.eval_metric_func(out, batch[1]) + loss_dict = engine.eval_loss_func(preds, labels) + for key in loss_dict: + if key not in output_info: + output_info[key] = AverageMeter(key, '7.5f') + output_info[key].update(loss_dict[key].numpy()[0], batch_size) + # calc metric + if engine.eval_metric_func is not None: + metric_dict = engine.eval_metric_func(preds, labels) for key in metric_dict: if metric_key is None: metric_key = key