未验证 提交 9229209b 编写于 作者: W wangguanzhong 提交者: GitHub

fix config for fluid.data (#37)

上级 3af2e211
...@@ -453,7 +453,7 @@ class FasterRCNNTrainFeed(DataFeed): ...@@ -453,7 +453,7 @@ class FasterRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd' 'is_crowd'
], ],
image_shape=[None, 3, None, None], image_shape=[3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5), RandomFlipImage(prob=0.5),
...@@ -505,7 +505,7 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -505,7 +505,7 @@ class FasterRCNNEvalFeed(DataFeed):
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box', fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'], 'gt_label', 'is_difficult'],
image_shape=[None, 3, None, None], image_shape=[3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -552,7 +552,7 @@ class FasterRCNNTestFeed(DataFeed): ...@@ -552,7 +552,7 @@ class FasterRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[None, 3, None, None], image_shape=[3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -600,7 +600,7 @@ class MaskRCNNTrainFeed(DataFeed): ...@@ -600,7 +600,7 @@ class MaskRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask' 'is_crowd', 'gt_mask'
], ],
image_shape=[None, 3, None, None], image_shape=[3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True), RandomFlipImage(prob=0.5, is_mask_flip=True),
...@@ -646,7 +646,7 @@ class MaskRCNNEvalFeed(DataFeed): ...@@ -646,7 +646,7 @@ class MaskRCNNEvalFeed(DataFeed):
dataset=CocoDataSet(COCO_VAL_ANNOTATION, dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[None, 3, None, None], image_shape=[3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -698,7 +698,7 @@ class MaskRCNNTestFeed(DataFeed): ...@@ -698,7 +698,7 @@ class MaskRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[None, 3, None, None], image_shape=[3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage( NormalizeImage(
...@@ -743,7 +743,7 @@ class SSDTrainFeed(DataFeed): ...@@ -743,7 +743,7 @@ class SSDTrainFeed(DataFeed):
def __init__(self, def __init__(self,
dataset=VocDataSet().__dict__, dataset=VocDataSet().__dict__,
fields=['image', 'gt_box', 'gt_label'], fields=['image', 'gt_box', 'gt_label'],
image_shape=[None, 3, 300, 300], image_shape=[3, 300, 300],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=False), DecodeImage(to_rgb=True, with_mixup=False),
NormalizeBox(), NormalizeBox(),
...@@ -802,7 +802,7 @@ class SSDEvalFeed(DataFeed): ...@@ -802,7 +802,7 @@ class SSDEvalFeed(DataFeed):
dataset=VocDataSet(VOC_VAL_ANNOTATION).__dict__, dataset=VocDataSet(VOC_VAL_ANNOTATION).__dict__,
fields=['image', 'im_shape', 'im_id', 'gt_box', fields=['image', 'im_shape', 'im_id', 'gt_box',
'gt_label', 'is_difficult'], 'gt_label', 'is_difficult'],
image_shape=[None, 3, 300, 300], image_shape=[3, 300, 300],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=False), DecodeImage(to_rgb=True, with_mixup=False),
NormalizeBox(), NormalizeBox(),
...@@ -847,7 +847,7 @@ class SSDTestFeed(DataFeed): ...@@ -847,7 +847,7 @@ class SSDTestFeed(DataFeed):
def __init__(self, def __init__(self,
dataset=SimpleDataSet(VOC_VAL_ANNOTATION).__dict__, dataset=SimpleDataSet(VOC_VAL_ANNOTATION).__dict__,
fields=['image', 'im_id', 'im_shape'], fields=['image', 'im_id', 'im_shape'],
image_shape=[None, 3, 300, 300], image_shape=[3, 300, 300],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
ResizeImage(target_size=300, use_cv2=False, interp=1), ResizeImage(target_size=300, use_cv2=False, interp=1),
...@@ -893,7 +893,7 @@ class YoloTrainFeed(DataFeed): ...@@ -893,7 +893,7 @@ class YoloTrainFeed(DataFeed):
def __init__(self, def __init__(self,
dataset=CocoDataSet().__dict__, dataset=CocoDataSet().__dict__,
fields=['image', 'gt_box', 'gt_label', 'gt_score'], fields=['image', 'gt_box', 'gt_label', 'gt_score'],
image_shape=[None, 3, 608, 608], image_shape=[3, 608, 608],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=True), DecodeImage(to_rgb=True, with_mixup=True),
MixupImage(alpha=1.5, beta=1.5), MixupImage(alpha=1.5, beta=1.5),
...@@ -955,7 +955,7 @@ class YoloEvalFeed(DataFeed): ...@@ -955,7 +955,7 @@ class YoloEvalFeed(DataFeed):
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_size', 'im_id', 'gt_box', fields=['image', 'im_size', 'im_id', 'gt_box',
'gt_label', 'is_difficult'], 'gt_label', 'is_difficult'],
image_shape=[None, 3, 608, 608], image_shape=[3, 608, 608],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
ResizeImage(target_size=608, interp=2), ResizeImage(target_size=608, interp=2),
...@@ -1013,7 +1013,7 @@ class YoloTestFeed(DataFeed): ...@@ -1013,7 +1013,7 @@ class YoloTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_size', 'im_id'], fields=['image', 'im_size', 'im_id'],
image_shape=[None, 3, 608, 608], image_shape=[3, 608, 608],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
ResizeImage(target_size=608, interp=2), ResizeImage(target_size=608, interp=2),
......
...@@ -40,7 +40,7 @@ feed_var_def = [ ...@@ -40,7 +40,7 @@ feed_var_def = [
def create_feed(feed, iterable=False, sub_prog_feed=False): def create_feed(feed, iterable=False, sub_prog_feed=False):
image_shape = feed.image_shape image_shape = [None] + feed.image_shape
feed_var_map = {var['name']: var for var in feed_var_def} feed_var_map = {var['name']: var for var in feed_var_def}
feed_var_map['image'] = { feed_var_map['image'] = {
'name': 'image', 'name': 'image',
...@@ -98,14 +98,14 @@ def create_feed(feed, iterable=False, sub_prog_feed=False): ...@@ -98,14 +98,14 @@ def create_feed(feed, iterable=False, sub_prog_feed=False):
'lod_level': 0 'lod_level': 0
} }
image_name_list.append(name) image_name_list.append(name)
feed_var_map['im_info']['shape'] = [feed.num_scale * 3] feed_var_map['im_info']['shape'] = [None, feed.num_scale * 3]
feed.fields = image_name_list + feed.fields[1:] feed.fields = image_name_list + feed.fields[1:]
if sub_prog_feed: if sub_prog_feed:
box_names = ['bbox', 'bbox_flip'] box_names = ['bbox', 'bbox_flip']
for box_name in box_names: for box_name in box_names:
sub_prog_feed = { sub_prog_feed = {
'name': box_name, 'name': box_name,
'shape': [6], 'shape': [None, 6],
'dtype': 'float32', 'dtype': 'float32',
'lod_level': 1 'lod_level': 1
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册