diff --git a/deploy/configs/inference_attr.yaml b/deploy/configs/inference_attr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b49e2af6482e72e01716faceefb8676d87c08347 --- /dev/null +++ b/deploy/configs/inference_attr.yaml @@ -0,0 +1,33 @@ +Global: + infer_imgs: "./images/Pedestrain_Attr.jpg" + inference_model_dir: "../inference/" + batch_size: 1 + use_gpu: True + enable_mkldnn: False + cpu_num_threads: 10 + enable_benchmark: True + use_fp16: False + ir_optim: True + use_tensorrt: False + gpu_mem: 8000 + enable_profile: False + +PreProcess: + transform_ops: + - ResizeImage: + size: [192, 256] + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + channel_num: 3 + - ToCHWImage: + +PostProcess: + main_indicator: Attribute + Attribute: + threshold: 0.5 #default threshold + glasses_threshold: 0.3 #threshold only for glasses + hold_threshold: 0.6 #threshold only for hold + \ No newline at end of file diff --git a/deploy/images/Pedestrain_Attr.jpg b/deploy/images/Pedestrain_Attr.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a87e856af8c17a3b93617b93ea517b91c508619 Binary files /dev/null and b/deploy/images/Pedestrain_Attr.jpg differ diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index 4f4d005fdff2bf17e04265e136443d0cd837f10e..1107b805085531de74ca1c34d25c98a5d226d531 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -64,9 +64,17 @@ class ThreshOutput(object): for idx, probs in enumerate(x): score = probs[1] if score < self.threshold: - result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]} + result = { + "class_ids": [0], + "scores": [1 - score], + "label_names": [self.label_0] + } else: - result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]} + result = { + "class_ids": [1], + "scores": [score], + "label_names": [self.label_1] + } if file_names is not None: result["file_name"] = file_names[idx] y.append(result) @@ -179,3 +187,96 @@ class Binarize(object): byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit) return byte + + +class Attribute(object): + def __init__(self, + threshold=0.5, + glasses_threshold=0.3, + hold_threshold=0.6): + self.threshold = threshold + self.glasses_threshold = glasses_threshold + self.hold_threshold = hold_threshold + + def __call__(self, batch_preds, file_names=None): + # postprocess output of predictor + age_list = ['AgeLess18', 'Age18-60', 'AgeOver60'] + direct_list = ['Front', 'Side', 'Back'] + bag_list = ['HandBag', 'ShoulderBag', 'Backpack'] + upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice'] + lower_list = [ + 'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts', + 'Skirt&Dress' + ] + batch_res = [] + for res in batch_preds: + res = res.tolist() + label_res = [] + # gender + gender = 'Female' if res[22] > self.threshold else 'Male' + label_res.append(gender) + # age + age = age_list[np.argmax(res[19:22])] + label_res.append(age) + # direction + direction = direct_list[np.argmax(res[23:])] + label_res.append(direction) + # glasses + glasses = 'Glasses: ' + if res[1] > self.glasses_threshold: + glasses += 'True' + else: + glasses += 'False' + label_res.append(glasses) + # hat + hat = 'Hat: ' + if res[0] > self.threshold: + hat += 'True' + else: + hat += 'False' + label_res.append(hat) + # hold obj + hold_obj = 'HoldObjectsInFront: ' + if res[18] > self.hold_threshold: + hold_obj += 'True' + else: + hold_obj += 'False' + label_res.append(hold_obj) + # bag + bag = bag_list[np.argmax(res[15:18])] + bag_score = res[15 + np.argmax(res[15:18])] + bag_label = bag if bag_score > self.threshold else 'No bag' + label_res.append(bag_label) + # upper + upper_res = res[4:8] + upper_label = 'Upper:' + sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve' + upper_label += ' {}'.format(sleeve) + for i, r in enumerate(upper_res): + if r > self.threshold: + upper_label += ' {}'.format(upper_list[i]) + label_res.append(upper_label) + # lower + lower_res = res[8:14] + lower_label = 'Lower: ' + has_lower = False + for i, l in enumerate(lower_res): + if l > self.threshold: + lower_label += ' {}'.format(lower_list[i]) + has_lower = True + if not has_lower: + lower_label += ' {}'.format(lower_list[np.argmax(lower_res)]) + + label_res.append(lower_label) + # shoe + shoe = 'Boots' if res[14] > self.threshold else 'No boots' + label_res.append(shoe) + + threshold_list = [0.5] * len(res) + threshold_list[1] = self.glasses_threshold + threshold_list[18] = self.hold_threshold + pred_res = (np.array(res) > np.array(threshold_list) + ).astype(np.int8).tolist() + + batch_res.append([label_res, pred_res]) + return batch_res diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index 64c07ea875eaa2c456393328183b7270080a64d1..41b46090a7f118f401beefd12a9e9d2513cb8bfb 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -138,13 +138,21 @@ def main(config): continue batch_results = cls_predictor.predict(batch_imgs) for number, result_dict in enumerate(batch_results): - filename = batch_names[number] - clas_ids = result_dict["class_ids"] - scores_str = "[{}]".format(", ".join("{:.2f}".format( - r) for r in result_dict["scores"])) - label_names = result_dict["label_names"] - print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". - format(filename, clas_ids, scores_str, label_names)) + if "Attribute" in config["PostProcess"]: + filename = batch_names[number] + attr_message = result_dict[0] + pred_res = result_dict[1] + print("{}:\t attributes: {}, \npredict output: {}".format( + filename, attr_message, pred_res)) + else: + filename = batch_names[number] + clas_ids = result_dict["class_ids"] + scores_str = "[{}]".format(", ".join("{:.2f}".format( + r) for r in result_dict["scores"])) + label_names = result_dict["label_names"] + print( + "{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". + format(filename, clas_ids, scores_str, label_names)) batch_imgs = [] batch_names = [] if cls_predictor.benchmark: diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index ca75c2eaa4f2d7f4a604a312ed591c10811105c4..4a3d40f37fb3ed0008777643469841ef3ac38b80 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -137,8 +137,11 @@ class ConvBNLayer(TheseusLayer): weight_attr = ParamAttr(learning_rate=lr_mult, trainable=True) bias_attr = ParamAttr(learning_rate=lr_mult, trainable=True) - self.bn = BatchNorm2D( - num_filters, weight_attr=weight_attr, bias_attr=bias_attr) + self.bn = BatchNorm( + num_filters, + param_attr=ParamAttr(learning_rate=lr_mult), + bias_attr=ParamAttr(learning_rate=lr_mult), + data_layout=data_format) self.relu = nn.ReLU() def forward(self, x): @@ -287,7 +290,8 @@ class ResNet(TheseusLayer): data_format="NCHW", input_image_channel=3, return_patterns=None, - return_stages=None): + return_stages=None, + **kargs): super().__init__() self.cfg = config diff --git a/ppcls/configs/Attr/StrongBaselineAttr.yaml b/ppcls/configs/Attr/StrongBaselineAttr.yaml index 7501669bc5707fa2577c7d0b573a3b23cd2a0213..2324015d667a09a56570677713792b16f1b2ed03 100644 --- a/ppcls/configs/Attr/StrongBaselineAttr.yaml +++ b/ppcls/configs/Attr/StrongBaselineAttr.yaml @@ -20,6 +20,7 @@ Arch: name: "ResNet50" pretrained: True class_num: 26 + infer_add_softmax: False # loss function config for traing/eval process Loss: @@ -110,5 +111,3 @@ DataLoader: Metric: Eval: - ATTRMetric: - - diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 60a77c95ece010bc6c146d8665b83a7c01124679..2c0ab83f4d4a875901b6655e9ccf91af1737cc73 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -458,7 +458,9 @@ class Engine(object): def export(self): assert self.mode == "export" - use_multilabel = self.config["Global"].get("use_multilabel", False) + use_multilabel = self.config["Global"].get( + "use_multilabel", + False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0] model = ExportModel(self.config["Arch"], self.model, use_multilabel) if self.config["Global"]["pretrained_model"] is not None: load_dygraph_pretrain(model.base_model, diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 4087cd4d4fd4eca0830d0ce253082dbbbbf16ec0..2161ca86ae51c1c1aa551dd08c1924adc3d9c59b 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -390,6 +390,7 @@ class AccuracyScore(MultiLabelMetric): def get_attr_metrics(gt_label, preds_probs, threshold): """ index: evaluated label index + adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py" """ pred_label = (preds_probs > threshold).astype(int)