diff --git a/deploy/configs/cls_demo/person/inference_person_cls.yaml b/deploy/configs/cls_demo/person/inference_person_cls.yaml index cf7a75659dbfe5690fb903e91b6df775bc506814..9c5161c7c4cc81e77e8c562c9b22a8f7848cebd1 100644 --- a/deploy/configs/cls_demo/person/inference_person_cls.yaml +++ b/deploy/configs/cls_demo/person/inference_person_cls.yaml @@ -27,9 +27,10 @@ PreProcess: - ToCHWImage: PostProcess: - main_indicator: Topk - Topk: - topk: 5 - class_id_map_file: "../ppcls/utils/cls_demo/person_label_list.txt" + main_indicator: ThreshOutput + ThreshOutput: + threshold: 0.9 + label_0: invalid + label_1: valid SavePreLabel: save_dir: ./pre_label/ diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index d26cbaa9a8558ffb7f96115eef0a0bd9481fe47a..4f4d005fdff2bf17e04265e136443d0cd837f10e 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -53,6 +53,26 @@ class PostProcesser(object): return rtn +class ThreshOutput(object): + def __init__(self, threshold, label_0="0", label_1="1"): + self.threshold = threshold + self.label_0 = label_0 + self.label_1 = label_1 + + def __call__(self, x, file_names=None): + y = [] + for idx, probs in enumerate(x): + score = probs[1] + if score < self.threshold: + result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]} + else: + result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]} + if file_names is not None: + result["file_name"] = file_names[idx] + y.append(result) + return y + + class Topk(object): def __init__(self, topk=1, class_id_map_file=None): assert isinstance(topk, (int, ))