diff --git a/deploy/configs/inference_attr.yaml b/deploy/configs/inference_attr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d11164ee65c9ba25ada2e3268ec87198d03461a2 --- /dev/null +++ b/deploy/configs/inference_attr.yaml @@ -0,0 +1,33 @@ +Global: + infer_imgs: "./images/Pedestrain_Attr.jpg" + inference_model_dir: "../inference/" + batch_size: 1 + use_gpu: True + enable_mkldnn: False + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +PreProcess: + transform_ops: + - ResizeImage: + size: [192, 256] + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + channel_num: 3 + - ToCHWImage: + +PostProcess: + main_indicator: Attribute + Attribute: + threshold: 0.5 + glasses_threshold: 0.3 + hold_threshold: 0.6 + \ No newline at end of file diff --git a/deploy/images/Pedestrain_Attr.jpg b/deploy/images/Pedestrain_Attr.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a87e856af8c17a3b93617b93ea517b91c508619 Binary files /dev/null and b/deploy/images/Pedestrain_Attr.jpg differ diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index 4f4d005fdff2bf17e04265e136443d0cd837f10e..f58434ad1d4d1c227c861bfe7e28fcf623ef793c 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -179,3 +179,90 @@ class Binarize(object): byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit) return byte + + +class Attribute(object): + def __init__(self, + threshold=0.5, + glasses_threshold=0.3, + hold_threshold=0.6): + self.threshold = threshold + self.glasses_threshold = glasses_threshold + self.hold_threshold = hold_threshold + + def __call__(self, batch_preds, file_names=None): + # postprocess output of predictor + age_list = ['AgeLess18', 'Age18-60', 'AgeOver60'] + direct_list = ['Front', 'Side', 'Back'] + bag_list = ['HandBag', 'ShoulderBag', 'Backpack'] + upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice'] + lower_list = [ + 'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts', + 'Skirt&Dress' + ] + batch_res = [] + for res in batch_preds: + res = res.tolist() + label_res = [] + # gender + gender = 'Female' if res[22] > self.threshold else 'Male' + label_res.append(gender) + # age + age = age_list[np.argmax(res[19:22])] + label_res.append(age) + # direction + direction = direct_list[np.argmax(res[23:])] + label_res.append(direction) + # glasses + glasses = 'Glasses: ' + if res[1] > self.glasses_threshold: + glasses += 'True' + else: + glasses += 'False' + label_res.append(glasses) + # hat + hat = 'Hat: ' + if res[0] > self.threshold: + hat += 'True' + else: + hat += 'False' + label_res.append(hat) + # hold obj + hold_obj = 'HoldObjectsInFront: ' + if res[18] > self.hold_threshold: + hold_obj += 'True' + else: + hold_obj += 'False' + label_res.append(hold_obj) + # bag + bag = bag_list[np.argmax(res[15:18])] + bag_score = res[15 + np.argmax(res[15:18])] + bag_label = bag if bag_score > self.threshold else 'No bag' + label_res.append(bag_label) + # upper + upper_res = res[4:8] + upper_label = 'Upper:' + sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve' + upper_label += ' {}'.format(sleeve) + for i, r in enumerate(upper_res): + if r > self.threshold: + upper_label += ' {}'.format(upper_list[i]) + label_res.append(upper_label) + # lower + lower_res = res[8:14] + lower_label = 'Lower: ' + has_lower = False + for i, l in enumerate(lower_res): + if l > self.threshold: + lower_label += ' {}'.format(lower_list[i]) + has_lower = True + if not has_lower: + lower_label += ' {}'.format(lower_list[np.argmax(lower_res)]) + + label_res.append(lower_label) + # shoe + shoe = 'Boots' if res[14] > self.threshold else 'No boots' + label_res.append(shoe) + + batch_res.append(label_res) + return batch_res diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index 64c07ea875eaa2c456393328183b7270080a64d1..d7da15eae804cf2939da9514a11739e13c6db858 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -138,13 +138,20 @@ def main(config): continue batch_results = cls_predictor.predict(batch_imgs) for number, result_dict in enumerate(batch_results): - filename = batch_names[number] - clas_ids = result_dict["class_ids"] - scores_str = "[{}]".format(", ".join("{:.2f}".format( - r) for r in result_dict["scores"])) - label_names = result_dict["label_names"] - print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". - format(filename, clas_ids, scores_str, label_names)) + if "Attribute" in config["PostProcess"]: + filename = batch_names[number] + attr_message = result_dict + print("{}:\tclass id(s): {}".format(filename, + attr_message)) + else: + filename = batch_names[number] + clas_ids = result_dict["class_ids"] + scores_str = "[{}]".format(", ".join("{:.2f}".format( + r) for r in result_dict["scores"])) + label_names = result_dict["label_names"] + print( + "{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". + format(filename, clas_ids, scores_str, label_names)) batch_imgs = [] batch_names = [] if cls_predictor.benchmark: