提交 0520640d 编写于 作者: Q qingqing01 提交者: GitHub

Change data shape in data_feed.py (#3026)

上级 b2b359c3
...@@ -28,10 +28,9 @@ from ppdet.data.transform.operators import ( ...@@ -28,10 +28,9 @@ from ppdet.data.transform.operators import (
DecodeImage, MixupImage, NormalizeBox, NormalizeImage, RandomDistort, DecodeImage, MixupImage, NormalizeBox, NormalizeImage, RandomDistort,
RandomFlipImage, RandomInterpImage, ResizeImage, ExpandImage, CropImage, RandomFlipImage, RandomInterpImage, ResizeImage, ExpandImage, CropImage,
Permute) Permute)
from ppdet.data.transform.arrange_sample import (ArrangeRCNN, ArrangeTestRCNN, from ppdet.data.transform.arrange_sample import (
ArrangeSSD, ArrangeTestSSD, ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeTestSSD, ArrangeYOLO,
ArrangeYOLO, ArrangeEvalYOLO, ArrangeEvalYOLO, ArrangeTestYOLO)
ArrangeTestYOLO)
__all__ = [ __all__ = [
'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet', 'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet',
...@@ -138,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None, my_source=None): ...@@ -138,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None, my_source=None):
ops.append(op_dict) ops.append(op_dict)
transform_config['OPS'] = ops transform_config['OPS'] = ops
return Reader.create(feed.mode, data_config, return Reader.create(feed.mode, data_config, transform_config, max_iter,
transform_config, max_iter, my_source) my_source)
# XXX batch transforms are only stubs for now, actually handled by `post_map` # XXX batch transforms are only stubs for now, actually handled by `post_map`
...@@ -412,6 +411,7 @@ class TestFeed(DataFeed): ...@@ -412,6 +411,7 @@ class TestFeed(DataFeed):
num_workers=num_workers) num_workers=num_workers)
# yapf: disable
@register @register
class FasterRCNNTrainFeed(DataFeed): class FasterRCNNTrainFeed(DataFeed):
__doc__ = DataFeed.__doc__ __doc__ = DataFeed.__doc__
...@@ -422,7 +422,7 @@ class FasterRCNNTrainFeed(DataFeed): ...@@ -422,7 +422,7 @@ class FasterRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd' 'is_crowd'
], ],
image_shape=[3, 1333, 800], image_shape=[3, 800, 1333],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5), RandomFlipImage(prob=0.5),
...@@ -467,7 +467,7 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -467,7 +467,7 @@ class FasterRCNNEvalFeed(DataFeed):
dataset=CocoDataSet(COCO_VAL_ANNOTATION, dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800], image_shape=[3, 800, 1333],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -508,7 +508,7 @@ class FasterRCNNTestFeed(DataFeed): ...@@ -508,7 +508,7 @@ class FasterRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800], image_shape=[3, 800, 1333],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -555,7 +555,7 @@ class MaskRCNNTrainFeed(DataFeed): ...@@ -555,7 +555,7 @@ class MaskRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask' 'is_crowd', 'gt_mask'
], ],
image_shape=[3, 1333, 800], image_shape=[3, 800, 1333],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True), RandomFlipImage(prob=0.5, is_mask_flip=True),
...@@ -601,7 +601,7 @@ class MaskRCNNEvalFeed(DataFeed): ...@@ -601,7 +601,7 @@ class MaskRCNNEvalFeed(DataFeed):
dataset=CocoDataSet(COCO_VAL_ANNOTATION, dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800], image_shape=[3, 800, 1333],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -647,7 +647,7 @@ class MaskRCNNTestFeed(DataFeed): ...@@ -647,7 +647,7 @@ class MaskRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800], image_shape=[3, 800, 1333],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage( NormalizeImage(
...@@ -985,3 +985,4 @@ class YoloTestFeed(DataFeed): ...@@ -985,3 +985,4 @@ class YoloTestFeed(DataFeed):
use_process=use_process) use_process=use_process)
self.mode = 'TEST' self.mode = 'TEST'
self.bufsize = 128 self.bufsize = 128
# yapf: enable
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册