From 7743cdf2d8ced48b3261926806c960ce64151b5f Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Wed, 1 Sep 2021 16:33:02 +0800 Subject: [PATCH] add strided_slice_grad op for npu (#35204) * add strided_slice_grad op for npu --- .../fluid/operators/strided_slice_op_npu.cc | 202 +++++++++++++++++- .../npu/test_strided_slice_op_npu.py | 99 ++++++--- 2 files changed, 265 insertions(+), 36 deletions(-) mode change 100755 => 100644 paddle/fluid/operators/strided_slice_op_npu.cc diff --git a/paddle/fluid/operators/strided_slice_op_npu.cc b/paddle/fluid/operators/strided_slice_op_npu.cc old mode 100755 new mode 100644 index deafdc5633a..eb9377cc638 --- a/paddle/fluid/operators/strided_slice_op_npu.cc +++ b/paddle/fluid/operators/strided_slice_op_npu.cc @@ -226,14 +226,204 @@ class StridedSliceNPUKernel : public framework::OpKernel { } }; +template +class StridedSliceGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Variable* input_var = ctx.InputVar("Input"); + bool is_tensor_array = input_var->IsType(); + PADDLE_ENFORCE_EQ(is_tensor_array, false, + platform::errors::InvalidArgument( + "Tensor array as input is not supported.")); + int rank = ctx.Input("Input")->dims().size(); + + switch (rank) { + case 1: + StridedSliceGradCompute<1>(ctx); + break; + case 2: + StridedSliceGradCompute<2>(ctx); + break; + case 3: + StridedSliceGradCompute<3>(ctx); + break; + case 4: + StridedSliceGradCompute<4>(ctx); + break; + case 5: + StridedSliceGradCompute<5>(ctx); + break; + case 6: + StridedSliceGradCompute<6>(ctx); + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "The rank of input is supported up to 6.")); + break; + } + } + + private: + template + void StridedSliceGradCompute(const framework::ExecutionContext& ctx) const { + auto place = ctx.GetPlace(); + auto& dev_ctx = + ctx.template device_context(); + + auto* input = ctx.Input("Input"); + auto input_dims = input->dims(); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("Input")); + dx->mutable_data(input_dims, place); + + auto starts_int = ctx.Attr>("starts"); + auto ends_int = ctx.Attr>("ends"); + auto strides_int = ctx.Attr>("strides"); + + std::vector starts(starts_int.begin(), starts_int.end()); + std::vector ends(ends_int.begin(), ends_int.end()); + std::vector strides(strides_int.begin(), strides_int.end()); + + auto axes = ctx.Attr>("axes"); + auto infer_flags = ctx.Attr>("infer_flags"); + auto decrease_axis = ctx.Attr>("decrease_axis"); + + auto list_new_ends_tensor = + ctx.MultiInput("EndsTensorList"); + auto list_new_starts_tensor = + ctx.MultiInput("StartsTensorList"); + auto list_new_strides_tensor = + ctx.MultiInput("StridesTensorList"); + + if (list_new_starts_tensor.size() > 0) { + starts = GetDataFromTensorList(list_new_starts_tensor); + } else if (ctx.HasInput("StartsTensor")) { + auto* starts_tensor = ctx.Input("StartsTensor"); + starts = GetDataFromTensor(starts_tensor); + } + + if (list_new_ends_tensor.size() > 0) { + ends = GetDataFromTensorList(list_new_ends_tensor); + } else if (ctx.HasInput("EndsTensor")) { + auto* ends_tensor = ctx.Input("EndsTensor"); + ends = GetDataFromTensor(ends_tensor); + } + + if (list_new_strides_tensor.size() > 0) { + strides = GetDataFromTensorList(list_new_strides_tensor); + } else if (ctx.HasInput("StridesTensor")) { + auto* strides_tensor = ctx.Input("StridesTensor"); + strides = GetDataFromTensor(strides_tensor); + } + + std::vector out_dims_vector(input_dims.size(), -1); + StridedSliceOutDims(starts, ends, strides, axes, infer_flags, input_dims, + decrease_axis, out_dims_vector.data(), axes.size(), + false); + + std::vector reverse_vector(starts.size(), 0); + StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), + reverse_vector.data(), input_dims, infer_flags, + decrease_axis, starts.size()); + + std::vector starts_indices_vector(D, 0); + std::vector ends_indices_vector(out_dims_vector.begin(), + out_dims_vector.end()); + std::vector strides_indices_vector(D, 1); + + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices_vector[axis_index] = starts[axis]; + ends_indices_vector[axis_index] = ends[axis]; + strides_indices_vector[axis_index] = strides[axis]; + } + + Tensor starts_indices_tensor; + Tensor ends_indices_tensor; + Tensor strides_indices_tensor; + + starts_indices_tensor.mutable_data({D}, place); + ends_indices_tensor.mutable_data({D}, place); + strides_indices_tensor.mutable_data({D}, place); + + TensorFromVector(starts_indices_vector, dev_ctx, &starts_indices_tensor); + TensorFromVector(ends_indices_vector, dev_ctx, &ends_indices_tensor); + TensorFromVector(strides_indices_vector, dev_ctx, &strides_indices_tensor); + + std::vector input_dims_vector; + for (int i = 0; i < input_dims.size(); i++) { + input_dims_vector.push_back(input_dims[i]); + } + Tensor input_dims_tensor; + TensorFromVector(input_dims_vector, dev_ctx, &input_dims_tensor); + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + + auto stream = dev_ctx.stream(); + framework::NPUAttributeMap attr_input = {{"begin_mask", 0}, + {"end_mask", 0}, + {"ellipsis_mask", 0}, + {"new_axis_mask", 0}, + {"shrink_axis_mask", 0}}; + + if (need_reverse) { + Tensor reverse_axis; + std::vector reverse_axis_vector; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + reverse_axis_vector.push_back(axes[axis]); + } + } + reverse_axis.mutable_data( + {static_cast(reverse_axis_vector.size())}, place); + TensorFromVector(reverse_axis_vector, dev_ctx, &reverse_axis); + + Tensor dout_tmp; + dout_tmp.mutable_data(dout->dims(), place); + const auto& runner_reverse = + NpuOpRunner("ReverseV2", {*dout, reverse_axis}, {dout_tmp}); + runner_reverse.Run(stream); + + const auto& runner = + NpuOpRunner("StridedSliceGrad", + {input_dims_tensor, starts_indices_tensor, + ends_indices_tensor, strides_indices_tensor, dout_tmp}, + {*dx}, attr_input); + runner.Run(stream); + } else { + const auto& runner = + NpuOpRunner("StridedSliceGrad", + {input_dims_tensor, starts_indices_tensor, + ends_indices_tensor, strides_indices_tensor, *dout}, + {*dx}, attr_input); + runner.Run(stream); + } + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL( + strided_slice, ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel, + ops::StridedSliceNPUKernel); + REGISTER_OP_NPU_KERNEL( - strided_slice, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel, - ops::StridedSliceNPUKernel); + strided_slice_grad, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel, + ops::StridedSliceGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py index 2f0fa697cb0..1260017da93 100755 --- a/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_strided_slice_op_npu.py @@ -56,11 +56,11 @@ def strided_slice_native_forward(input, axes, starts, ends, strides): return result -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp(OpTest): def setUp(self): self.initTestCase() + self.set_npu() + self.place = paddle.NPUPlace(0) self.op_type = 'strided_slice' self.output = strided_slice_native_forward( self.input, self.axes, self.starts, self.ends, self.strides) @@ -75,12 +75,17 @@ class TestStridedSliceOp(OpTest): 'infer_flags': self.infer_flags } + def set_npu(self): + self.__class__.use_npu = True + def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') def initTestCase(self): - self.input = np.random.rand(10) + self.input = np.random.rand(100) self.axes = [0] self.starts = [2] self.ends = [7] @@ -283,12 +288,12 @@ class TestStridedSliceOpBool6D(TestStridedSliceOpBool): self.infer_flags = [1, 1, 1, 1, 1] -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_starts_ListTensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() starts_tensor = [] for index, ele in enumerate(self.starts): @@ -305,6 +310,9 @@ class TestStridedSliceOp_starts_ListTensor(OpTest): 'infer_flags': self.infer_flags } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -318,16 +326,18 @@ class TestStridedSliceOp_starts_ListTensor(OpTest): self.starts_infer = [1, 10, 2] def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_ends_ListTensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() ends_tensor = [] for index, ele in enumerate(self.ends): @@ -344,6 +354,9 @@ class TestStridedSliceOp_ends_ListTensor(OpTest): 'infer_flags': self.infer_flags } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 0] @@ -357,16 +370,19 @@ class TestStridedSliceOp_ends_ListTensor(OpTest): self.ends_infer = [3, 1, 4] def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_starts_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() + self.inputs = { 'Input': self.input, "StartsTensor": np.array( @@ -381,6 +397,9 @@ class TestStridedSliceOp_starts_Tensor(OpTest): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -392,16 +411,19 @@ class TestStridedSliceOp_starts_Tensor(OpTest): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_ends_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" self.config() + self.set_npu() + self.inputs = { 'Input': self.input, "EndsTensor": np.array( @@ -416,6 +438,9 @@ class TestStridedSliceOp_ends_Tensor(OpTest): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -427,20 +452,23 @@ class TestStridedSliceOp_ends_Tensor(OpTest): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_listTensor_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) + self.op_type = "strided_slice" + self.set_npu() self.config() + ends_tensor = [] for index, ele in enumerate(self.ends): ends_tensor.append(("x" + str(index), np.ones( (1)).astype('int32') * ele)) - self.op_type = "strided_slice" self.inputs = { 'Input': self.input, @@ -457,6 +485,9 @@ class TestStridedSliceOp_listTensor_Tensor(OpTest): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, 0, 2] @@ -468,16 +499,19 @@ class TestStridedSliceOp_listTensor_Tensor(OpTest): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') -@skip_check_grad_ci( - reason='''forward only, it doesn't need to call check_grad.''') class TestStridedSliceOp_strides_Tensor(OpTest): def setUp(self): + self.place = paddle.NPUPlace(0) self.op_type = "strided_slice" + self.set_npu() self.config() + self.inputs = { 'Input': self.input, "StridesTensor": np.array( @@ -492,6 +526,9 @@ class TestStridedSliceOp_strides_Tensor(OpTest): 'infer_flags': self.infer_flags, } + def set_npu(self): + self.__class__.use_npu = True + def config(self): self.input = np.random.random([3, 4, 5, 6]).astype("float64") self.starts = [1, -1, 2] @@ -503,8 +540,10 @@ class TestStridedSliceOp_strides_Tensor(OpTest): self.input, self.axes, self.starts, self.ends, self.strides) def test_check_output(self): - place = paddle.NPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input'], 'Out') # Test python API -- GitLab