未验证 提交 13759f95 编写于 作者: W wangguanzhong 提交者: GitHub

remove im_shape for evaluation (#1089)

上级 897d86ac
...@@ -122,6 +122,5 @@ class CascadeRCNN(BaseArch): ...@@ -122,6 +122,5 @@ class CascadeRCNN(BaseArch):
'bbox_nums': self.gbd['predicted_bbox_nums'].numpy(), 'bbox_nums': self.gbd['predicted_bbox_nums'].numpy(),
'mask': self.gbd['predicted_mask'].numpy(), 'mask': self.gbd['predicted_mask'].numpy(),
'im_id': self.gbd['im_id'].numpy(), 'im_id': self.gbd['im_id'].numpy(),
'im_shape': self.gbd['im_shape'].numpy()
} }
return inputs return inputs
...@@ -73,7 +73,6 @@ class FasterRCNN(BaseArch): ...@@ -73,7 +73,6 @@ class FasterRCNN(BaseArch):
outs = { outs = {
"bbox": self.gbd['predicted_bbox'].numpy(), "bbox": self.gbd['predicted_bbox'].numpy(),
"bbox_nums": self.gbd['predicted_bbox_nums'].numpy(), "bbox_nums": self.gbd['predicted_bbox_nums'].numpy(),
'im_id': self.gbd['im_id'].numpy(), 'im_id': self.gbd['im_id'].numpy()
'im_shape': self.gbd['im_shape'].numpy()
} }
return outs return outs
...@@ -96,7 +96,6 @@ class MaskRCNN(BaseArch): ...@@ -96,7 +96,6 @@ class MaskRCNN(BaseArch):
'bbox': self.gbd['predicted_bbox'].numpy(), 'bbox': self.gbd['predicted_bbox'].numpy(),
'bbox_nums': self.gbd['predicted_bbox_nums'].numpy(), 'bbox_nums': self.gbd['predicted_bbox_nums'].numpy(),
'mask': self.gbd['predicted_mask'].numpy(), 'mask': self.gbd['predicted_mask'].numpy(),
'im_id': self.gbd['im_id'].numpy(), 'im_id': self.gbd['im_id'].numpy()
'im_shape': self.gbd['im_shape'].numpy()
} }
return inputs return inputs
...@@ -48,6 +48,7 @@ class YOLOv3(BaseArch): ...@@ -48,6 +48,7 @@ class YOLOv3(BaseArch):
def infer(self, ): def infer(self, ):
outs = { outs = {
"bbox": self.gbd['predicted_bbox'].numpy(), "bbox": self.gbd['predicted_bbox'].numpy(),
"bbox_nums": self.gbd['predicted_bbox_nums'] "bbox_nums": self.gbd['predicted_bbox_nums'],
'im_id': self.gbd['im_id'].numpy()
} }
return outs return outs
...@@ -120,12 +120,7 @@ def mask_post_process(bbox_nums, bboxes, masks, im_info): ...@@ -120,12 +120,7 @@ def mask_post_process(bbox_nums, bboxes, masks, im_info):
@jit @jit
def get_det_res(bbox_nums, def get_det_res(bbox_nums, bbox, image_id, num_id_to_cat_id_map, batch_size=1):
bbox,
image_id,
image_shape,
num_id_to_cat_id_map,
batch_size=1):
det_res = [] det_res = []
bbox_v = np.array(bbox) bbox_v = np.array(bbox)
if bbox_v.shape == ( if bbox_v.shape == (
...@@ -139,8 +134,6 @@ def get_det_res(bbox_nums, ...@@ -139,8 +134,6 @@ def get_det_res(bbox_nums,
for i in range(batch_size): for i in range(batch_size):
dt_num_this_img = bbox_nums[i + 1] - bbox_nums[i] dt_num_this_img = bbox_nums[i + 1] - bbox_nums[i]
image_id = int(image_id[i][0]) image_id = int(image_id[i][0])
image_width = int(image_shape[i][1]) #int(data[i][-1][1])
image_height = int(image_shape[i][2]) #int(data[i][-1][2])
for j in range(dt_num_this_img): for j in range(dt_num_this_img):
dt = bbox_v[k] dt = bbox_v[k]
k = k + 1 k = k + 1
......
...@@ -2,6 +2,8 @@ from __future__ import absolute_import ...@@ -2,6 +2,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
def json_eval_results(metric, json_directory=None, dataset=None): def json_eval_results(metric, json_directory=None, dataset=None):
""" """
...@@ -39,14 +41,16 @@ def coco_eval_results(outs_res=None, ...@@ -39,14 +41,16 @@ def coco_eval_results(outs_res=None,
from ppdet.py_op.post_process import get_det_res, get_seg_res from ppdet.py_op.post_process import get_det_res, get_seg_res
anno_file = os.path.join(dataset.dataset_dir, dataset.anno_path) anno_file = os.path.join(dataset.dataset_dir, dataset.anno_path)
cocoGt = COCO(anno_file) cocoGt = COCO(anno_file)
catid = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())} catid = {
i + dataset.with_background: v
for i, v in enumerate(cocoGt.getCatIds())
}
if outs_res is not None and len(outs_res) > 0: if outs_res is not None and len(outs_res) > 0:
det_res = [] det_res = []
for outs in outs_res: for outs in outs_res:
det_res += get_det_res(outs['bbox_nums'], outs['bbox'], det_res += get_det_res(outs['bbox_nums'], outs['bbox'],
outs['im_id'], outs['im_shape'], catid, outs['im_id'], catid, batch_size)
batch_size)
with io.open("bbox_eval.json", 'w') as outfile: with io.open("bbox_eval.json", 'w') as outfile:
encode_func = unicode if six.PY2 else str encode_func = unicode if six.PY2 else str
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册