From 2e62e2e25e3a1496611c6f09e92c979d71113939 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Tue, 13 Apr 2021 23:19:36 +0800 Subject: [PATCH] fix loss reduce from dict to list (#679) * fix loss reduce from dict to list * remove note --- tools/program.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/tools/program.py b/tools/program.py index 7006c9c3..fd155ab8 100644 --- a/tools/program.py +++ b/tools/program.py @@ -143,8 +143,8 @@ def create_metric(out, out = out[1] softmax_out = F.softmax(out) - fetchs = OrderedDict() - metric_names = set() + fetch_list = [] + metric_names = [] if not multilabel: softmax_out = F.softmax(out) @@ -154,12 +154,11 @@ def create_metric(out, k = min(topk, classes_num) topk = paddle.metric.accuracy(softmax_out, label=label, k=k) - metric_names.add("top1") - metric_names.add("top{}".format(k)) + metric_names.append("top1") + metric_names.append("top{}".format(k)) - fetchs['top1'] = top1 - topk_name = "top{}".format(k) - fetchs[topk_name] = topk + fetch_list.append(top1) + fetch_list.append(topk) else: out = F.sigmoid(out) preds = multi_hot_encode(out.numpy()) @@ -169,19 +168,22 @@ def create_metric(out, ham_dist_name = "hamming_distance" accuracy_name = "multilabel_accuracy" - metric_names.add(ham_dist_name) - metric_names.add(accuracy_name) + metric_names.append(ham_dist_name) + metric_names.append(accuracy_name) - fetchs[accuracy_name] = accuracy - fetchs[ham_dist_name] = ham_dist + fetch_list.append(accuracy) + fetch_list.append(ham_dist) # multi cards' eval if mode != "train" and paddle.distributed.get_world_size() > 1: - for metric_name in metric_names: - fetchs[metric_name] = paddle.distributed.all_reduce( - fetchs[metric_name], op=paddle.distributed.ReduceOp. + for idx, fetch in enumerate(fetch_list): + fetch_list[idx] = paddle.distributed.all_reduce( + fetch, op=paddle.distributed.ReduceOp. SUM) / paddle.distributed.get_world_size() + fetchs = OrderedDict() + for idx, name in enumerate(metric_names): + fetchs[name] = fetch_list[idx] return fetchs @@ -282,7 +284,8 @@ def create_feeds(batch, use_mix, num_classes, multilabel=False): if not multilabel: label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1)) else: - label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes)) + label = to_tensor(batch[1].numpy().astype('float32').reshape( + -1, num_classes)) feeds = {"image": image, "label": label} return feeds @@ -336,10 +339,12 @@ def run(dataloader, 0, ("top1", AverageMeter( "top1", '.5f', postfix=","))) else: - metric_list.insert(0, ("multilabel_accuracy", AverageMeter( - "multilabel_accuracy", '.5f', postfix=","))) - metric_list.insert(0, ("hamming_distance", AverageMeter( - "hamming_distance", '.5f', postfix=","))) + metric_list.insert( + 0, ("multilabel_accuracy", AverageMeter( + "multilabel_accuracy", '.5f', postfix=","))) + metric_list.insert( + 0, ("hamming_distance", AverageMeter( + "hamming_distance", '.5f', postfix=","))) metric_list = OrderedDict(metric_list) -- GitLab