From 939a35d6056d8caa20a0845aa1b004a9196980d8 Mon Sep 17 00:00:00 2001 From: zhiboniu Date: Thu, 26 May 2022 07:14:10 +0000 Subject: [PATCH] add more details --- deploy/configs/inference_attr.yaml | 6 +++--- deploy/python/postprocess.py | 20 +++++++++++++++++--- deploy/python/predict_cls.py | 7 ++++--- ppcls/metric/metrics.py | 1 + 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/deploy/configs/inference_attr.yaml b/deploy/configs/inference_attr.yaml index d11164ee..b49e2af6 100644 --- a/deploy/configs/inference_attr.yaml +++ b/deploy/configs/inference_attr.yaml @@ -27,7 +27,7 @@ PreProcess: PostProcess: main_indicator: Attribute Attribute: - threshold: 0.5 - glasses_threshold: 0.3 - hold_threshold: 0.6 + threshold: 0.5 #default threshold + glasses_threshold: 0.3 #threshold only for glasses + hold_threshold: 0.6 #threshold only for hold \ No newline at end of file diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index f58434ad..1107b805 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -64,9 +64,17 @@ class ThreshOutput(object): for idx, probs in enumerate(x): score = probs[1] if score < self.threshold: - result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]} + result = { + "class_ids": [0], + "scores": [1 - score], + "label_names": [self.label_0] + } else: - result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]} + result = { + "class_ids": [1], + "scores": [score], + "label_names": [self.label_1] + } if file_names is not None: result["file_name"] = file_names[idx] y.append(result) @@ -264,5 +272,11 @@ class Attribute(object): shoe = 'Boots' if res[14] > self.threshold else 'No boots' label_res.append(shoe) - batch_res.append(label_res) + threshold_list = [0.5] * len(res) + threshold_list[1] = self.glasses_threshold + threshold_list[18] = self.hold_threshold + pred_res = (np.array(res) > np.array(threshold_list) + ).astype(np.int8).tolist() + + batch_res.append([label_res, pred_res]) return batch_res diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index d7da15ea..41b46090 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -140,9 +140,10 @@ def main(config): for number, result_dict in enumerate(batch_results): if "Attribute" in config["PostProcess"]: filename = batch_names[number] - attr_message = result_dict - print("{}:\tclass id(s): {}".format(filename, - attr_message)) + attr_message = result_dict[0] + pred_res = result_dict[1] + print("{}:\t attributes: {}, \npredict output: {}".format( + filename, attr_message, pred_res)) else: filename = batch_names[number] clas_ids = result_dict["class_ids"] diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index fb087db1..1130fd49 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -391,6 +391,7 @@ class AccuracyScore(MultiLabelMetric): def get_attr_metrics(gt_label, preds_probs, threshold): """ index: evaluated label index + adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py" """ pred_label = (preds_probs > threshold).astype(int) -- GitLab