From 94b1b9203d9dda642db8265e3271989c882b1c8e Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Thu, 29 Jul 2021 13:38:29 +0800 Subject: [PATCH] fix ml train (#1092) * fix sklearn params * fix sklearn params * fix metric order --- ppcls/utils/metrics.py | 13 +++++++++---- tools/program.py | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/ppcls/utils/metrics.py b/ppcls/utils/metrics.py index 724cae2d..bb449180 100644 --- a/ppcls/utils/metrics.py +++ b/ppcls/utils/metrics.py @@ -25,7 +25,10 @@ from sklearn.preprocessing import binarize import numpy as np -__all__ = ["multi_hot_encode", "hamming_distance", "accuracy_score", "precision_recall_fscore", "mean_average_precision"] +__all__ = [ + "multi_hot_encode", "hamming_distance", "accuracy_score", + "precision_recall_fscore", "mean_average_precision" +] def multi_hot_encode(logits, threshold=0.5): @@ -33,7 +36,7 @@ def multi_hot_encode(logits, threshold=0.5): Encode logits to multi-hot by elementwise for multilabel """ - return binarize(logits, threshold) + return binarize(logits, threshold=threshold) def hamming_distance(output, target): @@ -70,7 +73,8 @@ def accuracy_score(output, target, base="sample"): tps = mcm[:, 1, 1] fps = mcm[:, 0, 1] - accuracy = (sum(tps) + sum(tns)) / (sum(tps) + sum(tns) + sum(fns) + sum(fps)) + accuracy = (sum(tps) + sum(tns)) / ( + sum(tps) + sum(tns) + sum(fns) + sum(fps)) return accuracy @@ -84,7 +88,8 @@ def precision_recall_fscore(output, target): fscores: """ - precisions, recalls, fscores, _ = precision_recall_fscore_support(target, output) + precisions, recalls, fscores, _ = precision_recall_fscore_support(target, + output) return precisions, recalls, fscores diff --git a/tools/program.py b/tools/program.py index 4666d962..652aec3c 100644 --- a/tools/program.py +++ b/tools/program.py @@ -172,8 +172,8 @@ def create_metric(out, metric_names.append(ham_dist_name) metric_names.append(accuracy_name) - fetch_list.append(accuracy) fetch_list.append(ham_dist) + fetch_list.append(accuracy) # multi cards' eval if not use_xpu: -- GitLab