提交 939a35d6 编写于 作者: Z zhiboniu

add more details

上级 a4e1da66
...@@ -27,7 +27,7 @@ PreProcess: ...@@ -27,7 +27,7 @@ PreProcess:
PostProcess: PostProcess:
main_indicator: Attribute main_indicator: Attribute
Attribute: Attribute:
threshold: 0.5 threshold: 0.5 #default threshold
glasses_threshold: 0.3 glasses_threshold: 0.3 #threshold only for glasses
hold_threshold: 0.6 hold_threshold: 0.6 #threshold only for hold
\ No newline at end of file
...@@ -64,9 +64,17 @@ class ThreshOutput(object): ...@@ -64,9 +64,17 @@ class ThreshOutput(object):
for idx, probs in enumerate(x): for idx, probs in enumerate(x):
score = probs[1] score = probs[1]
if score < self.threshold: if score < self.threshold:
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]} result = {
"class_ids": [0],
"scores": [1 - score],
"label_names": [self.label_0]
}
else: else:
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]} result = {
"class_ids": [1],
"scores": [score],
"label_names": [self.label_1]
}
if file_names is not None: if file_names is not None:
result["file_name"] = file_names[idx] result["file_name"] = file_names[idx]
y.append(result) y.append(result)
...@@ -264,5 +272,11 @@ class Attribute(object): ...@@ -264,5 +272,11 @@ class Attribute(object):
shoe = 'Boots' if res[14] > self.threshold else 'No boots' shoe = 'Boots' if res[14] > self.threshold else 'No boots'
label_res.append(shoe) label_res.append(shoe)
batch_res.append(label_res) threshold_list = [0.5] * len(res)
threshold_list[1] = self.glasses_threshold
threshold_list[18] = self.hold_threshold
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append([label_res, pred_res])
return batch_res return batch_res
...@@ -140,9 +140,10 @@ def main(config): ...@@ -140,9 +140,10 @@ def main(config):
for number, result_dict in enumerate(batch_results): for number, result_dict in enumerate(batch_results):
if "Attribute" in config["PostProcess"]: if "Attribute" in config["PostProcess"]:
filename = batch_names[number] filename = batch_names[number]
attr_message = result_dict attr_message = result_dict[0]
print("{}:\tclass id(s): {}".format(filename, pred_res = result_dict[1]
attr_message)) print("{}:\t attributes: {}, \npredict output: {}".format(
filename, attr_message, pred_res))
else: else:
filename = batch_names[number] filename = batch_names[number]
clas_ids = result_dict["class_ids"] clas_ids = result_dict["class_ids"]
......
...@@ -391,6 +391,7 @@ class AccuracyScore(MultiLabelMetric): ...@@ -391,6 +391,7 @@ class AccuracyScore(MultiLabelMetric):
def get_attr_metrics(gt_label, preds_probs, threshold): def get_attr_metrics(gt_label, preds_probs, threshold):
""" """
index: evaluated label index index: evaluated label index
adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
""" """
pred_label = (preds_probs > threshold).astype(int) pred_label = (preds_probs > threshold).astype(int)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册