提交 84e25e9f 编写于 作者: F FlyingQianMM

rewrite postprocess for rcnn predict

上级 5b145d45
......@@ -409,14 +409,7 @@ class FasterRCNN(BaseAPI):
return im, im_resize_info, im_shape
@staticmethod
def _postprocess(results, test_outputs_keys, batch_size, num_classes,
labels):
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
def _postprocess(res, batch_size, num_classes, labels):
clsid2catid = dict({i: i for i in range(num_classes)})
xywh_results = bbox2out([res], clsid2catid)
preds = [[] for i in range(batch_size)]
......@@ -463,8 +456,13 @@ class FasterRCNN(BaseAPI):
return_numpy=False,
use_program_cache=True)
preds = FasterRCNN._postprocess(result,
list(self.test_outputs.keys()),
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
preds = FasterRCNN._postprocess(res,
len(images), self.num_classes,
self.labels)
......@@ -507,8 +505,13 @@ class FasterRCNN(BaseAPI):
return_numpy=False,
use_program_cache=True)
preds = FasterRCNN._postprocess(result,
list(self.test_outputs.keys()),
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
preds = FasterRCNN._postprocess(res,
len(img_file_list), self.num_classes,
self.labels)
......
......@@ -338,15 +338,8 @@ class MaskRCNN(FasterRCNN):
return metrics
@staticmethod
def _postprocess(results, im_shape, test_outputs_keys, batch_size,
num_classes, mask_head_resolution, labels):
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
res['im_shape'] = (np.array(im_shape), [])
def _postprocess(res, batch_size, num_classes, mask_head_resolution,
labels):
clsid2catid = dict({i: i for i in range(num_classes)})
xywh_results = bbox2out([res], clsid2catid)
segm_results = mask2out([res], clsid2catid, mask_head_resolution)
......@@ -398,8 +391,14 @@ class MaskRCNN(FasterRCNN):
return_numpy=False,
use_program_cache=True)
preds = MaskRCNN._postprocess(result, im_shape,
list(self.test_outputs.keys()),
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
res['im_shape'] = (np.array(im_shape), [])
preds = MaskRCNN._postprocess(res,
len(images), self.num_classes,
self.mask_head_resolution, self.labels)
......@@ -442,9 +441,14 @@ class MaskRCNN(FasterRCNN):
return_numpy=False,
use_program_cache=True)
preds = MaskRCNN._postprocess(result, im_shape,
list(self.test_outputs.keys()),
len(img_file_list), self.num_classes,
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
res['im_shape'] = (np.array(im_shape), [])
preds = MaskRCNN._postprocess(res,
len(images), self.num_classes,
self.mask_head_resolution, self.labels)
return preds
......@@ -155,23 +155,42 @@ class Predictor:
res['im_info'] = im_info
return res
def postprocess(self, results, topk=1, batch_size=1, im_shape=None):
def postprocess(self,
results,
topk=1,
batch_size=1,
im_shape=None,
im_info=None):
def offset_to_lengths(lod):
offset = lod[0]
lengths = [
offset[i + 1] - offset[i] for i in range(len(offset) - 1)
]
return [lengths]
if self.model_type == "classifier":
true_topk = min(self.num_classes, topk)
preds = BaseClassifier._postprocess(results, true_topk,
preds = BaseClassifier._postprocess([results[0][0]], true_topk,
self.labels)
elif self.model_type == "detector":
res = {'bbox': (results[0][0], offset_to_lengths(results[0][1])), }
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [[]])
if self.model_name == "YOLOv3":
preds = YOLOv3._postprocess(results, ['bbox'], batch_size,
self.num_classes, self.labels)
preds = YOLOv3._postprocess(res, batch_size, self.num_classes,
self.labels)
elif self.model_name == "FasterRCNN":
preds = FasterRCNN._postprocess(results, ['bbox'], batch_size,
preds = FasterRCNN._postprocess(res, batch_size,
self.num_classes, self.labels)
elif self.model_name == "MaskRCNN":
res['mask'] = (results[1][0], offset_to_lengths(results[1][1]))
res['im_shape'] = (im_shape, [])
preds = MaskRCNN._postprocess(
results, ['bbox', 'mask'], batch_size, self.num_classes,
res, batch_size, self.num_classes,
self.mask_head_resolution, self.labels)
elif self.model_type == "segmenter":
res = [results[0][0], results[1][0]]
preds = DeepLabv3p._postprocess(res, im_info)
return preds
def raw_predict(self, inputs):
......@@ -191,7 +210,9 @@ class Predictor:
output_results = list()
for name in output_names:
output_tensor = self.predictor.get_output_tensor(name)
output_results.append(output_tensor.copy_to_cpu())
output_tensor_lod = output_tensor.lod()
output_results.append(
[output_tensor.copy_to_cpu(), output_tensor_lod])
return output_results
def predict(self, image, topk=1):
......@@ -207,8 +228,14 @@ class Predictor:
model_pred = self.raw_predict(preprocessed_input)
im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
'im_shape']
im_info = None if 'im_info' not in preprocessed_input else preprocessed_input[
'im_info']
results = self.postprocess(
model_pred, topk=topk, batch_size=1, im_shape=im_shape)
model_pred,
topk=topk,
batch_size=1,
im_shape=im_shape,
im_info=im_info)
return results[0]
......@@ -223,9 +250,15 @@ class Predictor:
"""
preprocessed_input = self.preprocess(image_list)
model_pred = self.raw_predict(preprocessed_input)
im_shape = None if 'im_shape' in preprocessed_input else preprocessed_input[
im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
'im_shape']
im_info = None if 'im_info' not in preprocessed_input else preprocessed_input[
'im_info']
results = self.postprocess(
model_pred, topk=topk, batch_size=1, im_shape=im_shape)
model_pred,
topk=topk,
batch_size=len(image_list),
im_shape=im_shape,
im_info=im_info)
return results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册