提交 84e25e9f 编写于 作者: F FlyingQianMM

rewrite postprocess for rcnn predict

上级 5b145d45
...@@ -409,14 +409,7 @@ class FasterRCNN(BaseAPI): ...@@ -409,14 +409,7 @@ class FasterRCNN(BaseAPI):
return im, im_resize_info, im_shape return im, im_resize_info, im_shape
@staticmethod @staticmethod
def _postprocess(results, test_outputs_keys, batch_size, num_classes, def _postprocess(res, batch_size, num_classes, labels):
labels):
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
clsid2catid = dict({i: i for i in range(num_classes)}) clsid2catid = dict({i: i for i in range(num_classes)})
xywh_results = bbox2out([res], clsid2catid) xywh_results = bbox2out([res], clsid2catid)
preds = [[] for i in range(batch_size)] preds = [[] for i in range(batch_size)]
...@@ -463,8 +456,13 @@ class FasterRCNN(BaseAPI): ...@@ -463,8 +456,13 @@ class FasterRCNN(BaseAPI):
return_numpy=False, return_numpy=False,
use_program_cache=True) use_program_cache=True)
preds = FasterRCNN._postprocess(result, res = {
list(self.test_outputs.keys()), k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
preds = FasterRCNN._postprocess(res,
len(images), self.num_classes, len(images), self.num_classes,
self.labels) self.labels)
...@@ -507,8 +505,13 @@ class FasterRCNN(BaseAPI): ...@@ -507,8 +505,13 @@ class FasterRCNN(BaseAPI):
return_numpy=False, return_numpy=False,
use_program_cache=True) use_program_cache=True)
preds = FasterRCNN._postprocess(result, res = {
list(self.test_outputs.keys()), k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
preds = FasterRCNN._postprocess(res,
len(img_file_list), self.num_classes, len(img_file_list), self.num_classes,
self.labels) self.labels)
......
...@@ -338,15 +338,8 @@ class MaskRCNN(FasterRCNN): ...@@ -338,15 +338,8 @@ class MaskRCNN(FasterRCNN):
return metrics return metrics
@staticmethod @staticmethod
def _postprocess(results, im_shape, test_outputs_keys, batch_size, def _postprocess(res, batch_size, num_classes, mask_head_resolution,
num_classes, mask_head_resolution, labels): labels):
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
res['im_shape'] = (np.array(im_shape), [])
clsid2catid = dict({i: i for i in range(num_classes)}) clsid2catid = dict({i: i for i in range(num_classes)})
xywh_results = bbox2out([res], clsid2catid) xywh_results = bbox2out([res], clsid2catid)
segm_results = mask2out([res], clsid2catid, mask_head_resolution) segm_results = mask2out([res], clsid2catid, mask_head_resolution)
...@@ -398,8 +391,14 @@ class MaskRCNN(FasterRCNN): ...@@ -398,8 +391,14 @@ class MaskRCNN(FasterRCNN):
return_numpy=False, return_numpy=False,
use_program_cache=True) use_program_cache=True)
preds = MaskRCNN._postprocess(result, im_shape, res = {
list(self.test_outputs.keys()), k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
res['im_shape'] = (np.array(im_shape), [])
preds = MaskRCNN._postprocess(res,
len(images), self.num_classes, len(images), self.num_classes,
self.mask_head_resolution, self.labels) self.mask_head_resolution, self.labels)
...@@ -442,9 +441,14 @@ class MaskRCNN(FasterRCNN): ...@@ -442,9 +441,14 @@ class MaskRCNN(FasterRCNN):
return_numpy=False, return_numpy=False,
use_program_cache=True) use_program_cache=True)
preds = MaskRCNN._postprocess(result, im_shape, res = {
list(self.test_outputs.keys()), k: (np.array(v), v.recursive_sequence_lengths())
len(img_file_list), self.num_classes, for k, v in zip(list(test_outputs_keys), results)
}
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [])
res['im_shape'] = (np.array(im_shape), [])
preds = MaskRCNN._postprocess(res,
len(images), self.num_classes,
self.mask_head_resolution, self.labels) self.mask_head_resolution, self.labels)
return preds return preds
...@@ -155,23 +155,42 @@ class Predictor: ...@@ -155,23 +155,42 @@ class Predictor:
res['im_info'] = im_info res['im_info'] = im_info
return res return res
def postprocess(self, results, topk=1, batch_size=1, im_shape=None): def postprocess(self,
results,
topk=1,
batch_size=1,
im_shape=None,
im_info=None):
def offset_to_lengths(lod):
offset = lod[0]
lengths = [
offset[i + 1] - offset[i] for i in range(len(offset) - 1)
]
return [lengths]
if self.model_type == "classifier": if self.model_type == "classifier":
true_topk = min(self.num_classes, topk) true_topk = min(self.num_classes, topk)
preds = BaseClassifier._postprocess(results, true_topk, preds = BaseClassifier._postprocess([results[0][0]], true_topk,
self.labels) self.labels)
elif self.model_type == "detector": elif self.model_type == "detector":
res = {'bbox': (results[0][0], offset_to_lengths(results[0][1])), }
res['im_id'] = (np.array(
[[i] for i in range(batch_size)]).astype('int32'), [[]])
if self.model_name == "YOLOv3": if self.model_name == "YOLOv3":
preds = YOLOv3._postprocess(results, ['bbox'], batch_size, preds = YOLOv3._postprocess(res, batch_size, self.num_classes,
self.num_classes, self.labels) self.labels)
elif self.model_name == "FasterRCNN": elif self.model_name == "FasterRCNN":
preds = FasterRCNN._postprocess(results, ['bbox'], batch_size, preds = FasterRCNN._postprocess(res, batch_size,
self.num_classes, self.labels) self.num_classes, self.labels)
elif self.model_name == "MaskRCNN": elif self.model_name == "MaskRCNN":
res['mask'] = (results[1][0], offset_to_lengths(results[1][1]))
res['im_shape'] = (im_shape, [])
preds = MaskRCNN._postprocess( preds = MaskRCNN._postprocess(
results, ['bbox', 'mask'], batch_size, self.num_classes, res, batch_size, self.num_classes,
self.mask_head_resolution, self.labels) self.mask_head_resolution, self.labels)
elif self.model_type == "segmenter":
res = [results[0][0], results[1][0]]
preds = DeepLabv3p._postprocess(res, im_info)
return preds return preds
def raw_predict(self, inputs): def raw_predict(self, inputs):
...@@ -191,7 +210,9 @@ class Predictor: ...@@ -191,7 +210,9 @@ class Predictor:
output_results = list() output_results = list()
for name in output_names: for name in output_names:
output_tensor = self.predictor.get_output_tensor(name) output_tensor = self.predictor.get_output_tensor(name)
output_results.append(output_tensor.copy_to_cpu()) output_tensor_lod = output_tensor.lod()
output_results.append(
[output_tensor.copy_to_cpu(), output_tensor_lod])
return output_results return output_results
def predict(self, image, topk=1): def predict(self, image, topk=1):
...@@ -207,8 +228,14 @@ class Predictor: ...@@ -207,8 +228,14 @@ class Predictor:
model_pred = self.raw_predict(preprocessed_input) model_pred = self.raw_predict(preprocessed_input)
im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[ im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
'im_shape'] 'im_shape']
im_info = None if 'im_info' not in preprocessed_input else preprocessed_input[
'im_info']
results = self.postprocess( results = self.postprocess(
model_pred, topk=topk, batch_size=1, im_shape=im_shape) model_pred,
topk=topk,
batch_size=1,
im_shape=im_shape,
im_info=im_info)
return results[0] return results[0]
...@@ -223,9 +250,15 @@ class Predictor: ...@@ -223,9 +250,15 @@ class Predictor:
""" """
preprocessed_input = self.preprocess(image_list) preprocessed_input = self.preprocess(image_list)
model_pred = self.raw_predict(preprocessed_input) model_pred = self.raw_predict(preprocessed_input)
im_shape = None if 'im_shape' in preprocessed_input else preprocessed_input[ im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
'im_shape'] 'im_shape']
im_info = None if 'im_info' not in preprocessed_input else preprocessed_input[
'im_info']
results = self.postprocess( results = self.postprocess(
model_pred, topk=topk, batch_size=1, im_shape=im_shape) model_pred,
topk=topk,
batch_size=len(image_list),
im_shape=im_shape,
im_info=im_info)
return results return results
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册