提交 939a35d6 编写于 作者: Z zhiboniu

add more details

上级 a4e1da66
......@@ -27,7 +27,7 @@ PreProcess:
PostProcess:
main_indicator: Attribute
Attribute:
threshold: 0.5
glasses_threshold: 0.3
hold_threshold: 0.6
threshold: 0.5 #default threshold
glasses_threshold: 0.3 #threshold only for glasses
hold_threshold: 0.6 #threshold only for hold
\ No newline at end of file
......@@ -64,9 +64,17 @@ class ThreshOutput(object):
for idx, probs in enumerate(x):
score = probs[1]
if score < self.threshold:
result = {"class_ids": [0], "scores": [1 - score], "label_names": [self.label_0]}
result = {
"class_ids": [0],
"scores": [1 - score],
"label_names": [self.label_0]
}
else:
result = {"class_ids": [1], "scores": [score], "label_names": [self.label_1]}
result = {
"class_ids": [1],
"scores": [score],
"label_names": [self.label_1]
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
......@@ -264,5 +272,11 @@ class Attribute(object):
shoe = 'Boots' if res[14] > self.threshold else 'No boots'
label_res.append(shoe)
batch_res.append(label_res)
threshold_list = [0.5] * len(res)
threshold_list[1] = self.glasses_threshold
threshold_list[18] = self.hold_threshold
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append([label_res, pred_res])
return batch_res
......@@ -140,9 +140,10 @@ def main(config):
for number, result_dict in enumerate(batch_results):
if "Attribute" in config["PostProcess"]:
filename = batch_names[number]
attr_message = result_dict
print("{}:\tclass id(s): {}".format(filename,
attr_message))
attr_message = result_dict[0]
pred_res = result_dict[1]
print("{}:\t attributes: {}, \npredict output: {}".format(
filename, attr_message, pred_res))
else:
filename = batch_names[number]
clas_ids = result_dict["class_ids"]
......
......@@ -391,6 +391,7 @@ class AccuracyScore(MultiLabelMetric):
def get_attr_metrics(gt_label, preds_probs, threshold):
"""
index: evaluated label index
adapted from "https://github.com/valencebond/Rethinking_of_PAR/blob/master/metrics/pedestrian_metrics.py"
"""
pred_label = (preds_probs > threshold).astype(int)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册