提交 9cf1abae 编写于 作者: Z zhiboniu

support attr infer

上级 4091592c
Global:
infer_imgs: "./images/Pedestrain_Attr.jpg"
inference_model_dir: "../inference/"
batch_size: 1
use_gpu: True
enable_mkldnn: False
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
size: [192, 256]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: Attribute
Attribute:
threshold: 0.5
glasses_threshold: 0.3
hold_threshold: 0.6
\ No newline at end of file
......@@ -179,3 +179,90 @@ class Binarize(object):
byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit)
return byte
class Attribute(object):
def __init__(self,
threshold=0.5,
glasses_threshold=0.3,
hold_threshold=0.6):
self.threshold = threshold
self.glasses_threshold = glasses_threshold
self.hold_threshold = hold_threshold
def __call__(self, batch_preds, file_names=None):
# postprocess output of predictor
age_list = ['AgeLess18', 'Age18-60', 'AgeOver60']
direct_list = ['Front', 'Side', 'Back']
bag_list = ['HandBag', 'ShoulderBag', 'Backpack']
upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice']
lower_list = [
'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts',
'Skirt&Dress'
]
batch_res = []
for res in batch_preds:
res = res.tolist()
label_res = []
# gender
gender = 'Female' if res[22] > self.threshold else 'Male'
label_res.append(gender)
# age
age = age_list[np.argmax(res[19:22])]
label_res.append(age)
# direction
direction = direct_list[np.argmax(res[23:])]
label_res.append(direction)
# glasses
glasses = 'Glasses: '
if res[1] > self.glasses_threshold:
glasses += 'True'
else:
glasses += 'False'
label_res.append(glasses)
# hat
hat = 'Hat: '
if res[0] > self.threshold:
hat += 'True'
else:
hat += 'False'
label_res.append(hat)
# hold obj
hold_obj = 'HoldObjectsInFront: '
if res[18] > self.hold_threshold:
hold_obj += 'True'
else:
hold_obj += 'False'
label_res.append(hold_obj)
# bag
bag = bag_list[np.argmax(res[15:18])]
bag_score = res[15 + np.argmax(res[15:18])]
bag_label = bag if bag_score > self.threshold else 'No bag'
label_res.append(bag_label)
# upper
upper_res = res[4:8]
upper_label = 'Upper:'
sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve'
upper_label += ' {}'.format(sleeve)
for i, r in enumerate(upper_res):
if r > self.threshold:
upper_label += ' {}'.format(upper_list[i])
label_res.append(upper_label)
# lower
lower_res = res[8:14]
lower_label = 'Lower: '
has_lower = False
for i, l in enumerate(lower_res):
if l > self.threshold:
lower_label += ' {}'.format(lower_list[i])
has_lower = True
if not has_lower:
lower_label += ' {}'.format(lower_list[np.argmax(lower_res)])
label_res.append(lower_label)
# shoe
shoe = 'Boots' if res[14] > self.threshold else 'No boots'
label_res.append(shoe)
batch_res.append(label_res)
return batch_res
......@@ -138,12 +138,19 @@ def main(config):
continue
batch_results = cls_predictor.predict(batch_imgs)
for number, result_dict in enumerate(batch_results):
if "Attribute" in config["PostProcess"]:
filename = batch_names[number]
attr_message = result_dict
print("{}:\tclass id(s): {}".format(filename,
attr_message))
else:
filename = batch_names[number]
clas_ids = result_dict["class_ids"]
scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"]))
label_names = result_dict["label_names"]
print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}".
print(
"{}:\tclass id(s): {}, score(s): {}, label_name(s): {}".
format(filename, clas_ids, scores_str, label_names))
batch_imgs = []
batch_names = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册