未验证 提交 2e62e2e2 编写于 作者: L littletomatodonkey 提交者: GitHub

fix loss reduce from dict to list (#679)

* fix loss reduce from dict to list

* remove note
上级 42d2962d
......@@ -143,8 +143,8 @@ def create_metric(out,
out = out[1]
softmax_out = F.softmax(out)
fetchs = OrderedDict()
metric_names = set()
fetch_list = []
metric_names = []
if not multilabel:
softmax_out = F.softmax(out)
......@@ -154,12 +154,11 @@ def create_metric(out,
k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
metric_names.add("top1")
metric_names.add("top{}".format(k))
metric_names.append("top1")
metric_names.append("top{}".format(k))
fetchs['top1'] = top1
topk_name = "top{}".format(k)
fetchs[topk_name] = topk
fetch_list.append(top1)
fetch_list.append(topk)
else:
out = F.sigmoid(out)
preds = multi_hot_encode(out.numpy())
......@@ -169,19 +168,22 @@ def create_metric(out,
ham_dist_name = "hamming_distance"
accuracy_name = "multilabel_accuracy"
metric_names.add(ham_dist_name)
metric_names.add(accuracy_name)
metric_names.append(ham_dist_name)
metric_names.append(accuracy_name)
fetchs[accuracy_name] = accuracy
fetchs[ham_dist_name] = ham_dist
fetch_list.append(accuracy)
fetch_list.append(ham_dist)
# multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1:
for metric_name in metric_names:
fetchs[metric_name] = paddle.distributed.all_reduce(
fetchs[metric_name], op=paddle.distributed.ReduceOp.
for idx, fetch in enumerate(fetch_list):
fetch_list[idx] = paddle.distributed.all_reduce(
fetch, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size()
fetchs = OrderedDict()
for idx, name in enumerate(metric_names):
fetchs[name] = fetch_list[idx]
return fetchs
......@@ -282,7 +284,8 @@ def create_feeds(batch, use_mix, num_classes, multilabel=False):
if not multilabel:
label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
else:
label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes))
label = to_tensor(batch[1].numpy().astype('float32').reshape(
-1, num_classes))
feeds = {"image": image, "label": label}
return feeds
......@@ -336,9 +339,11 @@ def run(dataloader,
0, ("top1", AverageMeter(
"top1", '.5f', postfix=",")))
else:
metric_list.insert(0, ("multilabel_accuracy", AverageMeter(
metric_list.insert(
0, ("multilabel_accuracy", AverageMeter(
"multilabel_accuracy", '.5f', postfix=",")))
metric_list.insert(0, ("hamming_distance", AverageMeter(
metric_list.insert(
0, ("hamming_distance", AverageMeter(
"hamming_distance", '.5f', postfix=",")))
metric_list = OrderedDict(metric_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册