未验证 提交 9229209b 编写于 作者: W wangguanzhong 提交者: GitHub

fix config for fluid.data (#37)

上级 3af2e211
......@@ -453,7 +453,7 @@ class FasterRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd'
],
image_shape=[None, 3, None, None],
image_shape=[3, None, None],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5),
......@@ -505,7 +505,7 @@ class FasterRCNNEvalFeed(DataFeed):
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[None, 3, None, None],
image_shape=[3, None, None],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
......@@ -552,7 +552,7 @@ class FasterRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[None, 3, None, None],
image_shape=[3, None, None],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
......@@ -600,7 +600,7 @@ class MaskRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask'
],
image_shape=[None, 3, None, None],
image_shape=[3, None, None],
sample_transforms=[
DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True),
......@@ -646,7 +646,7 @@ class MaskRCNNEvalFeed(DataFeed):
dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[None, 3, None, None],
image_shape=[3, None, None],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406],
......@@ -698,7 +698,7 @@ class MaskRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[None, 3, None, None],
image_shape=[3, None, None],
sample_transforms=[
DecodeImage(to_rgb=True),
NormalizeImage(
......@@ -743,7 +743,7 @@ class SSDTrainFeed(DataFeed):
def __init__(self,
dataset=VocDataSet().__dict__,
fields=['image', 'gt_box', 'gt_label'],
image_shape=[None, 3, 300, 300],
image_shape=[3, 300, 300],
sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=False),
NormalizeBox(),
......@@ -802,7 +802,7 @@ class SSDEvalFeed(DataFeed):
dataset=VocDataSet(VOC_VAL_ANNOTATION).__dict__,
fields=['image', 'im_shape', 'im_id', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[None, 3, 300, 300],
image_shape=[3, 300, 300],
sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=False),
NormalizeBox(),
......@@ -847,7 +847,7 @@ class SSDTestFeed(DataFeed):
def __init__(self,
dataset=SimpleDataSet(VOC_VAL_ANNOTATION).__dict__,
fields=['image', 'im_id', 'im_shape'],
image_shape=[None, 3, 300, 300],
image_shape=[3, 300, 300],
sample_transforms=[
DecodeImage(to_rgb=True),
ResizeImage(target_size=300, use_cv2=False, interp=1),
......@@ -893,7 +893,7 @@ class YoloTrainFeed(DataFeed):
def __init__(self,
dataset=CocoDataSet().__dict__,
fields=['image', 'gt_box', 'gt_label', 'gt_score'],
image_shape=[None, 3, 608, 608],
image_shape=[3, 608, 608],
sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=True),
MixupImage(alpha=1.5, beta=1.5),
......@@ -955,7 +955,7 @@ class YoloEvalFeed(DataFeed):
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_size', 'im_id', 'gt_box',
'gt_label', 'is_difficult'],
image_shape=[None, 3, 608, 608],
image_shape=[3, 608, 608],
sample_transforms=[
DecodeImage(to_rgb=True),
ResizeImage(target_size=608, interp=2),
......@@ -1013,7 +1013,7 @@ class YoloTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_size', 'im_id'],
image_shape=[None, 3, 608, 608],
image_shape=[3, 608, 608],
sample_transforms=[
DecodeImage(to_rgb=True),
ResizeImage(target_size=608, interp=2),
......
......@@ -40,7 +40,7 @@ feed_var_def = [
def create_feed(feed, iterable=False, sub_prog_feed=False):
image_shape = feed.image_shape
image_shape = [None] + feed.image_shape
feed_var_map = {var['name']: var for var in feed_var_def}
feed_var_map['image'] = {
'name': 'image',
......@@ -98,14 +98,14 @@ def create_feed(feed, iterable=False, sub_prog_feed=False):
'lod_level': 0
}
image_name_list.append(name)
feed_var_map['im_info']['shape'] = [feed.num_scale * 3]
feed_var_map['im_info']['shape'] = [None, feed.num_scale * 3]
feed.fields = image_name_list + feed.fields[1:]
if sub_prog_feed:
box_names = ['bbox', 'bbox_flip']
for box_name in box_names:
sub_prog_feed = {
'name': box_name,
'shape': [6],
'shape': [None, 6],
'dtype': 'float32',
'lod_level': 1
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册