提交 f9534359 编写于 作者: C cuicheng01

fix format

上级 18acc25f
......@@ -322,16 +322,15 @@ class VehicleAttribute(object):
return batch_res
class TableAttribute(object):
def __init__(self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5,
):
def __init__(
self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5, ):
self.source_threshold = source_threshold
self.number_threshold = number_threshold
self.color_threshold = color_threshold
......@@ -342,19 +341,27 @@ class TableAttribute(object):
def __call__(self, batch_preds, file_names=None):
# postprocess output of predictor
batch_res = []
for res in batch_preds:
res = res.tolist()
label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor'
color = 'Black-and-White' if res[
2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted'
obstruction = 'Without-Obstacles' if res[
4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[
5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold]
threshold_list = [
self.source_threshold, self.number_threshold,
self.color_threshold, self.clarity_threshold,
self.obstruction_threshold, self.angle_threshold
]
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res})
......
......@@ -71,7 +71,6 @@ class VehicleAttribute(object):
return batch_res
class PersonAttribute(object):
def __init__(self,
threshold=0.5,
......@@ -173,14 +172,14 @@ class PersonAttribute(object):
class TableAttribute(object):
def __init__(self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5,
):
def __init__(
self,
source_threshold=0.5,
number_threshold=0.5,
color_threshold=0.5,
clarity_threshold=0.5,
obstruction_threshold=0.5,
angle_threshold=0.5, ):
self.source_threshold = source_threshold
self.number_threshold = number_threshold
self.color_threshold = color_threshold
......@@ -195,6 +194,7 @@ class TableAttribute(object):
if file_names is not None:
assert x.shape[0] == len(file_names)
x = F.sigmoid(x).numpy()
# postprocess output of predictor
batch_res = []
for idx, res in enumerate(x):
......@@ -202,16 +202,26 @@ class TableAttribute(object):
label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor'
color = 'Black-and-White' if res[
2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted'
obstruction = 'Without-Obstacles' if res[
4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[
5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold]
threshold_list = [
self.source_threshold, self.number_threshold,
self.color_threshold, self.clarity_threshold,
self.obstruction_threshold, self.angle_threshold
]
pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res, "file_name": file_names[idx]})
batch_res.append({
"attributes": label_res,
"output": pred_res,
"file_name": file_names[idx]
})
return batch_res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册