未验证 提交 2e62e2e2 编写于 作者: L littletomatodonkey 提交者: GitHub

fix loss reduce from dict to list (#679)

* fix loss reduce from dict to list

* remove note
上级 42d2962d
...@@ -143,8 +143,8 @@ def create_metric(out, ...@@ -143,8 +143,8 @@ def create_metric(out,
out = out[1] out = out[1]
softmax_out = F.softmax(out) softmax_out = F.softmax(out)
fetchs = OrderedDict() fetch_list = []
metric_names = set() metric_names = []
if not multilabel: if not multilabel:
softmax_out = F.softmax(out) softmax_out = F.softmax(out)
...@@ -154,12 +154,11 @@ def create_metric(out, ...@@ -154,12 +154,11 @@ def create_metric(out,
k = min(topk, classes_num) k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k) topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
metric_names.add("top1") metric_names.append("top1")
metric_names.add("top{}".format(k)) metric_names.append("top{}".format(k))
fetchs['top1'] = top1 fetch_list.append(top1)
topk_name = "top{}".format(k) fetch_list.append(topk)
fetchs[topk_name] = topk
else: else:
out = F.sigmoid(out) out = F.sigmoid(out)
preds = multi_hot_encode(out.numpy()) preds = multi_hot_encode(out.numpy())
...@@ -169,19 +168,22 @@ def create_metric(out, ...@@ -169,19 +168,22 @@ def create_metric(out,
ham_dist_name = "hamming_distance" ham_dist_name = "hamming_distance"
accuracy_name = "multilabel_accuracy" accuracy_name = "multilabel_accuracy"
metric_names.add(ham_dist_name) metric_names.append(ham_dist_name)
metric_names.add(accuracy_name) metric_names.append(accuracy_name)
fetchs[accuracy_name] = accuracy fetch_list.append(accuracy)
fetchs[ham_dist_name] = ham_dist fetch_list.append(ham_dist)
# multi cards' eval # multi cards' eval
if mode != "train" and paddle.distributed.get_world_size() > 1: if mode != "train" and paddle.distributed.get_world_size() > 1:
for metric_name in metric_names: for idx, fetch in enumerate(fetch_list):
fetchs[metric_name] = paddle.distributed.all_reduce( fetch_list[idx] = paddle.distributed.all_reduce(
fetchs[metric_name], op=paddle.distributed.ReduceOp. fetch, op=paddle.distributed.ReduceOp.
SUM) / paddle.distributed.get_world_size() SUM) / paddle.distributed.get_world_size()
fetchs = OrderedDict()
for idx, name in enumerate(metric_names):
fetchs[name] = fetch_list[idx]
return fetchs return fetchs
...@@ -282,7 +284,8 @@ def create_feeds(batch, use_mix, num_classes, multilabel=False): ...@@ -282,7 +284,8 @@ def create_feeds(batch, use_mix, num_classes, multilabel=False):
if not multilabel: if not multilabel:
label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1)) label = to_tensor(batch[1].numpy().astype("int64").reshape(-1, 1))
else: else:
label = to_tensor(batch[1].numpy().astype('float32').reshape(-1, num_classes)) label = to_tensor(batch[1].numpy().astype('float32').reshape(
-1, num_classes))
feeds = {"image": image, "label": label} feeds = {"image": image, "label": label}
return feeds return feeds
...@@ -336,10 +339,12 @@ def run(dataloader, ...@@ -336,10 +339,12 @@ def run(dataloader,
0, ("top1", AverageMeter( 0, ("top1", AverageMeter(
"top1", '.5f', postfix=","))) "top1", '.5f', postfix=",")))
else: else:
metric_list.insert(0, ("multilabel_accuracy", AverageMeter( metric_list.insert(
"multilabel_accuracy", '.5f', postfix=","))) 0, ("multilabel_accuracy", AverageMeter(
metric_list.insert(0, ("hamming_distance", AverageMeter( "multilabel_accuracy", '.5f', postfix=",")))
"hamming_distance", '.5f', postfix=","))) metric_list.insert(
0, ("hamming_distance", AverageMeter(
"hamming_distance", '.5f', postfix=",")))
metric_list = OrderedDict(metric_list) metric_list = OrderedDict(metric_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册