未验证 提交 980227f9 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support uint8_t for fill_constant_op (#31911)

上级 07741593
...@@ -149,6 +149,7 @@ REGISTER_OPERATOR( ...@@ -149,6 +149,7 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>, REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<double>,
ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<bool>,
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>, REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<double>,
ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<bool>,
......
...@@ -51,6 +51,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::complex128>; ...@@ -51,6 +51,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::XPUDeviceContext, platform::float16>; template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
template struct SetConstant<platform::XPUDeviceContext, float>; template struct SetConstant<platform::XPUDeviceContext, float>;
template struct SetConstant<platform::XPUDeviceContext, double>; template struct SetConstant<platform::XPUDeviceContext, double>;
template struct SetConstant<platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<platform::XPUDeviceContext, int>; template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>; template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>; template struct SetConstant<platform::XPUDeviceContext, bool>;
......
...@@ -35,6 +35,7 @@ using complex128 = paddle::platform::complex128; ...@@ -35,6 +35,7 @@ using complex128 = paddle::platform::complex128;
template struct SetConstant<platform::CUDADeviceContext, platform::float16>; template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>; template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>; template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, uint8_t>;
template struct SetConstant<platform::CUDADeviceContext, int>; template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>; template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>; template struct SetConstant<platform::CUDADeviceContext, bool>;
......
...@@ -635,7 +635,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -635,7 +635,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64. If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, int32, int64. be float16, float32, float64, uint8, int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor. the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False. force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
...@@ -673,7 +673,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -673,7 +673,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs = {'force_cpu': force_cpu} attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype) dtype = convert_dtype(dtype)
if not isinstance(value, Variable): if not isinstance(value, Variable):
if dtype in ['int64', 'int32']: if dtype in ['uint8', 'int64', 'int32']:
attrs['str_value'] = str(int(value)) attrs['str_value'] = str(int(value))
attrs['value'] = int(value) attrs['value'] = int(value)
else: else:
...@@ -686,7 +686,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -686,7 +686,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
if isinstance(value, Variable): if isinstance(value, Variable):
if dtype in ['int64', 'int32']: if dtype in ['uint8', 'int64', 'int32']:
attrs['str_value'] = str(int(value.numpy().item(0))) attrs['str_value'] = str(int(value.numpy().item(0)))
else: else:
attrs['str_value'] = str(float(value.numpy().item(0))) attrs['str_value'] = str(float(value.numpy().item(0)))
...@@ -706,8 +706,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -706,8 +706,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs['ValueTensor'] = value inputs['ValueTensor'] = value
check_shape(shape) check_shape(shape)
check_dtype(dtype, 'dtype', check_dtype(
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'],
'fill_constant') 'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
......
...@@ -375,15 +375,9 @@ class TestFillConstantOpError(unittest.TestCase): ...@@ -375,15 +375,9 @@ class TestFillConstantOpError(unittest.TestCase):
out=x1) out=x1)
# The argument dtype of fill_constant_op must be one of bool, float16, # The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, int32 or int64 #float32, float64, uint8, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32") x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='uint8')
self.assertRaises( self.assertRaises(
TypeError, TypeError,
fluid.layers.fill_constant, fluid.layers.fill_constant,
......
...@@ -84,10 +84,7 @@ class TestFullOpError(unittest.TestCase): ...@@ -84,10 +84,7 @@ class TestFullOpError(unittest.TestCase):
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4') TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4')
# The argument dtype of full must be one of bool, float16, # The argument dtype of full must be one of bool, float16,
#float32, float64, int32 or int64 #float32, float64, uint8, int32 or int64
self.assertRaises(
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint8')
# The argument shape's type of full_op must be list, tuple or Variable. # The argument shape's type of full_op must be list, tuple or Variable.
def test_shape_type(): def test_shape_type():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册