提交 4af0d9bb 编写于 作者: G gaotingquan 提交者: cuicheng01

support ThreshOutput for binary classification

上级 6e22efcd
......@@ -80,46 +80,71 @@ class ThreshOutput(object):
def __init__(self,
threshold=0,
default_label_index=0,
label_0="0",
label_1="1",
class_id_map_file=None,
delimiter=None,
label_0=None,
label_1=None):
delimiter=None):
self.threshold = threshold
self.default_label_index = default_label_index
self.label_0 = label_0
self.label_1 = label_1
delimiter = delimiter if delimiter is not None else " "
self.class_id_map = parse_class_id_map(class_id_map_file, delimiter)
if label_0 is not None or label_1 is not None:
print(
"[WARNING] The arguments \"label_0\" and \"label_1\" have been deprecated. Please use \"default_label_index\" instead."
)
def __call__(self, x, file_names=None):
def binary_classification(x):
y = []
for idx, probs in enumerate(x):
score = probs[1]
if score < self.threshold:
result = {
"class_ids": [0],
"scores": [1 - score],
"label_names": [self.label_0]
}
else:
result = {
"class_ids": [1],
"scores": [score],
"label_names": [self.label_1]
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
def multi_classification(x):
y = []
for idx, probs in enumerate(x):
index = probs.argsort(axis=0)[::-1].astype("int32")
top1_id = index[0]
top1_score = probs[top1_id]
if top1_score > self.threshold:
rtn_id = top1_id
else:
rtn_id = self.default_label_index
label_name = self.class_id_map[
rtn_id] if self.class_id_map is not None else ""
result = {
"class_ids": [rtn_id],
"scores": [probs[rtn_id]],
"label_names": [label_name]
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
if file_names is not None:
assert x.shape[0] == len(file_names)
y = []
for idx, probs in enumerate(x):
index = probs.argsort(axis=0)[::-1].astype("int32")
top1_id = index[0]
top1_score = probs[top1_id]
if top1_score > self.threshold:
rtn_id = top1_id
else:
rtn_id = self.default_label_index
label_name = self.class_id_map[
rtn_id] if self.class_id_map is not None else ""
result = {
"class_ids": [rtn_id],
"scores": [probs[rtn_id]],
"label_names": [label_name]
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
return y
if x.shape[1] == 2:
return binary_classification(x)
else:
return multi_classification(x)
class Topk(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册