提交 0520640d 编写于 作者: Q qingqing01 提交者: GitHub

Change data shape in data_feed.py (#3026)

上级 b2b359c3
......@@ -28,10 +28,9 @@ from ppdet.data.transform.operators import (
DecodeImage, MixupImage, NormalizeBox, NormalizeImage, RandomDistort,
RandomFlipImage, RandomInterpImage, ResizeImage, ExpandImage, CropImage,
Permute)
from ppdet.data.transform.arrange_sample import (ArrangeRCNN, ArrangeTestRCNN,
ArrangeSSD, ArrangeTestSSD,
ArrangeYOLO, ArrangeEvalYOLO,
ArrangeTestYOLO)
from ppdet.data.transform.arrange_sample import (
ArrangeRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeTestSSD, ArrangeYOLO,
ArrangeEvalYOLO, ArrangeTestYOLO)
__all__ = [
'PadBatch', 'MultiScale', 'RandomShape', 'DataSet', 'CocoDataSet',
......@@ -138,8 +137,8 @@ def create_reader(feed, max_iter=0, args_path=None, my_source=None):
ops.append(op_dict)
transform_config['OPS'] = ops
return Reader.create(feed.mode, data_config,
transform_config, max_iter, my_source)
return Reader.create(feed.mode, data_config, transform_config, max_iter,
my_source)
# XXX batch transforms are only stubs for now, actually handled by `post_map`
......@@ -412,6 +411,7 @@ class TestFeed(DataFeed):
num_workers=num_workers)
# yapf: disable
@register
class FasterRCNNTrainFeed(DataFeed):
__doc__ = DataFeed.__doc__
......@@ -422,7 +422,7 @@ class FasterRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd'
],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5),
......@@ -467,7 +467,7 @@ class FasterRCNNEvalFeed(DataFeed):
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
......@@ -508,7 +508,7 @@ class FasterRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
......@@ -555,7 +555,7 @@ class MaskRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask'
],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True),
......@@ -601,7 +601,7 @@ class MaskRCNNEvalFeed(DataFeed):
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
......@@ -647,7 +647,7 @@ class MaskRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
image_shape=[3, 800, 1333],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(
......@@ -900,7 +900,7 @@ class YoloEvalFeed(DataFeed):
def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_size', 'im_id', 'gt_box',
fields=['image', 'im_size', 'im_id', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[3, 608, 608],
sample_transforms=[
......@@ -985,3 +985,4 @@ class YoloTestFeed(DataFeed):
use_process=use_process)
self.mode = 'TEST'
self.bufsize = 128
# yapf: enable
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册