未验证 提交 ad71254e 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1960 from zhiboniu/dev_attr

support attribute infer
Global:
infer_imgs: "./images/Pedestrain_Attr.jpg"
inference_model_dir: "../inference/"
batch_size: 1
use_gpu: True
enable_mkldnn: False
cpu_num_threads: 10
enable_benchmark: True
use_fp16: False
ir_optim: True
use_tensorrt: False
gpu_mem: 8000
enable_profile: False
PreProcess:
transform_ops:
- ResizeImage:
size: [192, 256]
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
channel_num: 3
- ToCHWImage:
PostProcess:
main_indicator: Attribute
Attribute:
threshold: 0.5 #default threshold
glasses_threshold: 0.3 #threshold only for glasses
hold_threshold: 0.6 #threshold only for hold
\ No newline at end of file
...@@ -64,9 +64,17 @@ class ThreshOutput(object): ...@@ -64,9 +64,17 @@ class ThreshOutput(object):
for idx, probs in enumerate(x): for idx, probs in enumerate(x):
score = probs[1] score = probs[1]
if score < self.threshold: if score < self.threshold:
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]} result = {
"class_ids": [0],
"scores": [1 - score],
"label_names": [self.label_0]
}
else: else:
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]} result = {
"class_ids": [1],
"scores": [score],
"label_names": [self.label_1]
}
if file_names is not None: if file_names is not None:
result["file_name"] = file_names[idx] result["file_name"] = file_names[idx]
y.append(result) y.append(result)
...@@ -179,3 +187,96 @@ class Binarize(object): ...@@ -179,3 +187,96 @@ class Binarize(object):
byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit) byte[:, i:i + 1] = np.dot(x[:, i * 8:(i + 1) * 8], self.unit)
return byte return byte
class Attribute(object):
def __init__(self,
threshold=0.5,
glasses_threshold=0.3,
hold_threshold=0.6):
self.threshold = threshold
self.glasses_threshold = glasses_threshold
self.hold_threshold = hold_threshold
def __call__(self, batch_preds, file_names=None):
# postprocess output of predictor
age_list = ['AgeLess18', 'Age18-60', 'AgeOver60']
direct_list = ['Front', 'Side', 'Back']
bag_list = ['HandBag', 'ShoulderBag', 'Backpack']
upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice']
lower_list = [
'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts',
'Skirt&Dress'
]
batch_res = []
for res in batch_preds:
res = res.tolist()
label_res = []
# gender
gender = 'Female' if res[22] > self.threshold else 'Male'
label_res.append(gender)
# age
age = age_list[np.argmax(res[19:22])]
label_res.append(age)
# direction
direction = direct_list[np.argmax(res[23:])]
label_res.append(direction)
# glasses
glasses = 'Glasses: '
if res[1] > self.glasses_threshold:
glasses += 'True'
else:
glasses += 'False'
label_res.append(glasses)
# hat
hat = 'Hat: '
if res[0] > self.threshold:
hat += 'True'
else:
hat += 'False'
label_res.append(hat)
# hold obj
hold_obj = 'HoldObjectsInFront: '
if res[18] > self.hold_threshold:
hold_obj += 'True'
else:
hold_obj += 'False'
label_res.append(hold_obj)
# bag
bag = bag_list[np.argmax(res[15:18])]
bag_score = res[15 + np.argmax(res[15:18])]
bag_label = bag if bag_score > self.threshold else 'No bag'
label_res.append(bag_label)
# upper
upper_res = res[4:8]
upper_label = 'Upper:'
sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve'
upper_label += ' {}'.format(sleeve)
for i, r in enumerate(upper_res):
if r > self.threshold:
upper_label += ' {}'.format(upper_list[i])
label_res.append(upper_label)
# lower
lower_res = res[8:14]
lower_label = 'Lower: '
has_lower = False
for i, l in enumerate(lower_res):
if l > self.threshold:
lower_label += ' {}'.format(lower_list[i])
has_lower = True
if not has_lower:
lower_label += ' {}'.format(lower_list[np.argmax(lower_res)])
label_res.append(lower_label)
# shoe
shoe = 'Boots' if res[14] > self.threshold else 'No boots'
label_res.append(shoe)
threshold_list = [0.5] * len(res)
threshold_list[1] = self.glasses_threshold
threshold_list[18] = self.hold_threshold
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append([label_res, pred_res])
return batch_res
...@@ -138,13 +138,21 @@ def main(config): ...@@ -138,13 +138,21 @@ def main(config):
continue continue
batch_results = cls_predictor.predict(batch_imgs) batch_results = cls_predictor.predict(batch_imgs)
for number, result_dict in enumerate(batch_results): for number, result_dict in enumerate(batch_results):
filename = batch_names[number] if "Attribute" in config["PostProcess"]:
clas_ids = result_dict["class_ids"] filename = batch_names[number]
scores_str = "[{}]".format(", ".join("{:.2f}".format( attr_message = result_dict[0]
r) for r in result_dict["scores"])) pred_res = result_dict[1]
label_names = result_dict["label_names"] print("{}:\t attributes: {}, \npredict output: {}".format(
print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". filename, attr_message, pred_res))
format(filename, clas_ids, scores_str, label_names)) else:
filename = batch_names[number]
clas_ids = result_dict["class_ids"]
scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"]))
label_names = result_dict["label_names"]
print(
"{}:\tclass id(s): {}, score(s): {}, label_name(s): {}".
format(filename, clas_ids, scores_str, label_names))
batch_imgs = [] batch_imgs = []
batch_names = [] batch_names = []
if cls_predictor.benchmark: if cls_predictor.benchmark:
......
...@@ -137,8 +137,11 @@ class ConvBNLayer(TheseusLayer): ...@@ -137,8 +137,11 @@ class ConvBNLayer(TheseusLayer):
weight_attr = ParamAttr(learning_rate=lr_mult, trainable=True) weight_attr = ParamAttr(learning_rate=lr_mult, trainable=True)
bias_attr = ParamAttr(learning_rate=lr_mult, trainable=True) bias_attr = ParamAttr(learning_rate=lr_mult, trainable=True)
self.bn = BatchNorm2D( self.bn = BatchNorm(
num_filters, weight_attr=weight_attr, bias_attr=bias_attr) num_filters,
param_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult),
data_layout=data_format)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, x): def forward(self, x):
...@@ -287,7 +290,8 @@ class ResNet(TheseusLayer): ...@@ -287,7 +290,8 @@ class ResNet(TheseusLayer):
data_format="NCHW", data_format="NCHW",
input_image_channel=3, input_image_channel=3,
return_patterns=None, return_patterns=None,
return_stages=None): return_stages=None,
**kargs):
super().__init__() super().__init__()
self.cfg = config self.cfg = config
......
...@@ -20,6 +20,7 @@ Arch: ...@@ -20,6 +20,7 @@ Arch:
name: "ResNet50" name: "ResNet50"
pretrained: True pretrained: True
class_num: 26 class_num: 26
infer_add_softmax: False
# loss function config for traing/eval process # loss function config for traing/eval process
Loss: Loss:
...@@ -110,5 +111,3 @@ DataLoader: ...@@ -110,5 +111,3 @@ DataLoader:
Metric: Metric:
Eval: Eval:
- ATTRMetric: - ATTRMetric:
...@@ -458,7 +458,9 @@ class Engine(object): ...@@ -458,7 +458,9 @@ class Engine(object):
def export(self): def export(self):
assert self.mode == "export" assert self.mode == "export"
use_multilabel = self.config["Global"].get("use_multilabel", False) use_multilabel = self.config["Global"].get(
"use_multilabel",
False) and not "ATTRMetric" in self.config["Metric"]["Eval"][0]
model = ExportModel(self.config["Arch"], self.model, use_multilabel) model = ExportModel(self.config["Arch"], self.model, use_multilabel)
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model, load_dygraph_pretrain(model.base_model,
......
...@@ -390,6 +390,7 @@ class AccuracyScore(MultiLabelMetric): ...@@ -390,6 +390,7 @@ class AccuracyScore(MultiLabelMetric):
def get_attr_metrics(gt_label, preds_probs, threshold): def get_attr_metrics(gt_label, preds_probs, threshold):
""" """
index: evaluated label index index: evaluated label index
adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
""" """
pred_label = (preds_probs > threshold).astype(int) pred_label = (preds_probs > threshold).astype(int)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册