未验证 提交 94b1b920 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix ml train (#1092)

* fix sklearn params

* fix sklearn params

* fix metric order
上级 bbdc1cdb
...@@ -25,7 +25,10 @@ from sklearn.preprocessing import binarize ...@@ -25,7 +25,10 @@ from sklearn.preprocessing import binarize
import numpy as np import numpy as np
__all__ = ["multi_hot_encode", "hamming_distance", "accuracy_score", "precision_recall_fscore", "mean_average_precision"] __all__ = [
"multi_hot_encode", "hamming_distance", "accuracy_score",
"precision_recall_fscore", "mean_average_precision"
]
def multi_hot_encode(logits, threshold=0.5): def multi_hot_encode(logits, threshold=0.5):
...@@ -33,7 +36,7 @@ def multi_hot_encode(logits, threshold=0.5): ...@@ -33,7 +36,7 @@ def multi_hot_encode(logits, threshold=0.5):
Encode logits to multi-hot by elementwise for multilabel Encode logits to multi-hot by elementwise for multilabel
""" """
return binarize(logits, threshold) return binarize(logits, threshold=threshold)
def hamming_distance(output, target): def hamming_distance(output, target):
...@@ -70,7 +73,8 @@ def accuracy_score(output, target, base="sample"): ...@@ -70,7 +73,8 @@ def accuracy_score(output, target, base="sample"):
tps = mcm[:, 1, 1] tps = mcm[:, 1, 1]
fps = mcm[:, 0, 1] fps = mcm[:, 0, 1]
accuracy = (sum(tps) + sum(tns)) / (sum(tps) + sum(tns) + sum(fns) + sum(fps)) accuracy = (sum(tps) + sum(tns)) / (
sum(tps) + sum(tns) + sum(fns) + sum(fps))
return accuracy return accuracy
...@@ -84,7 +88,8 @@ def precision_recall_fscore(output, target): ...@@ -84,7 +88,8 @@ def precision_recall_fscore(output, target):
fscores: fscores:
""" """
precisions, recalls, fscores, _ = precision_recall_fscore_support(target, output) precisions, recalls, fscores, _ = precision_recall_fscore_support(target,
output)
return precisions, recalls, fscores return precisions, recalls, fscores
......
...@@ -172,8 +172,8 @@ def create_metric(out, ...@@ -172,8 +172,8 @@ def create_metric(out,
metric_names.append(ham_dist_name) metric_names.append(ham_dist_name)
metric_names.append(accuracy_name) metric_names.append(accuracy_name)
fetch_list.append(accuracy)
fetch_list.append(ham_dist) fetch_list.append(ham_dist)
fetch_list.append(accuracy)
# multi cards' eval # multi cards' eval
if not use_xpu: if not use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册