未验证 提交 94b1b920 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix ml train (#1092)

* fix sklearn params

* fix sklearn params

* fix metric order
上级 bbdc1cdb
......@@ -25,7 +25,10 @@ from sklearn.preprocessing import binarize
import numpy as np
__all__ = ["multi_hot_encode", "hamming_distance", "accuracy_score", "precision_recall_fscore", "mean_average_precision"]
__all__ = [
"multi_hot_encode", "hamming_distance", "accuracy_score",
"precision_recall_fscore", "mean_average_precision"
]
def multi_hot_encode(logits, threshold=0.5):
......@@ -33,7 +36,7 @@ def multi_hot_encode(logits, threshold=0.5):
Encode logits to multi-hot by elementwise for multilabel
"""
return binarize(logits, threshold)
return binarize(logits, threshold=threshold)
def hamming_distance(output, target):
......@@ -70,7 +73,8 @@ def accuracy_score(output, target, base="sample"):
tps = mcm[:, 1, 1]
fps = mcm[:, 0, 1]
accuracy = (sum(tps) + sum(tns)) / (sum(tps) + sum(tns) + sum(fns) + sum(fps))
accuracy = (sum(tps) + sum(tns)) / (
sum(tps) + sum(tns) + sum(fns) + sum(fps))
return accuracy
......@@ -84,7 +88,8 @@ def precision_recall_fscore(output, target):
fscores:
"""
precisions, recalls, fscores, _ = precision_recall_fscore_support(target, output)
precisions, recalls, fscores, _ = precision_recall_fscore_support(target,
output)
return precisions, recalls, fscores
......
......@@ -172,8 +172,8 @@ def create_metric(out,
metric_names.append(ham_dist_name)
metric_names.append(accuracy_name)
fetch_list.append(accuracy)
fetch_list.append(ham_dist)
fetch_list.append(accuracy)
# multi cards' eval
if not use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册