未验证 提交 d8437062 编写于 作者: A AshburnLee 提交者: GitHub

add shape check for fill_constant OP and remove doc of type error (#26919)

上级 2660ea37
...@@ -29,6 +29,7 @@ from ..data_feeder import check_variable_and_dtype, check_type, check_dtype, con ...@@ -29,6 +29,7 @@ from ..data_feeder import check_variable_and_dtype, check_type, check_dtype, con
from paddle.utils import deprecated from paddle.utils import deprecated
import numpy import numpy
import warnings import warnings
from .utils import check_shape
__all__ = [ __all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast', 'create_tensor', 'create_parameter', 'create_global_var', 'cast',
...@@ -657,12 +658,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -657,12 +658,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
Returns: Returns:
Tensor: Tensor which is created according to shape and dtype. Tensor: Tensor which is created according to shape and dtype.
Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32 and int64
and the data type of ``out`` must be the same as the ``dtype``.
TypeError: The shape must be one of list, tuple and Tensor, the data type of ``shape``
must be int32 or int64 when ``shape`` is a Tensor
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -718,14 +713,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -718,14 +713,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
value = cast(value, dtype) value = cast(value, dtype)
inputs['ValueTensor'] = value inputs['ValueTensor'] = value
check_shape(shape)
check_dtype(dtype, 'dtype', check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'fill_constant') 'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
if isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant')
if out is not None: if out is not None:
check_variable_and_dtype(out, 'out', [convert_dtype(dtype)], check_variable_and_dtype(out, 'out', [convert_dtype(dtype)],
'fill_constant') 'fill_constant')
...@@ -1050,10 +1043,6 @@ def ones(shape, dtype, force_cpu=False): ...@@ -1050,10 +1043,6 @@ def ones(shape, dtype, force_cpu=False):
Returns: Returns:
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1. Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1086,10 +1075,6 @@ def zeros(shape, dtype, force_cpu=False, name=None): ...@@ -1086,10 +1075,6 @@ def zeros(shape, dtype, force_cpu=False, name=None):
Returns: Returns:
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0. Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
from ..framework import Variable from ..framework import Variable
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from sys import version_info
def convert_to_list(value, n, name, dtype=np.int): def convert_to_list(value, n, name, dtype=np.int):
...@@ -358,3 +359,22 @@ def convert_shape_to_list(shape): ...@@ -358,3 +359,22 @@ def convert_shape_to_list(shape):
else: else:
shape = list(shape.numpy().astype(int)) shape = list(shape.numpy().astype(int))
return shape return shape
def check_shape(shape):
"""
Check shape type and shape elements type before passing it to fill_constant
"""
if isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant')
else:
for ele in shape:
if not isinstance(ele, Variable):
if ele < 0:
raise ValueError(
"All elements in ``shape`` must be positive when it's a list or tuple"
)
if not isinstance(ele, six.integer_types):
raise TypeError(
"All elements in ``shape`` must be integers when it's a list or tuple"
)
...@@ -350,6 +350,14 @@ class TestFillConstantOpError(unittest.TestCase): ...@@ -350,6 +350,14 @@ class TestFillConstantOpError(unittest.TestCase):
dtype='int16', dtype='int16',
out=x1) out=x1)
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1.1],
value=5,
dtype='float32',
out=x1)
# The argument dtype of fill_constant_op must be one of bool, float16, # The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, int32 or int64 #float32, float64, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32") x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册