提交 f9534359 编写于 作者: C cuicheng01

fix format

上级 18acc25f
...@@ -322,16 +322,15 @@ class VehicleAttribute(object): ...@@ -322,16 +322,15 @@ class VehicleAttribute(object):
return batch_res return batch_res
class TableAttribute(object): class TableAttribute(object):
def __init__(self, def __init__(
source_threshold=0.5, self,
number_threshold=0.5, source_threshold=0.5,
color_threshold=0.5, number_threshold=0.5,
clarity_threshold=0.5, color_threshold=0.5,
obstruction_threshold=0.5, clarity_threshold=0.5,
angle_threshold=0.5, obstruction_threshold=0.5,
): angle_threshold=0.5, ):
self.source_threshold = source_threshold self.source_threshold = source_threshold
self.number_threshold = number_threshold self.number_threshold = number_threshold
self.color_threshold = color_threshold self.color_threshold = color_threshold
...@@ -342,19 +341,27 @@ class TableAttribute(object): ...@@ -342,19 +341,27 @@ class TableAttribute(object):
def __call__(self, batch_preds, file_names=None): def __call__(self, batch_preds, file_names=None):
# postprocess output of predictor # postprocess output of predictor
batch_res = [] batch_res = []
for res in batch_preds: for res in batch_preds:
res = res.tolist() res = res.tolist()
label_res = [] label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo' source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous' number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor' color = 'Black-and-White' if res[
2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry' clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles' obstruction = 'Without-Obstacles' if res[
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted' 4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[
5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle] label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold] threshold_list = [
self.source_threshold, self.number_threshold,
self.color_threshold, self.clarity_threshold,
self.obstruction_threshold, self.angle_threshold
]
pred_res = (np.array(res) > np.array(threshold_list) pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist() ).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res}) batch_res.append({"attributes": label_res, "output": pred_res})
......
...@@ -71,7 +71,6 @@ class VehicleAttribute(object): ...@@ -71,7 +71,6 @@ class VehicleAttribute(object):
return batch_res return batch_res
class PersonAttribute(object): class PersonAttribute(object):
def __init__(self, def __init__(self,
threshold=0.5, threshold=0.5,
...@@ -173,14 +172,14 @@ class PersonAttribute(object): ...@@ -173,14 +172,14 @@ class PersonAttribute(object):
class TableAttribute(object): class TableAttribute(object):
def __init__(self, def __init__(
source_threshold=0.5, self,
number_threshold=0.5, source_threshold=0.5,
color_threshold=0.5, number_threshold=0.5,
clarity_threshold=0.5, color_threshold=0.5,
obstruction_threshold=0.5, clarity_threshold=0.5,
angle_threshold=0.5, obstruction_threshold=0.5,
): angle_threshold=0.5, ):
self.source_threshold = source_threshold self.source_threshold = source_threshold
self.number_threshold = number_threshold self.number_threshold = number_threshold
self.color_threshold = color_threshold self.color_threshold = color_threshold
...@@ -195,6 +194,7 @@ class TableAttribute(object): ...@@ -195,6 +194,7 @@ class TableAttribute(object):
if file_names is not None: if file_names is not None:
assert x.shape[0] == len(file_names) assert x.shape[0] == len(file_names)
x = F.sigmoid(x).numpy() x = F.sigmoid(x).numpy()
# postprocess output of predictor # postprocess output of predictor
batch_res = [] batch_res = []
for idx, res in enumerate(x): for idx, res in enumerate(x):
...@@ -202,16 +202,26 @@ class TableAttribute(object): ...@@ -202,16 +202,26 @@ class TableAttribute(object):
label_res = [] label_res = []
source = 'Scanned' if res[0] > self.source_threshold else 'Photo' source = 'Scanned' if res[0] > self.source_threshold else 'Photo'
number = 'Little' if res[1] > self.number_threshold else 'Numerous' number = 'Little' if res[1] > self.number_threshold else 'Numerous'
color = 'Black-and-White' if res[2] > self.color_threshold else 'Multicolor' color = 'Black-and-White' if res[
2] > self.color_threshold else 'Multicolor'
clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry' clarity = 'Clear' if res[3] > self.clarity_threshold else 'Blurry'
obstruction = 'Without-Obstacles' if res[4] > self.number_threshold else 'With-Obstacles' obstruction = 'Without-Obstacles' if res[
angle = 'Horizontal' if res[5] > self.number_threshold else 'Tilted' 4] > self.number_threshold else 'With-Obstacles'
angle = 'Horizontal' if res[
5] > self.number_threshold else 'Tilted'
label_res = [source, number, color, clarity, obstruction, angle] label_res = [source, number, color, clarity, obstruction, angle]
threshold_list = [self.source_threshold, self.number_threshold, self.color_threshold, self.clarity_threshold, self.obstruction_threshold, self.angle_threshold] threshold_list = [
self.source_threshold, self.number_threshold,
self.color_threshold, self.clarity_threshold,
self.obstruction_threshold, self.angle_threshold
]
pred_res = (np.array(res) > np.array(threshold_list) pred_res = (np.array(res) > np.array(threshold_list)
).astype(np.int8).tolist() ).astype(np.int8).tolist()
batch_res.append({"attributes": label_res, "output": pred_res, "file_name": file_names[idx]}) batch_res.append({
"attributes": label_res,
"output": pred_res,
"file_name": file_names[idx]
})
return batch_res return batch_res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册