未验证 提交 e8be106a 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1872 from TingquanGao/2.3/fix_dist_loss

[cherry-pick] fix calc metric error and calc loss error in distributed
...@@ -66,66 +66,71 @@ def classification_eval(engine, epoch_id=0): ...@@ -66,66 +66,71 @@ def classification_eval(engine, epoch_id=0):
}, },
level=amp_level): level=amp_level):
out = engine.model(batch[0]) out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
batch_size)
else: else:
out = engine.model(batch[0]) out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
batch_size)
# just for DistributedBatchSampler issue: repeat sampling # just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size() current_samples = batch_size * paddle.distributed.get_world_size()
accum_samples += current_samples accum_samples += current_samples
# calc metric # gather Tensor when distributed
if engine.eval_metric_func is not None: if paddle.distributed.get_world_size() > 1:
if paddle.distributed.get_world_size() > 1: label_list = []
label_list = [] paddle.distributed.all_gather(label_list, batch[1])
paddle.distributed.all_gather(label_list, batch[1]) labels = paddle.concat(label_list, 0)
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
if isinstance(out, dict): if "Student" in out:
if "Student" in out: out = out["Student"]
out = out["Student"] if isinstance(out, dict):
elif "logits" in out:
out = out["logits"] out = out["logits"]
else: elif "logits" in out:
msg = "Error: Wrong key in out!" out = out["logits"]
raise Exception(msg)
if isinstance(out, list):
pred = []
for x in out:
pred_list = []
paddle.distributed.all_gather(pred_list, x)
pred_x = paddle.concat(pred_list, 0)
pred.append(pred_x)
else: else:
msg = "Error: Wrong key in out!"
raise Exception(msg)
if isinstance(out, list):
preds = []
for x in out:
pred_list = [] pred_list = []
paddle.distributed.all_gather(pred_list, out) paddle.distributed.all_gather(pred_list, x)
pred = paddle.concat(pred_list, 0) pred_x = paddle.concat(pred_list, 0)
preds.append(pred_x)
else:
pred_list = []
paddle.distributed.all_gather(pred_list, out)
preds = paddle.concat(pred_list, 0)
if accum_samples > total_samples and not engine.use_dali: if accum_samples > total_samples and not engine.use_dali:
pred = pred[:total_samples + current_samples - preds = preds[:total_samples + current_samples - accum_samples]
labels = labels[:total_samples + current_samples -
accum_samples] accum_samples]
labels = labels[:total_samples + current_samples - current_samples = total_samples + current_samples - accum_samples
accum_samples] else:
current_samples = total_samples + current_samples - accum_samples labels = batch[1]
metric_dict = engine.eval_metric_func(pred, labels) preds = out
# calc loss
if engine.eval_loss_func is not None:
if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=amp_level):
loss_dict = engine.eval_loss_func(preds, labels)
else: else:
metric_dict = engine.eval_metric_func(out, batch[1]) loss_dict = engine.eval_loss_func(preds, labels)
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0],
current_samples)
# calc metric
if engine.eval_metric_func is not None:
metric_dict = engine.eval_metric_func(preds, labels)
for key in metric_dict: for key in metric_dict:
if metric_key is None: if metric_key is None:
metric_key = key metric_key = key
......
...@@ -89,9 +89,6 @@ def retrieval_eval(engine, epoch_id=0): ...@@ -89,9 +89,6 @@ def retrieval_eval(engine, epoch_id=0):
def cal_feature(engine, name='gallery'): def cal_feature(engine, name='gallery'):
all_feas = None
all_image_id = None
all_unique_id = None
has_unique_id = False has_unique_id = False
if name == 'gallery': if name == 'gallery':
...@@ -103,6 +100,9 @@ def cal_feature(engine, name='gallery'): ...@@ -103,6 +100,9 @@ def cal_feature(engine, name='gallery'):
else: else:
raise RuntimeError("Only support gallery or query dataset") raise RuntimeError("Only support gallery or query dataset")
batch_feas_list = []
img_id_list = []
unique_id_list = []
max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len( max_iter = len(dataloader) - 1 if platform.system() == "Windows" else len(
dataloader) dataloader)
for idx, batch in enumerate(dataloader): # load is very time-consuming for idx, batch in enumerate(dataloader): # load is very time-consuming
...@@ -140,32 +140,39 @@ def cal_feature(engine, name='gallery'): ...@@ -140,32 +140,39 @@ def cal_feature(engine, name='gallery'):
if engine.config["Global"].get("feature_binarize") == "sign": if engine.config["Global"].get("feature_binarize") == "sign":
batch_feas = paddle.sign(batch_feas).astype("float32") batch_feas = paddle.sign(batch_feas).astype("float32")
if all_feas is None: if paddle.distributed.get_world_size() > 1:
all_feas = batch_feas batch_feas_gather = []
img_id_gather = []
unique_id_gather = []
paddle.distributed.all_gather(batch_feas_gather, batch_feas)
paddle.distributed.all_gather(img_id_gather, batch[1])
batch_feas_list.append(paddle.concat(batch_feas_gather))
img_id_list.append(paddle.concat(img_id_gather))
if has_unique_id: if has_unique_id:
all_unique_id = batch[2] paddle.distributed.all_gather(unique_id_gather, batch[2])
all_image_id = batch[1] unique_id_list.append(paddle.concat(unique_id_gather))
else: else:
all_feas = paddle.concat([all_feas, batch_feas]) batch_feas_list.append(batch_feas)
all_image_id = paddle.concat([all_image_id, batch[1]]) img_id_list.append(batch[1])
if has_unique_id: if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]]) unique_id_list.append(batch[2])
if engine.use_dali: if engine.use_dali:
dataloader.reset() dataloader.reset()
if paddle.distributed.get_world_size() > 1: all_feas = paddle.concat(batch_feas_list)
feat_list = [] all_img_id = paddle.concat(img_id_list)
img_id_list = [] if has_unique_id:
unique_id_list = [] all_unique_id = paddle.concat(unique_id_list)
paddle.distributed.all_gather(feat_list, all_feas)
paddle.distributed.all_gather(img_id_list, all_image_id) # just for DistributedBatchSampler issue: repeat sampling
all_feas = paddle.concat(feat_list, axis=0) total_samples = len(
all_image_id = paddle.concat(img_id_list, axis=0) dataloader.dataset) if not engine.use_dali else dataloader.size
if has_unique_id: all_feas = all_feas[:total_samples]
paddle.distributed.all_gather(unique_id_list, all_unique_id) all_img_id = all_img_id[:total_samples]
all_unique_id = paddle.concat(unique_id_list, axis=0) if has_unique_id:
all_unique_id = all_unique_id[:total_samples]
logger.info("Build {} done, all feat shape: {}, begin to eval..".format( logger.info("Build {} done, all feat shape: {}, begin to eval..".format(
name, all_feas.shape)) name, all_feas.shape))
return all_feas, all_image_id, all_unique_id return all_feas, all_img_id, all_unique_id
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册