未验证 提交 bfe82c09 编写于 作者: Y Yang Zhang 提交者: GitHub

Reorder data feed (#2592)

上级 4553b9e1
...@@ -81,7 +81,7 @@ def create_reader(feed, max_iter=0): ...@@ -81,7 +81,7 @@ def create_reader(feed, max_iter=0):
'TYPE': type(feed.dataset).__source__ 'TYPE': type(feed.dataset).__source__
} }
} }
if len(getattr(feed.dataset, 'images', [])) > 0: if len(getattr(feed.dataset, 'images', [])) > 0:
data_config[mode]['IMAGES'] = feed.dataset.images data_config[mode]['IMAGES'] = feed.dataset.images
...@@ -448,43 +448,33 @@ class FasterRCNNTrainFeed(DataFeed): ...@@ -448,43 +448,33 @@ class FasterRCNNTrainFeed(DataFeed):
self.mode = 'TRAIN' self.mode = 'TRAIN'
# XXX currently use two presets, in the future, these should be combined into a
# single `RCNNTrainFeed`. Mask (and keypoint) should be processed
# automatically if `gt_mask` (or `gt_keypoints`) is in the required fields
@register @register
class MaskRCNNTrainFeed(DataFeed): class FasterRCNNEvalFeed(DataFeed):
__doc__ = DataFeed.__doc__ __doc__ = DataFeed.__doc__
def __init__(self, def __init__(self,
dataset=CocoDataSet().__dict__, dataset=CocoDataSet(COCO_VAL_ANNOTATION,
fields=[ COCO_VAL_IMAGE_DIR).__dict__,
'image', 'im_info', 'im_id', 'gt_box', 'gt_label', fields=['image', 'im_info', 'im_id', 'im_shape'],
'is_crowd', 'gt_mask'
],
image_shape=[3, 1333, 800], image_shape=[3, 1333, 800],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], std=[0.229, 0.224, 0.225],
is_scale=True, is_scale=True,
is_channel_first=False), is_channel_first=False),
ResizeImage(target_size=800, ResizeImage(target_size=800, max_size=1333, interp=1),
max_size=1333, Permute(to_bgr=False)
interp=1,
use_cv2=True),
Permute(to_bgr=False, channel_first=True)
], ],
batch_transforms=[PadBatch()], batch_transforms=[PadBatch()],
batch_size=1, batch_size=1,
shuffle=True, shuffle=False,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
num_workers=2, num_workers=2,
use_process=False, use_padded_im_info=True):
use_padded_im_info=False): sample_transforms.append(ArrangeTestRCNN())
sample_transforms.append(ArrangeRCNN(is_mask=True)) super(FasterRCNNEvalFeed, self).__init__(
super(MaskRCNNTrainFeed, self).__init__(
dataset, dataset,
fields, fields,
image_shape, image_shape,
...@@ -495,17 +485,17 @@ class MaskRCNNTrainFeed(DataFeed): ...@@ -495,17 +485,17 @@ class MaskRCNNTrainFeed(DataFeed):
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
num_workers=num_workers, num_workers=num_workers,
use_process=use_process) use_padded_im_info=use_padded_im_info)
self.mode = 'TRAIN' self.mode = 'VAL'
@register @register
class FasterRCNNEvalFeed(DataFeed): class FasterRCNNTestFeed(DataFeed):
__doc__ = DataFeed.__doc__ __doc__ = DataFeed.__doc__
def __init__(self, def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800], image_shape=[3, 1333, 800],
sample_transforms=[ sample_transforms=[
...@@ -514,7 +504,6 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -514,7 +504,6 @@ class FasterRCNNEvalFeed(DataFeed):
std=[0.229, 0.224, 0.225], std=[0.229, 0.224, 0.225],
is_scale=True, is_scale=True,
is_channel_first=False), is_channel_first=False),
ResizeImage(target_size=800, max_size=1333, interp=1),
Permute(to_bgr=False) Permute(to_bgr=False)
], ],
batch_transforms=[PadBatch()], batch_transforms=[PadBatch()],
...@@ -525,7 +514,9 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -525,7 +514,9 @@ class FasterRCNNEvalFeed(DataFeed):
num_workers=2, num_workers=2,
use_padded_im_info=True): use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN()) sample_transforms.append(ArrangeTestRCNN())
super(FasterRCNNEvalFeed, self).__init__( if isinstance(dataset, dict):
dataset = SimpleDataSet(**dataset)
super(FasterRCNNTestFeed, self).__init__(
dataset, dataset,
fields, fields,
image_shape, image_shape,
...@@ -537,37 +528,46 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -537,37 +528,46 @@ class FasterRCNNEvalFeed(DataFeed):
drop_last=drop_last, drop_last=drop_last,
num_workers=num_workers, num_workers=num_workers,
use_padded_im_info=use_padded_im_info) use_padded_im_info=use_padded_im_info)
self.mode = 'VAL' self.mode = 'TEST'
# XXX currently use two presets, in the future, these should be combined into a
# single `RCNNTrainFeed`. Mask (and keypoint) should be processed
# automatically if `gt_mask` (or `gt_keypoints`) is in the required fields
@register @register
class FasterRCNNTestFeed(DataFeed): class MaskRCNNTrainFeed(DataFeed):
__doc__ = DataFeed.__doc__ __doc__ = DataFeed.__doc__
def __init__(self, def __init__(self,
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=CocoDataSet().__dict__,
COCO_VAL_IMAGE_DIR).__dict__, fields=[
fields=['image', 'im_info', 'im_id', 'im_shape'], 'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask'
],
image_shape=[3, 1333, 800], image_shape=[3, 1333, 800],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225], std=[0.229, 0.224, 0.225],
is_scale=True, is_scale=True,
is_channel_first=False), is_channel_first=False),
Permute(to_bgr=False) ResizeImage(target_size=800,
max_size=1333,
interp=1,
use_cv2=True),
Permute(to_bgr=False, channel_first=True)
], ],
batch_transforms=[PadBatch()], batch_transforms=[PadBatch()],
batch_size=1, batch_size=1,
shuffle=False, shuffle=True,
samples=-1, samples=-1,
drop_last=False, drop_last=False,
num_workers=2, num_workers=2,
use_padded_im_info=True): use_process=False,
sample_transforms.append(ArrangeTestRCNN()) use_padded_im_info=False):
if isinstance(dataset, dict): sample_transforms.append(ArrangeRCNN(is_mask=True))
dataset = SimpleDataSet(**dataset) super(MaskRCNNTrainFeed, self).__init__(
super(FasterRCNNTestFeed, self).__init__(
dataset, dataset,
fields, fields,
image_shape, image_shape,
...@@ -578,8 +578,8 @@ class FasterRCNNTestFeed(DataFeed): ...@@ -578,8 +578,8 @@ class FasterRCNNTestFeed(DataFeed):
samples=samples, samples=samples,
drop_last=drop_last, drop_last=drop_last,
num_workers=num_workers, num_workers=num_workers,
use_padded_im_info=use_padded_im_info) use_process=use_process)
self.mode = 'TEST' self.mode = 'TRAIN'
@register @register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册