diff --git a/ppdet/data/data_feed.py b/ppdet/data/data_feed.py index 110fd68524263a2bd80f0d02b47d1f9ea5e1903c..26bd297a00926ad6e0ff852b3f1fb0de5c6fec3a 100644 --- a/ppdet/data/data_feed.py +++ b/ppdet/data/data_feed.py @@ -453,7 +453,7 @@ class FasterRCNNTrainFeed(DataFeed): 'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd' ], - image_shape=[None, 3, None, None], + image_shape=[3, None, None], sample_transforms=[ DecodeImage(to_rgb=True), RandomFlipImage(prob=0.5), @@ -505,7 +505,7 @@ class FasterRCNNEvalFeed(DataFeed): COCO_VAL_IMAGE_DIR).__dict__, fields=['image', 'im_info', 'im_id', 'im_shape', 'gt_box', 'gt_label', 'is_difficult'], - image_shape=[None, 3, None, None], + image_shape=[3, None, None], sample_transforms=[ DecodeImage(to_rgb=True), NormalizeImage(mean=[0.485, 0.456, 0.406], @@ -552,7 +552,7 @@ class FasterRCNNTestFeed(DataFeed): dataset=SimpleDataSet(COCO_VAL_ANNOTATION, COCO_VAL_IMAGE_DIR).__dict__, fields=['image', 'im_info', 'im_id', 'im_shape'], - image_shape=[None, 3, None, None], + image_shape=[3, None, None], sample_transforms=[ DecodeImage(to_rgb=True), NormalizeImage(mean=[0.485, 0.456, 0.406], @@ -600,7 +600,7 @@ class MaskRCNNTrainFeed(DataFeed): 'image', 'im_info', 'im_id', 'gt_box', 'gt_label', 'is_crowd', 'gt_mask' ], - image_shape=[None, 3, None, None], + image_shape=[3, None, None], sample_transforms=[ DecodeImage(to_rgb=True), RandomFlipImage(prob=0.5, is_mask_flip=True), @@ -646,7 +646,7 @@ class MaskRCNNEvalFeed(DataFeed): dataset=CocoDataSet(COCO_VAL_ANNOTATION, COCO_VAL_IMAGE_DIR).__dict__, fields=['image', 'im_info', 'im_id', 'im_shape'], - image_shape=[None, 3, None, None], + image_shape=[3, None, None], sample_transforms=[ DecodeImage(to_rgb=True), NormalizeImage(mean=[0.485, 0.456, 0.406], @@ -698,7 +698,7 @@ class MaskRCNNTestFeed(DataFeed): dataset=SimpleDataSet(COCO_VAL_ANNOTATION, COCO_VAL_IMAGE_DIR).__dict__, fields=['image', 'im_info', 'im_id', 'im_shape'], - image_shape=[None, 3, None, None], + image_shape=[3, None, None], sample_transforms=[ DecodeImage(to_rgb=True), NormalizeImage( @@ -743,7 +743,7 @@ class SSDTrainFeed(DataFeed): def __init__(self, dataset=VocDataSet().__dict__, fields=['image', 'gt_box', 'gt_label'], - image_shape=[None, 3, 300, 300], + image_shape=[3, 300, 300], sample_transforms=[ DecodeImage(to_rgb=True, with_mixup=False), NormalizeBox(), @@ -802,7 +802,7 @@ class SSDEvalFeed(DataFeed): dataset=VocDataSet(VOC_VAL_ANNOTATION).__dict__, fields=['image', 'im_shape', 'im_id', 'gt_box', 'gt_label', 'is_difficult'], - image_shape=[None, 3, 300, 300], + image_shape=[3, 300, 300], sample_transforms=[ DecodeImage(to_rgb=True, with_mixup=False), NormalizeBox(), @@ -847,7 +847,7 @@ class SSDTestFeed(DataFeed): def __init__(self, dataset=SimpleDataSet(VOC_VAL_ANNOTATION).__dict__, fields=['image', 'im_id', 'im_shape'], - image_shape=[None, 3, 300, 300], + image_shape=[3, 300, 300], sample_transforms=[ DecodeImage(to_rgb=True), ResizeImage(target_size=300, use_cv2=False, interp=1), @@ -893,7 +893,7 @@ class YoloTrainFeed(DataFeed): def __init__(self, dataset=CocoDataSet().__dict__, fields=['image', 'gt_box', 'gt_label', 'gt_score'], - image_shape=[None, 3, 608, 608], + image_shape=[3, 608, 608], sample_transforms=[ DecodeImage(to_rgb=True, with_mixup=True), MixupImage(alpha=1.5, beta=1.5), @@ -955,7 +955,7 @@ class YoloEvalFeed(DataFeed): COCO_VAL_IMAGE_DIR).__dict__, fields=['image', 'im_size', 'im_id', 'gt_box', 'gt_label', 'is_difficult'], - image_shape=[None, 3, 608, 608], + image_shape=[3, 608, 608], sample_transforms=[ DecodeImage(to_rgb=True), ResizeImage(target_size=608, interp=2), @@ -1013,7 +1013,7 @@ class YoloTestFeed(DataFeed): dataset=SimpleDataSet(COCO_VAL_ANNOTATION, COCO_VAL_IMAGE_DIR).__dict__, fields=['image', 'im_size', 'im_id'], - image_shape=[None, 3, 608, 608], + image_shape=[3, 608, 608], sample_transforms=[ DecodeImage(to_rgb=True), ResizeImage(target_size=608, interp=2), diff --git a/ppdet/modeling/model_input.py b/ppdet/modeling/model_input.py index 376438963a1d53cf0fe2126592af0421cd82a508..1ded8b20e785097a26be93d782ac7aa2cd64c8a6 100644 --- a/ppdet/modeling/model_input.py +++ b/ppdet/modeling/model_input.py @@ -40,7 +40,7 @@ feed_var_def = [ def create_feed(feed, iterable=False, sub_prog_feed=False): - image_shape = feed.image_shape + image_shape = [None] + feed.image_shape feed_var_map = {var['name']: var for var in feed_var_def} feed_var_map['image'] = { 'name': 'image', @@ -98,14 +98,14 @@ def create_feed(feed, iterable=False, sub_prog_feed=False): 'lod_level': 0 } image_name_list.append(name) - feed_var_map['im_info']['shape'] = [feed.num_scale * 3] + feed_var_map['im_info']['shape'] = [None, feed.num_scale * 3] feed.fields = image_name_list + feed.fields[1:] if sub_prog_feed: box_names = ['bbox', 'bbox_flip'] for box_name in box_names: sub_prog_feed = { 'name': box_name, - 'shape': [6], + 'shape': [None, 6], 'dtype': 'float32', 'lod_level': 1 }