未验证 提交 4b6f8099 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support int16_t in fill_constant_op (#35619)

上级 75d5e3bf
...@@ -149,8 +149,8 @@ REGISTER_OPERATOR( ...@@ -149,8 +149,8 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>, fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>, ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>, ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>, ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>, ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex<float>>, ops::FillConstantKernel<paddle::platform::complex<float>>,
......
...@@ -18,8 +18,8 @@ namespace ops = paddle::operators; ...@@ -18,8 +18,8 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fill_constant, ops::FillConstantKernel<float>, fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>, ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>, ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>, ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::complex<float>>, ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>); ops::FillConstantKernel<paddle::platform::complex<double>>);
...@@ -41,6 +41,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::float16>; ...@@ -41,6 +41,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>; template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::CPUDeviceContext, float>; template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>; template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int16_t>;
template struct SetConstant<platform::CPUDeviceContext, int>; template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>; template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>; template struct SetConstant<platform::CPUDeviceContext, bool>;
...@@ -56,6 +57,7 @@ template struct SetConstant<platform::XPUDeviceContext, platform::bfloat16>; ...@@ -56,6 +57,7 @@ template struct SetConstant<platform::XPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::XPUDeviceContext, float>; template struct SetConstant<platform::XPUDeviceContext, float>;
template struct SetConstant<platform::XPUDeviceContext, double>; template struct SetConstant<platform::XPUDeviceContext, double>;
template struct SetConstant<platform::XPUDeviceContext, uint8_t>; template struct SetConstant<platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<platform::XPUDeviceContext, int16_t>;
template struct SetConstant<platform::XPUDeviceContext, int>; template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>; template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>; template struct SetConstant<platform::XPUDeviceContext, bool>;
......
...@@ -35,6 +35,7 @@ template struct SetConstant<platform::CUDADeviceContext, float>; ...@@ -35,6 +35,7 @@ template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>; template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, uint8_t>; template struct SetConstant<platform::CUDADeviceContext, uint8_t>;
template struct SetConstant<platform::CUDADeviceContext, int>; template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int16_t>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>; template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>; template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext, template struct SetConstant<platform::CUDADeviceContext,
......
...@@ -674,7 +674,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -674,7 +674,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64. If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, uint8, int32, int64. be float16, float32, float64, uint8, int16, int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor. the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False. force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
...@@ -712,7 +712,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -712,7 +712,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs = {'force_cpu': force_cpu} attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype) dtype = convert_dtype(dtype)
if not isinstance(value, Variable): if not isinstance(value, Variable):
if dtype in ['uint8', 'int64', 'int32']: if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value)) attrs['str_value'] = str(int(value))
attrs['value'] = int(value) attrs['value'] = int(value)
else: else:
...@@ -725,7 +725,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -725,7 +725,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
if isinstance(value, Variable): if isinstance(value, Variable):
if dtype in ['uint8', 'int64', 'int32']: if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value.numpy().item(0))) attrs['str_value'] = str(int(value.numpy().item(0)))
else: else:
attrs['str_value'] = str(float(value.numpy().item(0))) attrs['str_value'] = str(float(value.numpy().item(0)))
...@@ -745,10 +745,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -745,10 +745,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs['ValueTensor'] = value inputs['ValueTensor'] = value
check_shape(shape) check_shape(shape)
check_dtype( check_dtype(dtype, 'dtype', [
dtype, 'dtype', 'bool', 'float16', 'float32', 'float64', 'uint8', 'int16', 'int32',
['bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'], 'int64'
'fill_constant') ], 'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant') check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
if out is not None: if out is not None:
......
...@@ -358,13 +358,6 @@ class TestFillConstantOpError(unittest.TestCase): ...@@ -358,13 +358,6 @@ class TestFillConstantOpError(unittest.TestCase):
shape=[1], shape=[1],
value=5, value=5,
dtype='uint4') dtype='uint4')
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='int16',
out=x1)
self.assertRaises( self.assertRaises(
TypeError, TypeError,
...@@ -375,7 +368,7 @@ class TestFillConstantOpError(unittest.TestCase): ...@@ -375,7 +368,7 @@ class TestFillConstantOpError(unittest.TestCase):
out=x1) out=x1)
# The argument dtype of fill_constant_op must be one of bool, float16, # The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, uint8, int32 or int64 #float32, float64, uint8, int16, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32") x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises( self.assertRaises(
......
...@@ -84,7 +84,7 @@ class TestFullOpError(unittest.TestCase): ...@@ -84,7 +84,7 @@ class TestFullOpError(unittest.TestCase):
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4') TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4')
# The argument dtype of full must be one of bool, float16, # The argument dtype of full must be one of bool, float16,
#float32, float64, uint8, int32 or int64 #float32, float64, uint8, int16, int32 or int64
# The argument shape's type of full_op must be list, tuple or Variable. # The argument shape's type of full_op must be list, tuple or Variable.
def test_shape_type(): def test_shape_type():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册