未验证 提交 e8edfc7d 编写于 作者: A AshburnLee 提交者: GitHub

add shape check for fill_constant OP and remove doc of type error (#26919) (#27080)

上级 66144efc
......@@ -29,6 +29,7 @@ from ..data_feeder import check_variable_and_dtype, check_type, check_dtype, con
from paddle.utils import deprecated
import numpy
import warnings
from .utils import check_shape
__all__ = [
'create_tensor', 'create_parameter', 'create_global_var', 'cast',
......@@ -652,12 +653,6 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
Returns:
Tensor: Tensor which is created according to shape and dtype.
Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32 and int64
and the data type of ``out`` must be the same as the ``dtype``.
TypeError: The shape must be one of list, tuple and Tensor, the data type of ``shape``
must be int32 or int64 when ``shape`` is a Tensor
Examples:
.. code-block:: python
......@@ -713,14 +708,12 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
value = cast(value, dtype)
inputs['ValueTensor'] = value
check_shape(shape)
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
if isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant')
if out is not None:
check_variable_and_dtype(out, 'out', [convert_dtype(dtype)],
'fill_constant')
......@@ -1045,10 +1038,6 @@ def ones(shape, dtype, force_cpu=False):
Returns:
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
Examples:
.. code-block:: python
......@@ -1081,10 +1070,6 @@ def zeros(shape, dtype, force_cpu=False, name=None):
Returns:
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0.
Raises:
TypeError: The ``dtype`` must be one of bool, float16, float32, float64, int32, int64.
TypeError: The ``shape`` must be one of list, tuple and Tensor. The data type of ``shape`` must
be int32 or int64 when it's a Tensor.
Examples:
.. code-block:: python
......
......@@ -20,6 +20,7 @@ import numpy as np
from ..framework import Variable
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..layer_helper import LayerHelper
from sys import version_info
def convert_to_list(value, n, name, dtype=np.int):
......@@ -358,3 +359,22 @@ def convert_shape_to_list(shape):
else:
shape = list(shape.numpy().astype(int))
return shape
def check_shape(shape):
"""
Check shape type and shape elements type before passing it to fill_constant
"""
if isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'fill_constant')
else:
for ele in shape:
if not isinstance(ele, Variable):
if ele < 0:
raise ValueError(
"All elements in ``shape`` must be positive when it's a list or tuple"
)
if not isinstance(ele, six.integer_types):
raise TypeError(
"All elements in ``shape`` must be integers when it's a list or tuple"
)
......@@ -350,6 +350,14 @@ class TestFillConstantOpError(unittest.TestCase):
dtype='int16',
out=x1)
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1.1],
value=5,
dtype='float32',
out=x1)
# The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册