未验证 提交 02700af9 编写于 作者: W wangguanzhong 提交者: GitHub

cherry-pick fix data_feed (#68)

上级 3923340d
...@@ -453,7 +453,7 @@ class FasterRCNNTrainFeed(DataFeed): ...@@ -453,7 +453,7 @@ class FasterRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd' 'is_crowd'
], ],
image_shape=[3, 800, 1333], image_shape=[None, 3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5), RandomFlipImage(prob=0.5),
...@@ -505,7 +505,7 @@ class FasterRCNNEvalFeed(DataFeed): ...@@ -505,7 +505,7 @@ class FasterRCNNEvalFeed(DataFeed):
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box', fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box',
'gt_label', 'is_difficult'], 'gt_label', 'is_difficult'],
image_shape=[3, 800, 1333], image_shape=[None, 3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -552,7 +552,7 @@ class FasterRCNNTestFeed(DataFeed): ...@@ -552,7 +552,7 @@ class FasterRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 800, 1333], image_shape=[None, 3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -600,7 +600,7 @@ class MaskRCNNTrainFeed(DataFeed): ...@@ -600,7 +600,7 @@ class MaskRCNNTrainFeed(DataFeed):
'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'image', 'im_info', 'im_id', 'gt_box', 'gt_label',
'is_crowd', 'gt_mask' 'is_crowd', 'gt_mask'
], ],
image_shape=[3, 800, 1333], image_shape=[None, 3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
RandomFlipImage(prob=0.5, is_mask_flip=True), RandomFlipImage(prob=0.5, is_mask_flip=True),
...@@ -646,7 +646,7 @@ class MaskRCNNEvalFeed(DataFeed): ...@@ -646,7 +646,7 @@ class MaskRCNNEvalFeed(DataFeed):
dataset=CocoDataSet(COCO_VAL_ANNOTATION, dataset=CocoDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 800, 1333], image_shape=[None, 3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage(mean=[0.485, 0.456, 0.406], NormalizeImage(mean=[0.485, 0.456, 0.406],
...@@ -698,7 +698,7 @@ class MaskRCNNTestFeed(DataFeed): ...@@ -698,7 +698,7 @@ class MaskRCNNTestFeed(DataFeed):
dataset=SimpleDataSet(COCO_VAL_ANNOTATION, dataset=SimpleDataSet(COCO_VAL_ANNOTATION,
COCO_VAL_IMAGE_DIR).__dict__, COCO_VAL_IMAGE_DIR).__dict__,
fields=['image', 'im_info', 'im_id', 'im_shape'], fields=['image', 'im_info', 'im_id', 'im_shape'],
image_shape=[3, 800, 1333], image_shape=[None, 3, None, None],
sample_transforms=[ sample_transforms=[
DecodeImage(to_rgb=True), DecodeImage(to_rgb=True),
NormalizeImage( NormalizeImage(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册