未验证 提交 bfe82c09 编写于 作者: Y Yang Zhang 提交者: GitHub

Reorder data feed (#2592)

上级 4553b9e1
......@@ -81,7 +81,7 @@ def create_reader(feed, max_iter=0):
'TYPE': type(feed.dataset).__source__
}
}
if len(getattr(feed.dataset, 'images', [])) > 0:
data_config[mode]['IMAGES'] = feed.dataset.images
......@@ -448,43 +448,33 @@ class FasterRCNNTrainFeed(DataFeed):
self.mode = 'TRAIN'
# XXX currently use two presets, in the future, these should be combined into a
# single `RCNNTrainFeed`. Mask (and keypoint) should be processed
# automatically if `gt_mask` (or `gt_keypoints`) is in the required fields
@register
class MaskRCNNTrainFeed(DataFeed):
class FasterRCNNEvalFeed(DataFeed):
__doc__ = DataFeed.__doc__
def __init__(self,
dataset=CocoDataSet().__dict__,
fields=[
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask'
],
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
ResizeImage(target_size=800,
max_size=1333,
interp=1,
use_cv2=True),
Permute(to_bgr=False, channel_first=True)
ResizeImage(target_size=800, max_size=1333, interp=1),
Permute(to_bgr=False)
],
batch_transforms=[PadBatch()],
batch_size=1,
shuffle=True,
shuffle=False,
samples=-1,
drop_last=False,
num_workers=2,
use_process=False,
use_padded_im_info=False):
sample_transforms.append(ArrangeRCNN(is_mask=True))
super(MaskRCNNTrainFeed, self).__init__(
use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN())
super(FasterRCNNEvalFeed, self).__init__(
dataset,
fields,
image_shape,
......@@ -495,17 +485,17 @@ class MaskRCNNTrainFeed(DataFeed):
samples=samples,
drop_last=drop_last,
num_workers=num_workers,
use_process=use_process)
self.mode = 'TRAIN'
use_padded_im_info=use_padded_im_info)
self.mode = 'VAL'
@register
class FasterRCNNEvalFeed(DataFeed):
class FasterRCNNTestFeed(DataFeed):
__doc__ = DataFeed.__doc__
def __init__(self,
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 1333, 800],
sample_transforms=[
......@@ -514,7 +504,6 @@ class FasterRCNNEvalFeed(DataFeed):
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
ResizeImage(target_size=800, max_size=1333, interp=1),
Permute(to_bgr=False)
],
batch_transforms=[PadBatch()],
......@@ -525,7 +514,9 @@ class FasterRCNNEvalFeed(DataFeed):
num_workers=2,
use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN())
super(FasterRCNNEvalFeed, self).__init__(
if isinstance(dataset, dict):
dataset = SimpleDataSet(**dataset)
super(FasterRCNNTestFeed, self).__init__(
dataset,
fields,
image_shape,
......@@ -537,37 +528,46 @@ class FasterRCNNEvalFeed(DataFeed):
drop_last=drop_last,
num_workers=num_workers,
use_padded_im_info=use_padded_im_info)
self.mode = 'VAL'
self.mode = 'TEST'
# XXX currently use two presets, in the future, these should be combined into a
# single `RCNNTrainFeed`. Mask (and keypoint) should be processed
# automatically if `gt_mask` (or `gt_keypoints`) is in the required fields
@register
class FasterRCNNTestFeed(DataFeed):
class MaskRCNNTrainFeed(DataFeed):
__doc__ = DataFeed.__doc__
def __init__(self,
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
dataset=CocoDataSet().__dict__,
fields=[
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask'
],
image_shape=[3, 1333, 800],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
Permute(to_bgr=False)
ResizeImage(target_size=800,
max_size=1333,
interp=1,
use_cv2=True),
Permute(to_bgr=False, channel_first=True)
],
batch_transforms=[PadBatch()],
batch_size=1,
shuffle=False,
shuffle=True,
samples=-1,
drop_last=False,
num_workers=2,
use_padded_im_info=True):
sample_transforms.append(ArrangeTestRCNN())
if isinstance(dataset, dict):
dataset = SimpleDataSet(**dataset)
super(FasterRCNNTestFeed, self).__init__(
use_process=False,
use_padded_im_info=False):
sample_transforms.append(ArrangeRCNN(is_mask=True))
super(MaskRCNNTrainFeed, self).__init__(
dataset,
fields,
image_shape,
......@@ -578,8 +578,8 @@ class FasterRCNNTestFeed(DataFeed):
samples=samples,
drop_last=drop_last,
num_workers=num_workers,
use_padded_im_info=use_padded_im_info)
self.mode = 'TEST'
use_process=use_process)
self.mode = 'TRAIN'
@register
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册