未验证 提交 980227f9 编写于 作者: Z Zhang Zheng 提交者: GitHub

Support uint8_t for fill_constant_op (#31911)

上级 07741593
......@@ -149,6 +149,7 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
......
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<int>,
ops::FillConstantKernel<bool>,
......
......@@ -51,6 +51,7 @@ template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
template struct SetConstant<platform::XPUDeviceContext, float>;
template struct SetConstant<platform::XPUDeviceContext, double>;
template struct SetConstant<platform::XPUDeviceContext, uint8_t>;
template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
......
......@@ -35,6 +35,7 @@ using complex128 = paddle::platform::complex128;
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>;
template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, uint8_t>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
......
......@@ -635,7 +635,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor with date type int32 or int64.
dtype(np.dtype|str): Data type of the output Tensor which can
be float16, float32, float64, int32, int64.
be float16, float32, float64, uint8, int32, int64.
value(bool|float|int|Tensor): The constant value used to initialize
the Tensor to be created. If ``value`` is an Tensor, it should be an 1-D Tensor.
force_cpu(bool, optional): data should be on CPU if it's true, default value is False.
......@@ -673,7 +673,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs = {'force_cpu': force_cpu}
dtype = convert_dtype(dtype)
if not isinstance(value, Variable):
if dtype in ['int64', 'int32']:
if dtype in ['uint8', 'int64', 'int32']:
attrs['str_value'] = str(int(value))
attrs['value'] = int(value)
else:
......@@ -686,7 +686,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
out = _varbase_creator(dtype=dtype)
if isinstance(value, Variable):
if dtype in ['int64', 'int32']:
if dtype in ['uint8', 'int64', 'int32']:
attrs['str_value'] = str(int(value.numpy().item(0)))
else:
attrs['str_value'] = str(float(value.numpy().item(0)))
......@@ -706,8 +706,9 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
inputs['ValueTensor'] = value
check_shape(shape)
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
check_dtype(
dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'],
'fill_constant')
check_type(shape, 'shape', (Variable, list, tuple), 'fill_constant')
......
......@@ -375,15 +375,9 @@ class TestFillConstantOpError(unittest.TestCase):
out=x1)
# The argument dtype of fill_constant_op must be one of bool, float16,
#float32, float64, int32 or int64
#float32, float64, uint8, int32 or int64
x2 = fluid.layers.data(name='x2', shape=[1], dtype="int32")
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
shape=[1],
value=5,
dtype='uint8')
self.assertRaises(
TypeError,
fluid.layers.fill_constant,
......
......@@ -84,10 +84,7 @@ class TestFullOpError(unittest.TestCase):
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint4')
# The argument dtype of full must be one of bool, float16,
#float32, float64, int32 or int64
self.assertRaises(
TypeError, paddle.full, shape=[1], fill_value=5, dtype='uint8')
#float32, float64, uint8, int32 or int64
# The argument shape's type of full_op must be list, tuple or Variable.
def test_shape_type():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册