From 5a4d21863b626fc540e2d0c4f32234836bc55010 Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Mon, 7 Nov 2022 15:00:34 +0800 Subject: [PATCH] add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun (#47368) * add stat tool * add roll and roll_grad kernels and strided_slice and strided_slice_grad kernels, test=kunlun --- .../fluid/platform/device/xpu/xpu2_op_list.h | 12 ++ .../phi/kernels/strided_slice_grad_kernel.cc | 10 + paddle/phi/kernels/strided_slice_kernel.cc | 10 + paddle/phi/kernels/xpu/roll_grad_kernel.cc | 73 +++++++ paddle/phi/kernels/xpu/roll_kernel.cc | 70 +++++++ .../kernels/xpu/stride_slice_grad_kernel.cc | 100 ++++++++++ paddle/phi/kernels/xpu/stride_slice_kernel.cc | 121 ++++++++++++ .../tests/unittests/xpu/test_roll_op_xpu.py | 92 +++++++++ .../xpu/test_strided_slice_op_xpu.py | 184 ++++++++++++++++++ 9 files changed, 672 insertions(+) create mode 100644 paddle/phi/kernels/xpu/roll_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/roll_kernel.cc create mode 100644 paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/stride_slice_kernel.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_roll_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_strided_slice_op_xpu.py diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 692f0ab377..4089119097 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -472,6 +472,8 @@ XPUOpMap& get_kl2_ops() { {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"roi_align_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"roll", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"roll_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), @@ -567,6 +569,16 @@ XPUOpMap& get_kl2_ops() { {"stack_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, + {"strided_slice", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, + {"strided_slice_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT16, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.cc b/paddle/phi/kernels/strided_slice_grad_kernel.cc index 8551fc6e84..af8994cd8c 100644 --- a/paddle/phi/kernels/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/strided_slice_grad_kernel.cc @@ -68,3 +68,13 @@ PD_REGISTER_KERNEL(strided_slice_grad, phi::dtype::complex, phi::dtype::complex) {} #endif +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(strided_slice_grad, + XPU, + ALL_LAYOUT, + phi::StridedSliceGradKernel, + int, + int16_t, + float, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/strided_slice_kernel.cc b/paddle/phi/kernels/strided_slice_kernel.cc index 037abf461a..3ceb805723 100644 --- a/paddle/phi/kernels/strided_slice_kernel.cc +++ b/paddle/phi/kernels/strided_slice_kernel.cc @@ -59,3 +59,13 @@ PD_REGISTER_KERNEL(strided_slice, phi::dtype::complex, phi::dtype::complex) {} #endif +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(strided_slice, + XPU, + ALL_LAYOUT, + phi::StridedSliceKernel, + int, + int16_t, + float, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/xpu/roll_grad_kernel.cc b/paddle/phi/kernels/xpu/roll_grad_kernel.cc new file mode 100644 index 0000000000..4d08a49172 --- /dev/null +++ b/paddle/phi/kernels/xpu/roll_grad_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roll_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void RollGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const IntArray& shifts, + const std::vector& axis, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + auto shifts_data = shifts.GetData(); + dev_ctx.template Alloc(x_grad); + DDim input_dim = x.dims(); + std::vector xshape; + size_t nums = shifts_data.size(); + for (int i = 0; i < input_dim.size(); ++i) { + xshape.emplace_back(input_dim[i]); + } + + auto dims = axis; + + // axis = none, reshape to 1-D tensor + if (dims.size() == 0) { + dims.push_back(0l); + input_dim = phi::Dim<1>(x.numel()); + } + std::vector shifts_in; + std::vector axis_in; + + for (size_t i = 0; i < nums; ++i) { + int a = dims[i]; + if (a < 0) { + a += (input_dim.size()); + } + axis_in.emplace_back(a); + int sh = (0 - shifts_data[i]) % input_dim[a]; + if (sh < 0) { + sh += input_dim[a]; + } + shifts_in.emplace_back(sh); + } + + int r = xpu::roll(dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(x_grad->data()), + xshape, + shifts_in, + axis_in); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "roll"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(roll_grad, XPU, ALL_LAYOUT, phi::RollGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/roll_kernel.cc b/paddle/phi/kernels/xpu/roll_kernel.cc new file mode 100644 index 0000000000..f68e378f35 --- /dev/null +++ b/paddle/phi/kernels/xpu/roll_kernel.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/roll_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void RollKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shifts, + const std::vector& axis, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + auto shifts_data = shifts.GetData(); + dev_ctx.template Alloc(out); + DDim input_dim = x.dims(); + std::vector xshape; + size_t nums = shifts_data.size(); + for (int i = 0; i < input_dim.size(); ++i) { + xshape.emplace_back(input_dim[i]); + } + + auto dims = axis; + + // axis = none, reshape to 1-D tensor + if (dims.size() == 0) { + dims.push_back(0l); + input_dim = phi::Dim<1>(x.numel()); + } + std::vector shifts_in; + std::vector axis_in; + + for (size_t i = 0; i < nums; ++i) { + int a = dims[i]; + if (a < 0) { + a += (input_dim.size()); + } + axis_in.emplace_back(a); + int sh = shifts_data[i] % input_dim[a]; + if (sh < 0) { + sh += input_dim[a]; + } + shifts_in.emplace_back(sh); + } + int r = xpu::roll(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + xshape, + shifts_in, + axis_in); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "roll"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(roll, XPU, ALL_LAYOUT, phi::RollKernel, float) {} diff --git a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc new file mode 100644 index 0000000000..70bd235688 --- /dev/null +++ b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void StridedSliceRawGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + DDim in_dims = x.dims(); + dev_ctx.template Alloc(x_grad); + + auto starts_ = starts.GetData(); + auto ends_ = ends.GetData(); + auto strides_ = strides.GetData(); + + std::vector xshape; + std::vector starts_in(in_dims.size(), 0); + std::vector ends_in; + std::vector strides_in(in_dims.size(), 1); + + for (int i = 0; i < in_dims.size(); ++i) { + xshape.emplace_back(in_dims[i]); + ends_in.emplace_back(in_dims[i]); + } + int num = axes.size(); + + for (int i = 0; i < num; ++i) { + PADDLE_ENFORCE_EQ( + strides_[i] > 0, + true, + errors::InvalidArgument("Currently, XPU strided slice kernel does not", + "support reverse strided slice")); + int cur_axe = axes[i]; + int st = starts_[i]; + if (st > xshape[cur_axe]) { + st = xshape[cur_axe]; + } + if (st < 0) { + st += xshape[cur_axe]; + } + starts_in[cur_axe] = st; + + int end = ends_[i]; + if (end > xshape[cur_axe]) { + end = xshape[cur_axe]; + } + if (end < 0) { + end += xshape[cur_axe]; + } + + ends_in[cur_axe] = end; + strides_in[cur_axe] = strides_[i]; + } + + int r = xpu::strided_slice_grad( + dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(x_grad->data()), + xshape, + starts_in, + ends_in, + strides_in); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(strided_slice_raw_grad, + XPU, + ALL_LAYOUT, + phi::StridedSliceRawGradKernel, + int, + int16_t, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/stride_slice_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_kernel.cc new file mode 100644 index 0000000000..517445bb71 --- /dev/null +++ b/paddle/phi/kernels/xpu/stride_slice_kernel.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/strided_slice.h" + +namespace phi { + +template +void StridedSliceRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const IntArray& starts, + const IntArray& ends, + const IntArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + DDim in_dims = x.dims(); + + auto starts_ = starts.GetData(); + auto ends_ = ends.GetData(); + auto strides_ = strides.GetData(); + + std::vector out_dims_vector(in_dims.size(), -1); + funcs::StridedSliceOutDims(starts_, + ends_, + strides_, + axes, + infer_flags, + in_dims, + decrease_axis, + out_dims_vector.data(), + axes.size(), + false); + DDim out_dims(phi::make_ddim(out_dims_vector)); + + out->Resize(out_dims); + dev_ctx.template Alloc(out); + + std::vector xshape; + std::vector starts_in(in_dims.size(), 0); + std::vector ends_in; + std::vector strides_in(in_dims.size(), 1); + + for (int i = 0; i < in_dims.size(); ++i) { + xshape.emplace_back(in_dims[i]); + ends_in.emplace_back(in_dims[i]); + } + + int num = axes.size(); + for (int i = 0; i < num; ++i) { + PADDLE_ENFORCE_EQ( + strides_[i] > 0, + true, + errors::InvalidArgument("Currently, XPU strided slice kernel does not ", + "support reverse strided slice.")); + int cur_axe = axes[i]; + int st = starts_[i]; + if (st > xshape[cur_axe]) { + st = xshape[cur_axe]; + } + if (st < 0) { + st += xshape[cur_axe]; + } + starts_in[cur_axe] = st; + + int end = ends_[i]; + if (end > xshape[cur_axe]) { + end = xshape[cur_axe]; + } + if (end < 0) { + end += xshape[cur_axe]; + } + + ends_in[cur_axe] = end; + PADDLE_ENFORCE_EQ( + st < end, + true, + errors::InvalidArgument("End index should be larger than", + "start Index, this OP does not support", + "reverse operator.")); + + strides_in[cur_axe] = strides_[i]; + } + + int r = xpu::strided_slice(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + xshape, + starts_in, + ends_in, + strides_in); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_slice"); +} + +} // namespace phi +PD_REGISTER_KERNEL(strided_slice_raw, + XPU, + ALL_LAYOUT, + phi::StridedSliceRawKernel, + int, + int16_t, + float, + phi::dtype::float16) {} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_roll_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_roll_op_xpu.py new file mode 100644 index 0000000000..4c64c6e2a3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_roll_op_xpu.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import sys +import unittest + +sys.path.append("..") +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + create_test_class, + get_xpu_op_support_types, + XPUOpTestWrapper, +) + +paddle.enable_static() + + +class XPUTestRollOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = "roll" + self.use_dynamic_create_class = False + + class TestXPURollOp(XPUOpTest): + def setUp(self): + self.op_type = "roll" + self.dtype = self.in_type + self.init_shapes() + self.inputs = { + 'X': np.random.random(self.x_shape).astype(self.dtype) + } + self.attrs = {'shifts': self.shifts, 'axis': self.axis} + self.outputs = { + 'Out': np.roll( + self.inputs['X'], self.attrs['shifts'], self.attrs['axis'] + ) + } + + def init_shapes(self): + self.x_shape = (100, 4, 5) + self.shifts = [101, -1] + self.axis = [0, -2] + + def test_check_output(self): + self.check_output_with_place(paddle.XPUPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Out') + + class TestRollOpCase2(TestXPURollOp): + def init_shapes(self): + self.x_shape = (100, 10, 5) + self.shifts = [8, -1] + self.axis = [-1, -2] + + class TestRollOpCase3(TestXPURollOp): + def init_shapes(self): + self.x_shape = (100, 10, 5, 10, 15) + self.shifts = [50, -1, 3] + self.axis = [-1, -2, 1] + + class TestRollOpCase4(TestXPURollOp): + def init_shapes(self): + self.x_shape = (100, 10, 5, 10, 15) + self.shifts = [8, -1] + self.axis = [-1, -2] + + class TestRollOpCase4(TestXPURollOp): + def init_shapes(self): + self.x_shape = (100, 10, 5, 10) + self.shifts = [20, -1] + self.axis = [0, -2] + + +support_types = get_xpu_op_support_types('roll') +for stype in support_types: + create_test_class(globals(), XPUTestRollOp, stype) + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_strided_slice_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_strided_slice_op_xpu.py new file mode 100644 index 0000000000..c73ead8d0f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_strided_slice_op_xpu.py @@ -0,0 +1,184 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import sys +import unittest + +sys.path.append("..") +from op_test_xpu import XPUOpTest +from xpu.get_test_cover_info import ( + create_test_class, + get_xpu_op_support_types, + XPUOpTestWrapper, +) + +paddle.enable_static() + + +def strided_slice_native_forward(input, axes, starts, ends, strides): + dim = input.ndim + start = [] + end = [] + stride = [] + for i in range(dim): + start.append(0) + end.append(input.shape[i]) + stride.append(1) + + for i in range(len(axes)): + start[axes[i]] = starts[i] + end[axes[i]] = ends[i] + stride[axes[i]] = strides[i] + + result = { + 1: lambda input, start, end, stride: input[ + start[0] : end[0] : stride[0] + ], + 2: lambda input, start, end, stride: input[ + start[0] : end[0] : stride[0], start[1] : end[1] : stride[1] + ], + 3: lambda input, start, end, stride: input[ + start[0] : end[0] : stride[0], + start[1] : end[1] : stride[1], + start[2] : end[2] : stride[2], + ], + 4: lambda input, start, end, stride: input[ + start[0] : end[0] : stride[0], + start[1] : end[1] : stride[1], + start[2] : end[2] : stride[2], + start[3] : end[3] : stride[3], + ], + 5: lambda input, start, end, stride: input[ + start[0] : end[0] : stride[0], + start[1] : end[1] : stride[1], + start[2] : end[2] : stride[2], + start[3] : end[3] : stride[3], + start[4] : end[4] : stride[4], + ], + 6: lambda input, start, end, stride: input[ + start[0] : end[0] : stride[0], + start[1] : end[1] : stride[1], + start[2] : end[2] : stride[2], + start[3] : end[3] : stride[3], + start[4] : end[4] : stride[4], + start[5] : end[5] : stride[5], + ], + }[dim](input, start, end, stride) + + return result + + +class XPUTestStrideSliceOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'strided_slice' + self.use_dynamic_create_class = False + + class XPUTestStrideSliceOp(XPUOpTest): + def setUp(self): + self.op_type = 'strided_slice' + self.dtype = self.in_type + self.initTestCase() + self.input = np.random.random(self.inshape).astype(self.dtype) + self.python_api = paddle.strided_slice + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides + ) + + self.inputs = {'Input': self.input} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def test_check_output(self): + self.check_output_with_place(paddle.XPUPlace(0)) + + def test_check_grad(self): + self.check_grad_with_place(paddle.XPUPlace(0), ['Input'], 'Out') + + def initTestCase(self): + self.inshape = 100 + self.axes = [0] + self.starts = [-4] + self.ends = [-1] + self.strides = [1] + self.infer_flags = [1] + + class XPUTestStrideSliceOp1(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = 100 + self.axes = [0] + self.starts = [3] + self.ends = [8] + self.strides = [1] + self.infer_flags = [1] + + class XPUTestStrideSliceOp2(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = (4, 8, 12) + self.axes = [0, 1, 2] + self.starts = [3, 4, 5] + self.ends = [4, 5, 6] + self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] + + class XPUTestStrideSliceOp3(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = (4, 8, 12, 4, 40) + self.axes = [0, 1, 2, 3, 4] + self.starts = [3, 4, 5, 1, 10] + self.ends = [4, 5, 6, 2, 30] + self.strides = [1, 1, 1, 2, 2] + self.infer_flags = [1, 1, 1, 1, 1] + + class XPUTestStrideSliceOp5(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = (5, 5, 5) + self.axes = [0, 1, 2] + self.starts = [1, 0, 0] + self.ends = [2, 1, 3] + self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] + + class XPUTestStrideSliceOp7(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = (5, 5, 5) + self.axes = [0, 1, 2] + self.starts = [1, 0, 0] + self.ends = [2, 2, 3] + self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] + + class XPUTestStrideSliceOp8(XPUTestStrideSliceOp): + def initTestCase(self): + self.inshape = (3, 3, 3, 6, 7, 8) + self.axes = [0, 1, 2, 3, 4, 5] + self.starts = [1, 0, 0, 0, 1, 2] + self.ends = [2, 2, 3, 1, 2, 8] + self.strides = [1, 1, 1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1, 1] + + +support_types = get_xpu_op_support_types('strided_slice') +for stype in support_types: + create_test_class(globals(), XPUTestStrideSliceOp, stype) + +if __name__ == "__main__": + unittest.main() -- GitLab