未验证 提交 7c454715 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add support for None for fluid.data (#20228)

上级 e9205c38
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
import six
from . import core from . import core
from .layer_helper import LayerHelper from .layer_helper import LayerHelper
...@@ -86,6 +87,11 @@ def data(name, shape, dtype='float32', lod_level=0): ...@@ -86,6 +87,11 @@ def data(name, shape, dtype='float32', lod_level=0):
""" """
helper = LayerHelper('data', **locals()) helper = LayerHelper('data', **locals())
shape = list(shape)
for i in six.moves.range(len(shape)):
if shape[i] is None:
shape[i] = -1
return helper.create_global_variable( return helper.create_global_variable(
name=name, name=name,
shape=shape, shape=shape,
......
...@@ -83,7 +83,7 @@ class TestFeedData(unittest.TestCase): ...@@ -83,7 +83,7 @@ class TestFeedData(unittest.TestCase):
def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor): def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor) batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
in_size = [-1, 3, 4, 8] in_size = [None, 3, 4, 8]
feed_in_data = np.random.uniform( feed_in_data = np.random.uniform(
size=[batch_size, 3, 4, 5]).astype(np.float32) size=[batch_size, 3, 4, 5]).astype(np.float32)
label_size = [-1, 1] label_size = [-1, 1]
...@@ -97,7 +97,7 @@ class TestFeedData(unittest.TestCase): ...@@ -97,7 +97,7 @@ class TestFeedData(unittest.TestCase):
in_size = [-1, 3, 4, 5] in_size = [-1, 3, 4, 5]
feed_in_data = np.random.uniform( feed_in_data = np.random.uniform(
size=[batch_size, 3, 4, 5]).astype(np.float32) size=[batch_size, 3, 4, 5]).astype(np.float32)
label_size = (-1, 1) label_size = (None, 1)
feed_label = np.random.randint( feed_label = np.random.randint(
low=0, high=self.class_num, size=[batch_size, 1]).astype(np.int64) low=0, high=self.class_num, size=[batch_size, 1]).astype(np.int64)
self._feed_data_in_executor(in_size, label_size, feed_in_data, self._feed_data_in_executor(in_size, label_size, feed_in_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册