From 980227f9740b4f656e43bc99e0cc84a13185d5c1 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 1 Apr 2021 11:24:33 +0800 Subject: [PATCH] Support uint8_t for fill_constant_op (#31911) --- paddle/fluid/operators/fill_constant_op.cc | 1 + paddle/fluid/operators/fill_constant_op.cu.cc | 1 + paddle/fluid/operators/math/math_function.cc | 1 + paddle/fluid/operators/math/math_function.cu | 1 + python/paddle/fluid/layers/tensor.py | 13 +++++++------ .../fluid/tests/unittests/test_fill_constant_op.py | 8 +------- python/paddle/fluid/tests/unittests/test_full_op.py | 5 +---- 7 files changed, 13 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 8a96d057cb..caa2930990 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -149,6 +149,7 @@ REGISTER_OPERATOR( REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, diff --git a/paddle/fluid/operators/fill_constant_op.cu.cc b/paddle/fluid/operators/fill_constant_op.cu.cc index 78c62a4053..e784c20b8b 100644 --- a/paddle/fluid/operators/fill_constant_op.cu.cc +++ b/paddle/fluid/operators/fill_constant_op.cu.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel, ops::FillConstantKernel, + ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, ops::FillConstantKernel, diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index a61b50faa7..5242d03c11 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -51,6 +51,7 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; +template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index cc8925fcf8..2b93cd9260 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -35,6 +35,7 @@ using complex128 = paddle::platform::complex128; template struct SetConstant; template struct SetConstant; template struct SetConstant; +template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 84f99962e8..7458466b02 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -635,7 +635,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64. dtype(np.dtype|str): Data type of the output Tensor which can - be float16, float32, float64, int32, int64. + be float16, float32, float64, uint8, int32, int64. value(bool|float|int|Tensor): The constant value used to initialize the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor. force_cpu(bool, optional): data should be on CPU if it's true, default value is False. @@ -673,7 +673,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): attrs = {'force_cpu': force_cpu} dtype = convert_dtype(dtype) if not isinstance(value, Variable): - if dtype in ['int64', 'int32']: + if dtype in ['uint8', 'int64', 'int32']: attrs['str_value'] = str(int(value)) attrs['value'] = int(value) else: @@ -686,7 +686,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): out = _varbase_creator(dtype=dtype) if isinstance(value, Variable): - if dtype in ['int64', 'int32']: + if dtype in ['uint8', 'int64', 'int32']: attrs['str_value'] = str(int(value.numpy().item(0))) else: attrs['str_value'] = str(float(value.numpy().item(0))) @@ -706,9 +706,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): inputs['ValueTensor'] = value check_shape(shape) - check_dtype(dtype, 'dtype', - ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], - 'fill_constant') + check_dtype( + dtype, 'dtype', + ['bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'], + 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') if out is not None: diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index c305f71aa5..0dd78ea53c 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -375,15 +375,9 @@ class TestFillConstantOpError(unittest.TestCase): out=x1) # The argument dtype of fill_constant_op must be one of bool, float16, - #float32, float64, int32 or int64 + #float32, float64, uint8, int32 or int64 x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32") - self.assertRaises( - TypeError, - fluid.layers.fill_constant, - shape=[1], - value=5, - dtype='uint8') self.assertRaises( TypeError, fluid.layers.fill_constant, diff --git a/python/paddle/fluid/tests/unittests/test_full_op.py b/python/paddle/fluid/tests/unittests/test_full_op.py index 2d850db783..19944aba46 100644 --- a/python/paddle/fluid/tests/unittests/test_full_op.py +++ b/python/paddle/fluid/tests/unittests/test_full_op.py @@ -84,10 +84,7 @@ class TestFullOpError(unittest.TestCase): TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4') # The argument dtype of full must be one of bool, float16, - #float32, float64, int32 or int64 - - self.assertRaises( - TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint8') + #float32, float64, uint8, int32 or int64 # The argument shape's type of full_op must be list, tuple or Variable. def test_shape_type(): -- GitLab