From ccabafa6df9d98103f675bf4733039a8cfa66931 Mon Sep 17 00:00:00 2001 From: TeslaZhao Date: Tue, 8 Jun 2021 11:10:48 +0800 Subject: [PATCH] OP:strided_slice_op supports bool type inputs (#33373) (#33393) * Fix two english api documents, transpose and strided_slice * OP:strided_slice_op supports bool type inputs --- paddle/fluid/operators/strided_slice_op.cc | 2 + paddle/fluid/operators/strided_slice_op.cu | 4 +- python/paddle/fluid/layers/nn.py | 4 +- .../tests/unittests/test_strided_slice_op.py | 65 +++++++++++++++++++ 4 files changed, 72 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index e49476e4dc7..effacc7591d 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -324,6 +324,7 @@ REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad, REGISTER_OP_CPU_KERNEL( strided_slice, + ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, @@ -335,6 +336,7 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL( strided_slice_grad, + ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, diff --git a/paddle/fluid/operators/strided_slice_op.cu b/paddle/fluid/operators/strided_slice_op.cu index b85403b1c5b..edf843bb3ee 100644 --- a/paddle/fluid/operators/strided_slice_op.cu +++ b/paddle/fluid/operators/strided_slice_op.cu @@ -19,6 +19,7 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( strided_slice, + ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, ops::StridedSliceKernel, @@ -30,7 +31,8 @@ REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL( strided_slice_grad, - ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel, + ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, ops::StridedSliceGradKernel, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9ac314528dc..2bac3289d1b 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11075,7 +11075,7 @@ def strided_slice(input, axes, starts, ends, strides): Then: result = [ [2], ] Args: - input (Variable): An N-D ``Tensor`` or ``LoDTensor`` . The data type is ``float32``, ``float64``, ``int32`` or ``int64``. + input (Variable): An N-D ``Tensor`` or ``LoDTensor`` . The data type is ``bool``, ``float32``, ``float64``, ``int32`` or ``int64``. axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to. It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`. starts (list|tuple|Variable): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of @@ -11126,7 +11126,7 @@ def strided_slice(input, axes, starts, ends, strides): helper = LayerHelper('strided_slice', **locals()) check_variable_and_dtype(input, 'input', - ['float32', 'float64', 'int32', 'int64'], + ['bool', 'float32', 'float64', 'int32', 'int64'], 'strided_slice') check_type(axes, 'axes', (list, tuple), 'strided_slice') check_type(starts, 'starts', (list, tuple, Variable), 'strided_slice') diff --git a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py index 71550c8f247..ebf7c01e2ca 100644 --- a/python/paddle/fluid/tests/unittests/test_strided_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_strided_slice_op.py @@ -216,6 +216,71 @@ class TestStrideSliceOp13(TestStrideSliceOp): self.infer_flags = [1, 1, 1, 1, 1] +class TestStrideSliceOpBool(TestStrideSliceOp): + def test_check_grad(self): + pass + + +class TestStrideSliceOpBool1D(TestStrideSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(100).astype("bool") + self.axes = [0] + self.starts = [3] + self.ends = [8] + self.strides = [1] + self.infer_flags = [1] + + +class TestStrideSliceOpBool2D(TestStrideSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(10, 10).astype("bool") + self.axes = [0, 1] + self.starts = [1, 0] + self.ends = [2, 2] + self.strides = [1, 1] + self.infer_flags = [1, 1] + + +class TestStrideSliceOpBool3D(TestStrideSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 4, 10).astype("bool") + self.axes = [0, 1, 2] + self.starts = [0, -1, 0] + self.ends = [2, -3, 5] + self.strides = [1, -1, 1] + self.infer_flags = [1, 1, 1] + + +class TestStrideSliceOpBool4D(TestStrideSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4).astype("bool") + self.axes = [0, 1, 2, 3] + self.starts = [1, 0, 0, 0] + self.ends = [2, 2, 3, 4] + self.strides = [1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1] + + +class TestStrideSliceOpBool5D(TestStrideSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4, 5).astype("bool") + self.axes = [0, 1, 2, 3, 4] + self.starts = [1, 0, 0, 0, 0] + self.ends = [2, 2, 3, 4, 4] + self.strides = [1, 1, 1, 1, 1] + self.infer_flags = [1, 1, 1, 1] + + +class TestStrideSliceOpBool6D(TestStrideSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 6, 7, 8).astype("bool") + self.axes = [0, 1, 2, 3, 4, 5] + self.starts = [1, 0, 0, 0, 1, 2] + self.ends = [2, 2, 3, 1, 2, 8] + self.strides = [1, 1, 1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1, 1] + + class TestStridedSliceOp_starts_ListTensor(OpTest): def setUp(self): self.op_type = "strided_slice" -- GitLab