diff --git a/ppdet/data/data_feed.py b/ppdet/data/data_feed.py index 568f00de9867d6220dbce3d459cd1be66c651e22..4f67bed1a61f4ab7597bac608b542e0347fad3c9 100644 --- a/ppdet/data/data_feed.py +++ b/ppdet/data/data_feed.py @@ -30,8 +30,8 @@ from ppdet.data.transform.operators import ( Permute) from ppdet.data.transform.arrange_sample import ( - ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD, ArrangeTestSSD, - ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO) + ArrangeRCNN, ArrangeEvalRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD, + ArrangeTestSSD, ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO) __all__ = [ 'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet', @@ -476,7 +476,8 @@ class FasterRCNNEvalFeed(DataFeed): def __init__(self, dataset=CocoDataSet(COCO_VAL_ANNOTATION, COCO_VAL_IMAGE_DIR).__dict__, - fields=['image', 'im_info', 'im_id', 'im_shape'], + fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box', + 'gt_label', 'is_difficult'], image_shape=[3, 800, 1333], sample_transforms=[ DecodeImage(to_rgb=True), @@ -494,7 +495,7 @@ class FasterRCNNEvalFeed(DataFeed): drop_last=False, num_workers=2, use_padded_im_info=True): - sample_transforms.append(ArrangeTestRCNN()) + sample_transforms.append(ArrangeEvalRCNN()) super(FasterRCNNEvalFeed, self).__init__( dataset, fields, diff --git a/ppdet/data/transform/arrange_sample.py b/ppdet/data/transform/arrange_sample.py index 13f70bfa30dc8710ff0277004a67c3c75fdad65e..697995cd72d37a7dee29f6c18ca527ffa6ad5077 100644 --- a/ppdet/data/transform/arrange_sample.py +++ b/ppdet/data/transform/arrange_sample.py @@ -90,6 +90,47 @@ class ArrangeRCNN(BaseOperator): return outs +@register_op +class ArrangeEvalRCNN(BaseOperator): + """ + Transform dict to the tuple format needed for evaluation. + """ + + def __init__(self): + super(ArrangeEvalRCNN, self).__init__() + + def __call__(self, sample, context=None): + """ + Args: + sample: a dict which contains image + info and annotation info. + context: a dict which contains additional info. + Returns: + sample: a tuple containing the following items: + (image, im_info, im_id, im_shape, gt_bbox, + gt_class, difficult) + """ + im = sample['image'] + keys = list(sample.keys()) + if 'im_info' in keys: + im_info = sample['im_info'] + else: + raise KeyError("The dataset doesn't have 'im_info' key.") + im_id = sample['im_id'] + h = sample['h'] + w = sample['w'] + # For rcnn models in eval and infer stage, original image size + # is needed to clip the bounding boxes. And box clip op in + # bbox prediction needs im_info as input in format of [N, 3], + # so im_shape is appended by 1 to match dimension. + im_shape = np.array((h, w, 1), dtype=np.float32) + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + difficult = sample['difficult'] + outs = (im, im_info, im_id, im_shape, gt_bbox, gt_class, difficult) + return outs + + @register_op class ArrangeTestRCNN(BaseOperator): """ @@ -152,6 +193,7 @@ class ArrangeSSD(BaseOperator): outs = (im, gt_bbox, gt_class) return outs + @register_op class ArrangeEvalSSD(BaseOperator): """ @@ -184,6 +226,7 @@ class ArrangeEvalSSD(BaseOperator): return outs + @register_op class ArrangeTestSSD(BaseOperator): """