From 7c45471579c10862b5c6d9fa143cca5058b6bbd1 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Wed, 9 Oct 2019 15:47:39 +0800 Subject: [PATCH] Add support for None for fluid.data (#20228) --- python/paddle/fluid/data.py | 6 ++++++ .../tests/unittests/test_feed_data_check_shape_type.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/data.py b/python/paddle/fluid/data.py index b90a681ed8e..008765c3f53 100644 --- a/python/paddle/fluid/data.py +++ b/python/paddle/fluid/data.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import six from . import core from .layer_helper import LayerHelper @@ -86,6 +87,11 @@ def data(name, shape, dtype='float32', lod_level=0): """ helper = LayerHelper('data', **locals()) + shape = list(shape) + for i in six.moves.range(len(shape)): + if shape[i] is None: + shape[i] = -1 + return helper.create_global_variable( name=name, shape=shape, diff --git a/python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py b/python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py index 2489ae8e266..9f70c63c69d 100644 --- a/python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py +++ b/python/paddle/fluid/tests/unittests/test_feed_data_check_shape_type.py @@ -83,7 +83,7 @@ class TestFeedData(unittest.TestCase): def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor): batch_size = self._get_batch_size(use_cuda, use_parallel_executor) - in_size = [-1, 3, 4, 8] + in_size = [None, 3, 4, 8] feed_in_data = np.random.uniform( size=[batch_size, 3, 4, 5]).astype(np.float32) label_size = [-1, 1] @@ -97,7 +97,7 @@ class TestFeedData(unittest.TestCase): in_size = [-1, 3, 4, 5] feed_in_data = np.random.uniform( size=[batch_size, 3, 4, 5]).astype(np.float32) - label_size = (-1, 1) + label_size = (None, 1) feed_label = np.random.randint( low=0, high=self.class_num, size=[batch_size, 1]).astype(np.int64) self._feed_data_in_executor(in_size, label_size, feed_in_data, -- GitLab