diff --git a/paddle/fluid/operators/strided_slice_op_npu.cc b/paddle/fluid/operators/strided_slice_op_npu.cc new file mode 100755 index 0000000000000000000000000000000000000000..deafdc5633a15d2edfeea6d0e31b3de64b922fb6 --- /dev/null +++ b/paddle/fluid/operators/strided_slice_op_npu.cc @@ -0,0 +1,239 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/strided_slice_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/operators/slice_op.h" + +namespace paddle { +namespace operators { + +template +class StridedSliceNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Variable* input_var = ctx.InputVar("Input"); + bool is_tensor_array = input_var->IsType(); + PADDLE_ENFORCE_EQ(is_tensor_array, false, + platform::errors::InvalidArgument( + "Tensor array as input is not supported.")); + int rank = ctx.Input("Input")->dims().size(); + switch (rank) { + case 1: + StridedSliceCompute<1>(ctx); + break; + case 2: + StridedSliceCompute<2>(ctx); + break; + case 3: + StridedSliceCompute<3>(ctx); + break; + case 4: + StridedSliceCompute<4>(ctx); + break; + case 5: + StridedSliceCompute<5>(ctx); + break; + case 6: + StridedSliceCompute<6>(ctx); + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of input is supported up to 6.")); + break; + } + } + + private: + template + void StridedSliceCompute(const framework::ExecutionContext& ctx) const { + auto place = ctx.GetPlace(); + auto stream = + ctx.template device_context() + .stream(); + + auto in = ctx.Input("Input"); + auto out = ctx.Output("Out"); + auto in_dims = in->dims(); + + // list + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); + auto strides_int = ctx.Attr>("strides"); + + std::vector starts(starts_int.begin(), starts_int.end()); + std::vector ends(ends_int.begin(), ends_int.end()); + std::vector strides(strides_int.begin(), strides_int.end()); + + auto axes = ctx.Attr>("axes"); + auto infer_flags = ctx.Attr>("infer_flags"); + auto decrease_axis = ctx.Attr>("decrease_axis"); + + // vector> + auto list_new_ends_tensor = + ctx.MultiInput("EndsTensorList"); + auto list_new_starts_tensor = + ctx.MultiInput("StartsTensorList"); + auto list_new_strides_tensor = + ctx.MultiInput("StridesTensorList"); + + // Tensor + if (list_new_starts_tensor.size() > 0) { + starts = GetDataFromTensorList(list_new_starts_tensor); + } else if (ctx.HasInput("StartsTensor")) { + auto* starts_tensor = ctx.Input("StartsTensor"); + starts = GetDataFromTensor(starts_tensor); + } + + if (list_new_ends_tensor.size() > 0) { + ends = GetDataFromTensorList(list_new_ends_tensor); + } else if (ctx.HasInput("EndsTensor")) { + auto* ends_tensor = ctx.Input("EndsTensor"); + ends = GetDataFromTensor(ends_tensor); + } + + if (list_new_strides_tensor.size() > 0) { + strides = GetDataFromTensorList(list_new_strides_tensor); + } else if (ctx.HasInput("StridesTensor")) { + auto* strides_tensor = ctx.Input("StridesTensor"); + strides = GetDataFromTensor(strides_tensor); + } + + // out dims calculation + std::vector out_dims_vector(in_dims.size(), -1); + StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, + decrease_axis, out_dims_vector.data(), axes.size(), + false); + framework::DDim out_dims(framework::make_ddim(out_dims_vector)); + + // check whether need to reverse (false: stride > 0; true: stride < 0) + std::vector reverse_vector(starts.size(), 0); + StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), + reverse_vector.data(), in_dims, infer_flags, + decrease_axis, starts.size()); + + // construct the starts_indices, ends_indices and strides_indices tensor for + // calling StridedSlice op + std::vector starts_indices_vector(D, 0); + std::vector ends_indices_vector(out_dims_vector.begin(), + out_dims_vector.end()); + std::vector strides_indices_vector(D, 1); + + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices_vector[axis_index] = starts[axis]; + ends_indices_vector[axis_index] = ends[axis]; + strides_indices_vector[axis_index] = strides[axis]; + } + + Tensor starts_indices_tensor; + Tensor ends_indices_tensor; + Tensor strides_indices_tensor; + + starts_indices_tensor.mutable_data({D}, place); + ends_indices_tensor.mutable_data({D}, place); + strides_indices_tensor.mutable_data({D}, place); + + TensorFromVector(starts_indices_vector, ctx.device_context(), + &starts_indices_tensor); + TensorFromVector(ends_indices_vector, ctx.device_context(), + &ends_indices_tensor); + TensorFromVector(strides_indices_vector, ctx.device_context(), + &strides_indices_tensor); + + auto out_dims_origin = out_dims; + if (decrease_axis.size() > 0) { + std::vector new_out_shape; + for (size_t i = 0; i < decrease_axis.size(); ++i) { + PADDLE_ENFORCE_EQ( + out_dims[decrease_axis[i]], 1, + platform::errors::InvalidArgument( + "the size of decrease dimension should be 1, but received %d.", + out_dims[decrease_axis[i]])); + out_dims_origin[decrease_axis[i]] = 0; + } + + for (int i = 0; i < out_dims_origin.size(); ++i) { + if (out_dims_origin[i] != 0) { + new_out_shape.push_back(out_dims_origin[i]); + } + } + if (new_out_shape.size() == 0) { + new_out_shape.push_back(1); + } + out_dims_origin = framework::make_ddim(new_out_shape); + } + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + + out->Resize(out_dims); + out->mutable_data(place); + + const auto& runner = NpuOpRunner( + "StridedSlice", {*in, starts_indices_tensor, ends_indices_tensor, + strides_indices_tensor}, + {*out}, {{"begin_mask", 0}, + {"end_mask", 0}, + {"ellipsis_mask", 0}, + {"new_axis_mask", 0}, + {"shrink_axis_mask", 0}}); + runner.Run(stream); + + if (need_reverse) { + Tensor out_tmp; + out_tmp.mutable_data(out_dims, place); + TensorCopy(*out, place, + ctx.template device_context(), + &out_tmp); + + Tensor reverse_axis; + std::vector reverse_axis_vector; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + reverse_axis_vector.push_back(axes[axis]); + } + } + reverse_axis.mutable_data( + {static_cast(reverse_axis_vector.size())}, place); + TensorFromVector(reverse_axis_vector, ctx.device_context(), + &reverse_axis); + + const auto& runner_reverse = + NpuOpRunner("ReverseV2", {out_tmp, reverse_axis}, {*out}); + runner_reverse.Run(stream); + } + + if (decrease_axis.size() > 0) { + out->Resize(out_dims_origin); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + strided_slice, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py new file mode 100755 index 0000000000000000000000000000000000000000..2f0fa697cb0d9887fe70349b881c825c7fe10773 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py @@ -0,0 +1,583 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import numpy as np +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import unittest +import paddle.fluid as fluid +import paddle + +paddle.enable_static() + + +def strided_slice_native_forward(input, axes, starts, ends, strides): + dim = input.ndim + start = [] + end = [] + stride = [] + for i in range(dim): + start.append(0) + end.append(input.shape[i]) + stride.append(1) + + for i in range(len(axes)): + start[axes[i]] = starts[i] + end[axes[i]] = ends[i] + stride[axes[i]] = strides[i] + + result = { + 1: lambda input, start, end, stride: input[start[0]:end[0]:stride[0]], + 2: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1]], + 3: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2]], + 4: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3]], + 5: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], start[4]:end[4]:stride[4]], + 6: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \ + start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], \ + start[4]:end[4]:stride[4], start[5]:end[5]:stride[5]] + }[dim](input, start, end, stride) + + return result + + +@skip_check_grad_ci( + reason='''forward only, it doesn't need to call check_grad.''') +class TestStridedSliceOp(OpTest): + def setUp(self): + self.initTestCase() + self.op_type = 'strided_slice' + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + self.inputs = {'Input': self.input} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags + } + + def test_check_output(self): + place = paddle.NPUPlace(0) + self.check_output_with_place(place) + + def initTestCase(self): + self.input = np.random.rand(10) + self.axes = [0] + self.starts = [2] + self.ends = [7] + self.strides = [1] + self.infer_flags = [1] + + +class TestStridedSliceOp1(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(100) + self.axes = [0] + self.starts = [3] + self.ends = [8] + self.strides = [1] + self.infer_flags = [1] + + +class TestStridedSliceOp2(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(100) + self.axes = [0] + self.starts = [5] + self.ends = [0] + self.strides = [-1] + self.infer_flags = [1] + + +class TestStridedSliceOp3(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(100) + self.axes = [0] + self.starts = [-1] + self.ends = [-3] + self.strides = [-1] + self.infer_flags = [1] + + +class TestStridedSliceOp4(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 4, 10) + self.axes = [0, 1, 2] + self.starts = [0, -1, 0] + self.ends = [2, -3, 5] + self.strides = [1, -1, 1] + self.infer_flags = [1, 1, 1] + + +class TestStridedSliceOp5(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(5, 5, 5) + self.axes = [0, 1, 2] + self.starts = [1, 0, 0] + self.ends = [2, 1, 3] + self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] + + +class TestStridedSliceOp6(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(5, 5, 5) + self.axes = [0, 1, 2] + self.starts = [1, -1, 0] + self.ends = [2, -3, 3] + self.strides = [1, -1, 1] + self.infer_flags = [1, 1, 1] + + +class TestStridedSliceOp7(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(5, 5, 5) + self.axes = [0, 1, 2] + self.starts = [1, 0, 0] + self.ends = [2, 2, 3] + self.strides = [1, 1, 1] + self.infer_flags = [1, 1, 1] + + +class TestStridedSliceOp8(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(1, 100, 1) + self.axes = [1] + self.starts = [1] + self.ends = [2] + self.strides = [1] + self.infer_flags = [1] + + +class TestStridedSliceOp9(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(1, 100, 1) + self.axes = [1] + self.starts = [-1] + self.ends = [-2] + self.strides = [-1] + self.infer_flags = [1] + + +class TestStridedSliceOp10(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(10, 10) + self.axes = [0, 1] + self.starts = [1, 0] + self.ends = [2, 2] + self.strides = [1, 1] + self.infer_flags = [1, 1] + + +class TestStridedSliceOp11(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4) + self.axes = [0, 1, 2, 3] + self.starts = [1, 0, 0, 0] + self.ends = [2, 2, 3, 4] + self.strides = [1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1] + + +class TestStridedSliceOp12(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4, 5) + self.axes = [0, 1, 2, 3, 4] + self.starts = [1, 0, 0, 0, 0] + self.ends = [2, 2, 3, 4, 4] + self.strides = [1, 1, 1, 1, 1] + self.infer_flags = [1, 1, 1, 1] + + +class TestStridedSliceOp13(TestStridedSliceOp): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 6, 7, 8) + self.axes = [0, 1, 2, 3, 4, 5] + self.starts = [1, 0, 0, 0, 1, 2] + self.ends = [2, 2, 3, 1, 2, 8] + self.strides = [1, 1, 1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1, 1] + + +class TestStridedSliceOpBool(TestStridedSliceOp): + def test_check_grad(self): + pass + + +class TestStridedSliceOpBool1D(TestStridedSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(100).astype("bool") + self.axes = [0] + self.starts = [3] + self.ends = [8] + self.strides = [1] + self.infer_flags = [1] + + +class TestStridedSliceOpBool2D(TestStridedSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(10, 10).astype("bool") + self.axes = [0, 1] + self.starts = [1, 0] + self.ends = [2, 2] + self.strides = [1, 1] + self.infer_flags = [1, 1] + + +class TestStridedSliceOpBool3D(TestStridedSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 4, 10).astype("bool") + self.axes = [0, 1, 2] + self.starts = [0, -1, 0] + self.ends = [2, -3, 5] + self.strides = [1, -1, 1] + self.infer_flags = [1, 1, 1] + + +class TestStridedSliceOpBool4D(TestStridedSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4).astype("bool") + self.axes = [0, 1, 2, 3] + self.starts = [1, 0, 0, 0] + self.ends = [2, 2, 3, 4] + self.strides = [1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1] + + +class TestStridedSliceOpBool5D(TestStridedSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 4, 5).astype("bool") + self.axes = [0, 1, 2, 3, 4] + self.starts = [1, 0, 0, 0, 0] + self.ends = [2, 2, 3, 4, 4] + self.strides = [1, 1, 1, 1, 1] + self.infer_flags = [1, 1, 1, 1] + + +class TestStridedSliceOpBool6D(TestStridedSliceOpBool): + def initTestCase(self): + self.input = np.random.rand(3, 3, 3, 6, 7, 8).astype("bool") + self.axes = [0, 1, 2, 3, 4, 5] + self.starts = [1, 0, 0, 0, 1, 2] + self.ends = [2, 2, 3, 1, 2, 8] + self.strides = [1, 1, 1, 1, 1, 2] + self.infer_flags = [1, 1, 1, 1, 1] + + +@skip_check_grad_ci( + reason='''forward only, it doesn't need to call check_grad.''') +class TestStridedSliceOp_starts_ListTensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + + starts_tensor = [] + for index, ele in enumerate(self.starts): + starts_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts_infer, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [1, -1, 1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + self.starts_infer = [1, 10, 2] + + def test_check_output(self): + place = paddle.NPUPlace(0) + self.check_output_with_place(place) + + +@skip_check_grad_ci( + reason='''forward only, it doesn't need to call check_grad.''') +class TestStridedSliceOp_ends_ListTensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + + ends_tensor = [] + for index, ele in enumerate(self.ends): + ends_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = {'Input': self.input, 'EndsTensorList': ends_tensor} + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends_infer, + 'strides': self.strides, + 'infer_flags': self.infer_flags + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 0] + self.ends = [3, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 2] + self.infer_flags = [1, -1, 1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + self.ends_infer = [3, 1, 4] + + def test_check_output(self): + place = paddle.NPUPlace(0) + self.check_output_with_place(place) + + +@skip_check_grad_ci( + reason='''forward only, it doesn't need to call check_grad.''') +class TestStridedSliceOp_starts_Tensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + self.inputs = { + 'Input': self.input, + "StartsTensor": np.array( + self.starts, dtype="int32") + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + #'starts': self.starts, + 'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + place = paddle.NPUPlace(0) + self.check_output_with_place(place) + + +@skip_check_grad_ci( + reason='''forward only, it doesn't need to call check_grad.''') +class TestStridedSliceOp_ends_Tensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + self.inputs = { + 'Input': self.input, + "EndsTensor": np.array( + self.ends, dtype="int32") + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + #'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + place = paddle.NPUPlace(0) + self.check_output_with_place(place) + + +@skip_check_grad_ci( + reason='''forward only, it doesn't need to call check_grad.''') +class TestStridedSliceOp_listTensor_Tensor(OpTest): + def setUp(self): + self.config() + ends_tensor = [] + for index, ele in enumerate(self.ends): + ends_tensor.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + self.op_type = "strided_slice" + + self.inputs = { + 'Input': self.input, + "StartsTensor": np.array( + self.starts, dtype="int32"), + "EndsTensorList": ends_tensor + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + #'starts': self.starts, + #'ends': self.ends, + 'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, 0, 2] + self.ends = [2, 3, 4] + self.axes = [0, 1, 2] + self.strides = [1, 1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + place = paddle.NPUPlace(0) + self.check_output_with_place(place) + + +@skip_check_grad_ci( + reason='''forward only, it doesn't need to call check_grad.''') +class TestStridedSliceOp_strides_Tensor(OpTest): + def setUp(self): + self.op_type = "strided_slice" + self.config() + self.inputs = { + 'Input': self.input, + "StridesTensor": np.array( + self.strides, dtype="int32") + } + self.outputs = {'Out': self.output} + self.attrs = { + 'axes': self.axes, + 'starts': self.starts, + 'ends': self.ends, + #'strides': self.strides, + 'infer_flags': self.infer_flags, + } + + def config(self): + self.input = np.random.random([3, 4, 5, 6]).astype("float64") + self.starts = [1, -1, 2] + self.ends = [2, 0, 4] + self.axes = [0, 1, 2] + self.strides = [1, -1, 1] + self.infer_flags = [-1, -1, -1] + self.output = strided_slice_native_forward( + self.input, self.axes, self.starts, self.ends, self.strides) + + def test_check_output(self): + place = paddle.NPUPlace(0) + self.check_output_with_place(place) + + + # Test python API +class TestStridedSliceAPI(unittest.TestCase): + def test_1(self): + input = np.random.random([3, 4, 5, 6]).astype("float64") + minus_1 = fluid.layers.fill_constant([1], "int32", -1) + minus_3 = fluid.layers.fill_constant([1], "int32", -3) + starts = fluid.layers.data( + name='starts', shape=[3], dtype='int32', append_batch_size=False) + ends = fluid.layers.data( + name='ends', shape=[3], dtype='int32', append_batch_size=False) + strides = fluid.layers.data( + name='strides', shape=[3], dtype='int32', append_batch_size=False) + + x = fluid.layers.data( + name="x", + shape=[3, 4, 5, 6], + append_batch_size=False, + dtype="float64") + out_1 = fluid.layers.strided_slice( + x, + axes=[0, 1, 2], + starts=[-3, 0, 2], + ends=[3, 100, -1], + strides=[1, 1, 1]) + out_2 = fluid.layers.strided_slice( + x, + axes=[0, 1, 3], + starts=[minus_3, 0, 2], + ends=[3, 100, -1], + strides=[1, 1, 1]) + out_3 = fluid.layers.strided_slice( + x, + axes=[0, 1, 3], + starts=[minus_3, 0, 2], + ends=[3, 100, minus_1], + strides=[1, 1, 1]) + out_4 = fluid.layers.strided_slice( + x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides) + + out_5 = x[-3:3, 0:100:2, -1:2:-1] + out_6 = x[minus_3:3:1, 0:100:2, :, minus_1:2:minus_1] + out_7 = x[minus_1, 0:100:2, :, -1:2:-1] + + exe = fluid.Executor(place=paddle.NPUPlace(0)) + res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run( + fluid.default_main_program(), + feed={ + "x": input, + 'starts': np.array([-3, 0, 2]).astype("int32"), + 'ends': np.array([3, 2147483648, -1]).astype("int64"), + 'strides': np.array([1, 1, 1]).astype("int32") + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7]) + assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1]) + assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :]) + assert np.array_equal(res_5, input[-3:3, 0:100:2, -1:2:-1, :]) + assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1]) + assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1]) + + def test_dygraph_op(self): + x = paddle.zeros(shape=[3, 4, 5, 6], dtype="float32") + axes = [1, 2, 3] + starts = [-3, 0, 2] + ends = [3, 2, 4] + strides_1 = [1, 1, 1] + sliced_1 = paddle.strided_slice( + x, axes=axes, starts=starts, ends=ends, strides=strides_1) + assert sliced_1.shape == (3, 2, 2, 2) + + +if __name__ == "__main__": + unittest.main()