提交 78a4d0bf 编写于 作者: W wangguanzhong 提交者: GitHub

fix rcnn eval on voc (#3344)

* fix rcnn eval on voc

* update comment
上级 28e19603
......@@ -30,8 +30,8 @@ from ppdet.data.transform.operators import (
Permute)
from ppdet.data.transform.arrange_sample import (
ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD, ArrangeTestSSD,
ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO)
ArrangeRCNN, ArrangeEvalRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD,
ArrangeTestSSD, ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO)
__all__ = [
'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet',
......@@ -476,7 +476,8 @@ class FasterRCNNEvalFeed(DataFeed):
def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
......@@ -494,7 +495,7 @@ class FasterRCNNEvalFeed(DataFeed):
drop_last=False,
num_workers=2,
use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN())
sample_transforms.append(ArrangeEvalRCNN())
super(FasterRCNNEvalFeed, self).__init__(
dataset,
fields,
......
......@@ -90,6 +90,47 @@ class ArrangeRCNN(BaseOperator):
return outs
@register_op
class ArrangeEvalRCNN(BaseOperator):
"""
Transform dict to the tuple format needed for evaluation.
"""
def __init__(self):
super(ArrangeEvalRCNN, self).__init__()
def __call__(self, sample, context=None):
"""
Args:
sample: a dict which contains image
info and annotation info.
context: a dict which contains additional info.
Returns:
sample: a tuple containing the following items:
(image, im_info, im_id, im_shape, gt_bbox,
gt_class, difficult)
"""
im = sample['image']
keys = list(sample.keys())
if 'im_info' in keys:
im_info = sample['im_info']
else:
raise KeyError("The dataset doesn't have 'im_info' key.")
im_id = sample['im_id']
h = sample['h']
w = sample['w']
# For rcnn models in eval and infer stage, original image size
# is needed to clip the bounding boxes. And box clip op in
# bbox prediction needs im_info as input in format of [N, 3],
# so im_shape is appended by 1 to match dimension.
im_shape = np.array((h, w, 1), dtype=np.float32)
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
difficult = sample['difficult']
outs = (im, im_info, im_id, im_shape, gt_bbox, gt_class, difficult)
return outs
@register_op
class ArrangeTestRCNN(BaseOperator):
"""
......@@ -152,6 +193,7 @@ class ArrangeSSD(BaseOperator):
outs = (im, gt_bbox, gt_class)
return outs
@register_op
class ArrangeEvalSSD(BaseOperator):
"""
......@@ -184,6 +226,7 @@ class ArrangeEvalSSD(BaseOperator):
return outs
@register_op
class ArrangeTestSSD(BaseOperator):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册