未验证 提交 297290a8 编写于 作者: D danleifeng 提交者: GitHub

add uint8 type for flatten op (#32120)

* add uint8 type for flatten;test=develop
上级 4935b8e7
...@@ -429,6 +429,7 @@ REGISTER_OPERATOR(flatten_contiguous_range_grad, ...@@ -429,6 +429,7 @@ REGISTER_OPERATOR(flatten_contiguous_range_grad,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CPUDeviceContext, float>, flatten, ops::FlattenKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, double>, ops::FlattenKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, int>, ops::FlattenKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::FlattenKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::FlattenKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::FlattenKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -436,12 +437,14 @@ REGISTER_OP_CPU_KERNEL( ...@@ -436,12 +437,14 @@ REGISTER_OP_CPU_KERNEL(
flatten_grad, flatten_grad,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, float>, ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, double>, ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int>, ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::FlattenGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
flatten2, ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, float>, flatten2, ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, double>, ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int>, ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Flatten2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -449,6 +452,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -449,6 +452,7 @@ REGISTER_OP_CPU_KERNEL(
flatten2_grad, flatten2_grad,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, double>, ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int>, ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int8_t>, ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int64_t>); ops::Flatten2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
...@@ -458,6 +462,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -458,6 +462,8 @@ REGISTER_OP_CPU_KERNEL(
float>, float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext, ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
double>, double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext, int>, ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext, ops::FlattenContiguousRangeKernel<paddle::platform::CPUDeviceContext,
int8_t>, int8_t>,
...@@ -469,6 +475,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -469,6 +475,8 @@ REGISTER_OP_CPU_KERNEL(
float>, float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
double>, double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
int>, int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CPUDeviceContext,
......
...@@ -19,6 +19,7 @@ namespace ops = paddle::operators; ...@@ -19,6 +19,7 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>, flatten, ops::FlattenKernel<paddle::platform::CUDADeviceContext, float>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, double>, ops::FlattenKernel<paddle::platform::CUDADeviceContext, double>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, int>, ops::FlattenKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::FlattenKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::FlattenKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::FlattenKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -26,12 +27,14 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -26,12 +27,14 @@ REGISTER_OP_CUDA_KERNEL(
flatten_grad, flatten_grad,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, float>, ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, double>, ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int>, ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::FlattenGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
flatten2, ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, float>, flatten2, ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, float>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, double>, ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, double>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int>, ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Flatten2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -39,6 +42,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -39,6 +42,7 @@ REGISTER_OP_CUDA_KERNEL(
flatten2_grad, flatten2_grad,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, float>, ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, float>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, double>, ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int>, ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int8_t>, ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::Flatten2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
...@@ -48,6 +52,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -48,6 +52,8 @@ REGISTER_OP_CUDA_KERNEL(
float>, float>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
double>, double>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
uint8_t>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, int>, ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeKernel<paddle::platform::CUDADeviceContext,
int8_t>, int8_t>,
...@@ -59,6 +65,8 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -59,6 +65,8 @@ REGISTER_OP_CUDA_KERNEL(
float>, float>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
double>, double>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
int>, int>,
ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext, ops::FlattenContiguousRangeGradKernel<paddle::platform::CUDADeviceContext,
......
...@@ -9940,7 +9940,7 @@ def flatten(x, axis=1, name=None): ...@@ -9940,7 +9940,7 @@ def flatten(x, axis=1, name=None):
Args: Args:
x (Variable): A tensor of rank >= axis. A tensor with type float32, x (Variable): A tensor of rank >= axis. A tensor with type float32,
float64, int8, int32, int64. float64, int8, int32, int64, uint8.
axis (int): Indicate up to which input dimensions (exclusive) should axis (int): Indicate up to which input dimensions (exclusive) should
be flattened to the outer dimension of the output. be flattened to the outer dimension of the output.
The value for axis must be in the range [0, R], where R The value for axis must be in the range [0, R], where R
...@@ -9962,14 +9962,17 @@ def flatten(x, axis=1, name=None): ...@@ -9962,14 +9962,17 @@ def flatten(x, axis=1, name=None):
.. code-block:: python .. code-block:: python
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
paddle.enable_static()
x = fluid.data(name="x", shape=[4, 4, 3], dtype="float32") x = fluid.data(name="x", shape=[4, 4, 3], dtype="float32")
# x shape is [4, 4, 3] # x shape is [4, 4, 3]
out = fluid.layers.flatten(x=x, axis=2) out = fluid.layers.flatten(x=x, axis=2)
# out shape is [16, 3] # out shape is [16, 3]
""" """
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64'], 'flatten') x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64', 'uint8'],
'flatten')
helper = LayerHelper('flatten', **locals()) helper = LayerHelper('flatten', **locals())
if not (isinstance(x, Variable)): if not (isinstance(x, Variable)):
......
...@@ -81,7 +81,7 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(TypeError, test_Variable) self.assertRaises(TypeError, test_Variable)
def test_type(): def test_type():
# dtype must be float32, float64, int8, int32, int64. # dtype must be float32, float64, int8, int32, int64, uint8.
x2 = fluid.layers.data( x2 = fluid.layers.data(
name='x2', shape=[3, 2, 4, 5], dtype='float16') name='x2', shape=[3, 2, 4, 5], dtype='float16')
fluid.layers.flatten(x2, axis=1) fluid.layers.flatten(x2, axis=1)
......
...@@ -166,7 +166,7 @@ class TestFlatten2OpError(unittest.TestCase): ...@@ -166,7 +166,7 @@ class TestFlatten2OpError(unittest.TestCase):
self.assertRaises(ValueError, test_ValueError3) self.assertRaises(ValueError, test_ValueError3)
def test_type(): def test_type():
# dtype must be float32, float64, int8, int32, int64. # dtype must be float32, float64, int8, int32, int64, uint8.
x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] *
image_shape[3]).reshape(image_shape) / 100. image_shape[3]).reshape(image_shape) / 100.
x2 = x2.astype('float16') x2 = x2.astype('float16')
......
...@@ -212,7 +212,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -212,7 +212,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
Args: Args:
x (Tensor): A tensor of number of dimentions >= axis. A tensor with data type float32, x (Tensor): A tensor of number of dimentions >= axis. A tensor with data type float32,
float64, int8, int32, int64. float64, int8, int32, int64, uint8.
start_axis (int): the start axis to flatten start_axis (int): the start axis to flatten
stop_axis (int): the stop axis to flatten stop_axis (int): the stop axis to flatten
name(str, Optional): For details, please refer to :ref:`api_guide_Name`. name(str, Optional): For details, please refer to :ref:`api_guide_Name`.
...@@ -249,7 +249,8 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None): ...@@ -249,7 +249,8 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
raise ValueError("The input x should be a Tensor") raise ValueError("The input x should be a Tensor")
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64'], 'flatten') x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64', 'uint8'],
'flatten')
helper = LayerHelper('flatten', **locals()) helper = LayerHelper('flatten', **locals())
x_dim = len(x.shape) x_dim = len(x.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册