提交 78a4d0bf 编写于 作者: W wangguanzhong 提交者: GitHub

fix rcnn eval on voc (#3344)

* fix rcnn eval on voc

* update comment
上级 28e19603
...@@ -30,8 +30,8 @@ from ppdet.data.transform.operators import ( ...@@ -30,8 +30,8 @@ from ppdet.data.transform.operators import (
Permute) Permute)
from ppdet.data.transform.arrange_sample import ( from ppdet.data.transform.arrange_sample import (
ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD, ArrangeTestSSD, ArrangeRCNN, ArrangeEvalRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD,
ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO) ArrangeTestSSD, ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO)
__all__ = [ __all__ = [
'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet', 'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet',
...@@ -476,7 +476,8 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -476,7 +476,8 @@ class FasterRCNNEvalFeed(DataFeed):
def __init__(self, def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION, dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[3, 800, 1333], image_shape=[3, 800, 1333],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
...@@ -494,7 +495,7 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -494,7 +495,7 @@ class FasterRCNNEvalFeed(DataFeed):
drop_last=False, drop_last=False,
num_workers=2, num_workers=2,
use_padded_im_info=True): use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN()) sample_transforms.append(ArrangeEvalRCNN())
super(FasterRCNNEvalFeed, self).__init__( super(FasterRCNNEvalFeed, self).__init__(
dataset, dataset,
fields, fields,
......
...@@ -90,6 +90,47 @@ class ArrangeRCNN(BaseOperator): ...@@ -90,6 +90,47 @@ class ArrangeRCNN(BaseOperator):
return outs return outs
@register_op
class ArrangeEvalRCNN(BaseOperator):
"""
Transform dict to the tuple format needed for evaluation.
"""
def __init__(self):
super(ArrangeEvalRCNN, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_info, im_id, im_shape, gt_bbox,
gt_class, difficult)
"""
im = sample['image']
keys = list(sample.keys())
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
# For rcnn models in eval and infer stage, original image size
# is needed to clip the bounding boxes. And box clip op in
# bbox prediction needs im_info as input in format of [N, 3],
# so im_shape is appended by 1 to match dimension.
im_shape = np.array((h, w, 1), dtype=np.float32)
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
outs = (im, im_info, im_id, im_shape, gt_bbox, gt_class, difficult)
return outs
@register_op @register_op
class ArrangeTestRCNN(BaseOperator): class ArrangeTestRCNN(BaseOperator):
""" """
...@@ -152,6 +193,7 @@ class ArrangeSSD(BaseOperator): ...@@ -152,6 +193,7 @@ class ArrangeSSD(BaseOperator):
outs = (im, gt_bbox, gt_class) outs = (im, gt_bbox, gt_class)
return outs return outs
@register_op @register_op
class ArrangeEvalSSD(BaseOperator): class ArrangeEvalSSD(BaseOperator):
""" """
...@@ -184,6 +226,7 @@ class ArrangeEvalSSD(BaseOperator): ...@@ -184,6 +226,7 @@ class ArrangeEvalSSD(BaseOperator):
return outs return outs
@register_op @register_op
class ArrangeTestSSD(BaseOperator): class ArrangeTestSSD(BaseOperator):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册