未验证 提交 ccabafa6 编写于 作者: T TeslaZhao 提交者: GitHub

OP:strided_slice_op supports bool type inputs (#33373) (#33393)

* Fix two english api documents, transpose and strided_slice

* OP:strided_slice_op supports bool type inputs
上级 3c22b174
......@@ -324,6 +324,7 @@ REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
REGISTER_OP_CPU_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -335,6 +336,7 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
......@@ -30,7 +31,8 @@ REGISTER_OP_CUDA_KERNEL(
REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
......
......@@ -11075,7 +11075,7 @@ def strided_slice(input, axes, starts, ends, strides):
Then:
result = [ [2], ]
Args:
input (Variable): An N-D ``Tensor`` or ``LoDTensor`` . The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
input (Variable): An N-D ``Tensor`` or ``LoDTensor`` . The data type is ``bool``, ``float32``, ``float64``, ``int32`` or ``int64``.
axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to.
It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`.
starts (list|tuple|Variable): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of
......@@ -11126,7 +11126,7 @@ def strided_slice(input, axes, starts, ends, strides):
helper = LayerHelper('strided_slice', **locals())
check_variable_and_dtype(input, 'input',
['float32', 'float64', 'int32', 'int64'],
['bool', 'float32', 'float64', 'int32', 'int64'],
'strided_slice')
check_type(axes, 'axes', (list, tuple), 'strided_slice')
check_type(starts, 'starts', (list, tuple, Variable), 'strided_slice')
......
......@@ -216,6 +216,71 @@ class TestStrideSliceOp13(TestStrideSliceOp):
self.infer_flags = [1, 1, 1, 1, 1]
class TestStrideSliceOpBool(TestStrideSliceOp):
def test_check_grad(self):
pass
class TestStrideSliceOpBool1D(TestStrideSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(100).astype("bool")
self.axes = [0]
self.starts = [3]
self.ends = [8]
self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceOpBool2D(TestStrideSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(10, 10).astype("bool")
self.axes = [0, 1]
self.starts = [1, 0]
self.ends = [2, 2]
self.strides = [1, 1]
self.infer_flags = [1, 1]
class TestStrideSliceOpBool3D(TestStrideSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 4, 10).astype("bool")
self.axes = [0, 1, 2]
self.starts = [0, -1, 0]
self.ends = [2, -3, 5]
self.strides = [1, -1, 1]
self.infer_flags = [1, 1, 1]
class TestStrideSliceOpBool4D(TestStrideSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4).astype("bool")
self.axes = [0, 1, 2, 3]
self.starts = [1, 0, 0, 0]
self.ends = [2, 2, 3, 4]
self.strides = [1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1]
class TestStrideSliceOpBool5D(TestStrideSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 4, 5).astype("bool")
self.axes = [0, 1, 2, 3, 4]
self.starts = [1, 0, 0, 0, 0]
self.ends = [2, 2, 3, 4, 4]
self.strides = [1, 1, 1, 1, 1]
self.infer_flags = [1, 1, 1, 1]
class TestStrideSliceOpBool6D(TestStrideSliceOpBool):
def initTestCase(self):
self.input = np.random.rand(3, 3, 3, 6, 7, 8).astype("bool")
self.axes = [0, 1, 2, 3, 4, 5]
self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
class TestStridedSliceOp_starts_ListTensor(OpTest):
def setUp(self):
self.op_type = "strided_slice"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册