未验证 提交 4b6f8099 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support int16_t in fill_constant_op (#35619)

上级 75d5e3bf
......@@ -149,8 +149,8 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
......
......@@ -18,8 +18,8 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
......@@ -41,6 +41,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
template struct SetConstant<platform::CPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::CPUDeviceContext, float>;
template struct SetConstant<platform::CPUDeviceContext, double>;
template struct SetConstant<platform::CPUDeviceContext, int16_t>;
template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
......@@ -56,6 +57,7 @@ template struct SetConstant<platform::XPUDeviceContext, platform::bfloat16>;
template struct SetConstant<platform::XPUDeviceContext, float>;
template struct SetConstant<platform::XPUDeviceContext, double>;
template struct SetConstant<platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<platform::XPUDeviceContext, int16_t>;
template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
......
......@@ -35,6 +35,7 @@ template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, uint8_t>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int16_t>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext,
......
......@@ -674,7 +674,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, uint8, int32, int64.
be float16, float32, float64, uint8, int16, int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
......@@ -712,7 +712,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['uint8', 'int64', 'int32']:
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
......@@ -725,7 +725,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out = _varbase_creator(dtype=dtype)
if isinstance(value, Variable):
if dtype in ['uint8', 'int64', 'int32']:
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value.numpy().item(0)))
else:
attrs['str_value'] = str(float(value.numpy().item(0)))
......@@ -745,10 +745,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs['ValueTensor'] = value
check_shape(shape)
check_dtype(
dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'],
'fill_constant')
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'uint8', 'int16', 'int32',
'int64'
], 'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
if out is not None:
......
......@@ -358,13 +358,6 @@ class TestFillConstantOpError(unittest.TestCase):
shape=[1],
value=5,
dtype='uint4')
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='int16',
out=x1)
self.assertRaises(
TypeError,
......@@ -375,7 +368,7 @@ class TestFillConstantOpError(unittest.TestCase):
out=x1)
# The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, uint8, int32 or int64
#float32, float64, uint8, int16, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises(
......
......@@ -84,7 +84,7 @@ class TestFullOpError(unittest.TestCase):
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4')
# The argument dtype of full must be one of bool, float16,
#float32, float64, uint8, int32 or int64
#float32, float64, uint8, int16, int32 or int64
# The argument shape's type of full_op must be list, tuple or Variable.
def test_shape_type():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册