From 852324d6e0b65b4f3483058c7cfce6d17424599f Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Wed, 3 Aug 2022 18:49:04 +0800 Subject: [PATCH] update attribute docs (#6574) --- deploy/pipeline/pphuman/attr_infer.py | 7 +- .../customization/attribute.md | 99 +++++++++++++++++++ 2 files changed, 102 insertions(+), 4 deletions(-) diff --git a/deploy/pipeline/pphuman/attr_infer.py b/deploy/pipeline/pphuman/attr_infer.py index f783d5e5c..dfbdd8f68 100644 --- a/deploy/pipeline/pphuman/attr_infer.py +++ b/deploy/pipeline/pphuman/attr_infer.py @@ -142,13 +142,12 @@ class AttrDetector(Detector): bag_label = bag if bag_score > self.threshold else 'No bag' label_res.append(bag_label) # upper - upper_res = res[4:8] upper_label = 'Upper:' sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve' upper_label += ' {}'.format(sleeve) - for i, r in enumerate(upper_res): - if r > self.threshold: - upper_label += ' {}'.format(upper_list[i]) + upper_res = res[4:8] + if np.max(upper_res) > self.threshold: + upper_label += ' {}'.format(upper_list[np.argmax(upper_res)]) label_res.append(upper_label) # lower lower_res = res[8:14] diff --git a/docs/advanced_tutorials/customization/attribute.md b/docs/advanced_tutorials/customization/attribute.md index b7abe48b1..833dfc322 100644 --- a/docs/advanced_tutorials/customization/attribute.md +++ b/docs/advanced_tutorials/customization/attribute.md @@ -192,3 +192,102 @@ ATTR: 删减属性同理。 例如,如果不需要年龄属性,则位置[19, 20, 21]的数值可以去掉。只需在train.txt中标注的26个数字中全部删除第19-21位数值即可,同时标注数据时也不再需要标注这3位属性值。 + +## 修改后处理代码 + +修改了属性定义后,pipeline后处理部分也需要做相应修改,主要影响结果可视化时的显示结果。 + +相应代码在路径`deploy/pipeline/pphuman/attr_infer.py`文件中`postprocess`函数。 + +其函数实现说明如下: + +``` +# 函数入口 + def postprocess(self, inputs, result): + # postprocess output of predictor + im_results = result['output'] + +# 1) 定义各组属性实际意义,其数量及位置与输出结果中占用位数一一对应。 + labels = self.pred_config.labels + age_list = ['AgeLess18', 'Age18-60', 'AgeOver60'] + direct_list = ['Front', 'Side', 'Back'] + bag_list = ['HandBag', 'ShoulderBag', 'Backpack'] + upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice'] + lower_list = [ + 'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts', + 'Skirt&Dress' + ] +# 2) 部分属性所用阈值与通用值有明显区别,单独设置 + glasses_threshold = 0.3 + hold_threshold = 0.6 + + batch_res = [] + for res in im_results: + res = res.tolist() + label_res = [] + # gender +# 3) 单个位置属性类别,判断该位置是否大于阈值,来分配二分类结果 + gender = 'Female' if res[22] > self.threshold else 'Male' + label_res.append(gender) + # age +# 4)多个位置属性类别,N选一形式,选择得分最高的属性 + age = age_list[np.argmax(res[19:22])] + label_res.append(age) + # direction + direction = direct_list[np.argmax(res[23:])] + label_res.append(direction) + # glasses + glasses = 'Glasses: ' + if res[1] > glasses_threshold: + glasses += 'True' + else: + glasses += 'False' + label_res.append(glasses) + # hat + hat = 'Hat: ' + if res[0] > self.threshold: + hat += 'True' + else: + hat += 'False' + label_res.append(hat) + # hold obj + hold_obj = 'HoldObjectsInFront: ' + if res[18] > hold_threshold: + hold_obj += 'True' + else: + hold_obj += 'False' + label_res.append(hold_obj) + # bag + bag = bag_list[np.argmax(res[15:18])] + bag_score = res[15 + np.argmax(res[15:18])] + bag_label = bag if bag_score > self.threshold else 'No bag' + label_res.append(bag_label) + # upper +# 5)同一类属性,分为两组(这里是款式和花色),每小组内单独选择,相当于两组不同属性。 + upper_label = 'Upper:' + sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve' + upper_label += ' {}'.format(sleeve) + upper_res = res[4:8] + if np.max(upper_res) > self.threshold: + upper_label += ' {}'.format(upper_list[np.argmax(upper_res)]) + label_res.append(upper_label) + # lower + lower_res = res[8:14] + lower_label = 'Lower: ' + has_lower = False + for i, l in enumerate(lower_res): + if l > self.threshold: + lower_label += ' {}'.format(lower_list[i]) + has_lower = True + if not has_lower: + lower_label += ' {}'.format(lower_list[np.argmax(lower_res)]) + + label_res.append(lower_label) + # shoe + shoe = 'Boots' if res[14] > self.threshold else 'No boots' + label_res.append(shoe) + + batch_res.append(label_res) + result = {'output': batch_res} + return result +``` -- GitLab