diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 3b6ded344462a959c589b021a2d1b3ac4242a969..a90551c1b7b4fd45ae9a0e1cfa225a87db811295 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -29,6 +29,7 @@ from ..data_feeder import check_variable_and_dtype, check_type, check_dtype, con from paddle.utils import deprecated import numpy import warnings +from .utils import check_shape __all__ = [ 'create_tensor', 'create_parameter', 'create_global_var', 'cast', @@ -652,12 +653,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): Returns: Tensor: Tensor which is created according to shape and dtype. - Raises: - TypeError: The dtype must be one of bool, float16, float32, float64, int32 and int64 - and the data type of ``out`` must be the same as the ``dtype``. - TypeError: The shape must be one of list, tuple and Tensor, the data type of ``shape`` - must be int32 or int64 when ``shape`` is a Tensor - Examples: .. code-block:: python @@ -713,14 +708,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): value = cast(value, dtype) inputs['ValueTensor'] = value + check_shape(shape) check_dtype(dtype, 'dtype', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') - if isinstance(shape, Variable): - check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant') - if out is not None: check_variable_and_dtype(out, 'out', [convert_dtype(dtype)], 'fill_constant') @@ -1045,10 +1038,6 @@ def ones(shape, dtype, force_cpu=False): Returns: Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1. - Raises: - TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64. - TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must - be int32 or int64 when it's a Tensor. Examples: .. code-block:: python @@ -1081,10 +1070,6 @@ def zeros(shape, dtype, force_cpu=False, name=None): Returns: Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0. - Raises: - TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64. - TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must - be int32 or int64 when it's a Tensor. Examples: .. code-block:: python diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index f6b947dd5e817d36060e1e6ea73d888da823ce96..2095c9957e75b94396e573eba341f4cfded5dbc8 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -20,6 +20,7 @@ import numpy as np from ..framework import Variable from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..layer_helper import LayerHelper +from sys import version_info def convert_to_list(value, n, name, dtype=np.int): @@ -358,3 +359,22 @@ def convert_shape_to_list(shape): else: shape = list(shape.numpy().astype(int)) return shape + + +def check_shape(shape): + """ + Check shape type and shape elements type before passing it to fill_constant + """ + if isinstance(shape, Variable): + check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant') + else: + for ele in shape: + if not isinstance(ele, Variable): + if ele < 0: + raise ValueError( + "All elements in ``shape`` must be positive when it's a list or tuple" + ) + if not isinstance(ele, six.integer_types): + raise TypeError( + "All elements in ``shape`` must be integers when it's a list or tuple" + ) diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 3475320eeebc55a14dd569410610b70ae35e65a3..43069470680c7d49071ce54bf3649962c56f06ea 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -350,6 +350,14 @@ class TestFillConstantOpError(unittest.TestCase): dtype='int16', out=x1) + self.assertRaises( + TypeError, + fluid.layers.fill_constant, + shape=[1.1], + value=5, + dtype='float32', + out=x1) + # The argument dtype of fill_constant_op must be one of bool, float16, #float32, float64, int32 or int64 x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")