From 3f64a2c357c5500bcb6075b8f7c84676ea19184d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 24 Oct 2022 18:56:07 +0800 Subject: [PATCH] Polish slice code in fluid (#45746) * support selected_rows kernel for multiply in dygraph * delete useless code of slice in fluid * fix complie bug * move slice_array from fluid to phi * fix strided_slice_op_npu --- paddle/fluid/operators/slice_op.cc | 53 +---- paddle/fluid/operators/slice_op.h | 217 ------------------ paddle/fluid/operators/slice_op_mlu.cc | 3 +- paddle/fluid/operators/slice_op_npu.cc | 3 +- paddle/fluid/operators/strided_slice_op.cc | 1 - .../fluid/operators/strided_slice_op_mlu.cc | 8 +- .../fluid/operators/strided_slice_op_npu.cc | 8 +- paddle/phi/common/int_array.h | 2 + paddle/phi/kernels/cpu/slice_grad_kernel.cc | 26 +++ paddle/phi/kernels/cpu/slice_kernel.cc | 28 +++ .../phi/kernels/gpu/slice_grad_kernel.cu.cc | 30 +++ paddle/phi/kernels/gpu/slice_kernel.cu.cc | 30 +++ .../phi/kernels/impl/slice_grad_kernel_impl.h | 60 +++++ paddle/phi/kernels/impl/slice_kernel_impl.h | 59 +++++ paddle/phi/kernels/slice_grad_kernel.h | 16 ++ paddle/phi/kernels/slice_kernel.h | 14 ++ paddle/phi/ops/compat/slice_sig.cc | 47 +++- 17 files changed, 329 insertions(+), 276 deletions(-) delete mode 100644 paddle/fluid/operators/slice_op.h diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 2b3be3d78b2..71da14eae7f 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/slice_op.h" - #include #include #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/kernels/funcs/slice_utils.h" namespace paddle { @@ -456,53 +455,3 @@ REGISTER_OPERATOR(slice_grad, ops::SliceDoubleOpGradMaker, ops::SliceOpGradNoNeedBufferVarsInferer, ops::SliceOpGradVarTypeInference); - -REGISTER_OP_CPU_KERNEL( - slice, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel>, - ops::SliceKernel>, - ops::SliceKernel); - -REGISTER_OP_CPU_KERNEL( - slice_grad, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel>, - ops::SliceGradKernel>, - ops::SliceGradKernel); - -REGISTER_OP_CUDA_KERNEL( - slice, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel, - ops::SliceKernel>, - ops::SliceKernel>); - -REGISTER_OP_CUDA_KERNEL( - slice_grad, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel, - ops::SliceGradKernel>, - ops::SliceGradKernel>); diff --git a/paddle/fluid/operators/slice_op.h b/paddle/fluid/operators/slice_op.h deleted file mode 100644 index 5efb0c38194..00000000000 --- a/paddle/fluid/operators/slice_op.h +++ /dev/null @@ -1,217 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/utils.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { -using Tensor = phi::DenseTensor; -using Variable = framework::Variable; -using LoDTensorArray = framework::LoDTensorArray; -using DDim = framework::DDim; - -inline void DealTensorArray(const framework::ExecutionContext& ctx, - const std::vector& starts, - const std::vector& ends, - bool out_is_array) { - auto in_array = ctx.Input("Input"); - // If the input is LoDTensorArray, the rank of input is 1. - int64_t in_size = in_array->size(); - int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0]; - int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0]; - - start = std::max(start, static_cast(0)); - end = std::max(end, static_cast(0)); - end = std::min(end, in_size); - - if (starts[0] == -1 && end == 0) { - end = start + 1; - } - - PADDLE_ENFORCE_GT(end, - start, - platform::errors::InvalidArgument( - "Attr(ends) should be greater than attr(starts) in " - "slice op. But received end = %d, start = %d.", - ends[0], - starts[0])); - int64_t out_size = end - start; - - if (out_is_array) { - auto out_array = ctx.Output("Out"); - out_array->resize(out_size); - - for (int i = 0; i < out_size; ++i) { - auto* out_tensor = &out_array->at(i); - auto in_tensor = in_array->at(i + start); - out_tensor->set_lod(in_tensor.lod()); - if (in_tensor.memory_size() > 0) { - paddle::framework::TensorCopy(in_tensor, ctx.GetPlace(), out_tensor); - } else { - VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so " - "nothing has been written to output array[" - << i << "]."; - } - } - } else { - auto out = ctx.Output("Out"); - auto in_tensor = in_array->at(start); - paddle::framework::TensorCopy(in_tensor, ctx.GetPlace(), out); - } -} - -template -class SliceKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Variable* input_var = ctx.InputVar("Input"); - Variable* out_var = ctx.OutputVar("Out"); - bool input_is_array = input_var->IsType(); - bool out_is_array = out_var->IsType(); - - auto axes_int = ctx.Attr>("axes"); - auto starts_int = ctx.Attr>("starts"); - auto ends_int = ctx.Attr>("ends"); - std::vector axes(axes_int.begin(), axes_int.end()); - std::vector starts(starts_int.begin(), starts_int.end()); - std::vector ends(ends_int.begin(), ends_int.end()); - - auto decrease_axis = ctx.Attr>("decrease_axis"); - auto infer_flags = ctx.Attr>("infer_flags"); - - // Step 1: Get the accurate attribute value of starts and ends - auto starts_tensor_list = - ctx.MultiInput("StartsTensorList"); - if (ctx.HasInput("StartsTensor")) { - starts = GetDataFromTensor( - ctx.Input("StartsTensor")); - } else if (starts_tensor_list.size() > 0) { - starts = GetDataFromTensorList(starts_tensor_list); - } - - auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); - if (ctx.HasInput("EndsTensor")) { - ends = - GetDataFromTensor(ctx.Input("EndsTensor")); - } else if (ends_tensor_list.size() > 0) { - ends = GetDataFromTensorList(ends_tensor_list); - } - - PADDLE_ENFORCE_EQ( - starts.size(), - axes.size(), - platform::errors::InvalidArgument( - "The size of starts must be equal to the size of axes.")); - PADDLE_ENFORCE_EQ( - ends.size(), - axes.size(), - platform::errors::InvalidArgument( - "The size of ends must be equal to the size of axes.")); - - // Step 2: Compute output - if (input_is_array) { - DealTensorArray(ctx, starts, ends, out_is_array); - return; - } - } -}; - -template -class SliceGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto axes = ctx.Attr>("axes"); - auto starts_int = ctx.Attr>("starts"); - auto ends_int = ctx.Attr>("ends"); - std::vector starts(starts_int.begin(), starts_int.end()); - std::vector ends(ends_int.begin(), ends_int.end()); - - // Get the accurate attribute value of starts and ends - auto starts_tensor_list = - ctx.MultiInput("StartsTensorList"); - if (ctx.HasInput("StartsTensor")) { - starts = GetDataFromTensor( - ctx.Input("StartsTensor")); - } else if (starts_tensor_list.size() > 0) { - starts = GetDataFromTensorList(starts_tensor_list); - } - - auto ends_tensor_list = ctx.MultiInput("EndsTensorList"); - if (ctx.HasInput("EndsTensor")) { - ends = - GetDataFromTensor(ctx.Input("EndsTensor")); - } else if (ends_tensor_list.size() > 0) { - ends = GetDataFromTensorList(ends_tensor_list); - } - - Variable* d_input_var = ctx.OutputVar(framework::GradVarName("Input")); - const Variable* d_out_var = ctx.InputVar(framework::GradVarName("Out")); - bool d_input_is_array = d_input_var->IsType(); - bool d_out_is_array = d_out_var->IsType(); - - if (d_input_is_array) { - auto* input_array = ctx.Input("Input"); - auto* d_in_arr = - ctx.Output(framework::GradVarName("Input")); - - int64_t d_in_size = input_array->size(); - d_in_arr->resize(d_in_size); - // If the input is LoDTensorArray, the rank of input is 1. - // So only use the 0th element of starts. - int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0]; - start = std::max(start, static_cast(0)); - // set zero - platform::DeviceContextPool& pool = - platform::DeviceContextPool::Instance(); - auto& dev_ctx = *pool.Get(ctx.GetPlace()); - phi::funcs::SetConstant functor; - for (int i = 0; i < d_in_size; ++i) { - auto dim = input_array->at(i).dims(); - d_in_arr->at(i).Resize(dim); - d_in_arr->at(i).mutable_data(ctx.GetPlace()); - functor(reinterpret_cast(dev_ctx), - &d_in_arr->at(i), - static_cast(0)); - } - - if (d_out_is_array) { - auto* d_out_arr = - ctx.Input(framework::GradVarName("Out")); - int d_out_size = d_out_arr->size(); - for (int i = 0; i < d_out_size; ++i) { - paddle::framework::TensorCopy( - d_out_arr->at(i), ctx.GetPlace(), &(d_in_arr->at(start + i))); - } - } else { - auto* d_out = - ctx.Input(framework::GradVarName("Out")); - paddle::framework::TensorCopy( - *d_out, ctx.GetPlace(), &(d_in_arr->at(start))); - } - return; - } - } - - private: -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/slice_op_mlu.cc b/paddle/fluid/operators/slice_op_mlu.cc index 60c86b1fcf5..1935e2d0c9b 100644 --- a/paddle/fluid/operators/slice_op_mlu.cc +++ b/paddle/fluid/operators/slice_op_mlu.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/fluid/operators/slice_op.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/phi/kernels/funcs/slice_utils.h" namespace paddle { diff --git a/paddle/fluid/operators/slice_op_npu.cc b/paddle/fluid/operators/slice_op_npu.cc index 5ed606c7e00..13ad2635756 100644 --- a/paddle/fluid/operators/slice_op_npu.cc +++ b/paddle/fluid/operators/slice_op_npu.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/slice_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/phi/kernels/funcs/slice_utils.h" diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index ad75d23452c..a91b210f2dc 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/slice_op.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/kernels/funcs/strided_slice.h" diff --git a/paddle/fluid/operators/strided_slice_op_mlu.cc b/paddle/fluid/operators/strided_slice_op_mlu.cc index 806c8205d09..5800c167b01 100644 --- a/paddle/fluid/operators/strided_slice_op_mlu.cc +++ b/paddle/fluid/operators/strided_slice_op_mlu.cc @@ -12,13 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" -#include "paddle/fluid/operators/slice_op.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/phi/kernels/funcs/strided_slice.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; +using Variable = framework::Variable; +using LoDTensorArray = framework::LoDTensorArray; +using DDim = framework::DDim; + static void ProcessStridedSliceParams( const std::vector& axes, const DDim& input_dims, diff --git a/paddle/fluid/operators/strided_slice_op_npu.cc b/paddle/fluid/operators/strided_slice_op_npu.cc index 9a1492fea1e..f613dc10540 100644 --- a/paddle/fluid/operators/strided_slice_op_npu.cc +++ b/paddle/fluid/operators/strided_slice_op_npu.cc @@ -12,13 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/slice_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/phi/kernels/funcs/strided_slice.h" namespace paddle { namespace operators { +using Tensor = phi::DenseTensor; +using Variable = framework::Variable; +using LoDTensorArray = framework::LoDTensorArray; +using DDim = framework::DDim; + template class StridedSliceNPUKernel : public framework::OpKernel { public: diff --git a/paddle/phi/common/int_array.h b/paddle/phi/common/int_array.h index ca6f7fc17d2..d462d5de041 100644 --- a/paddle/phi/common/int_array.h +++ b/paddle/phi/common/int_array.h @@ -58,6 +58,8 @@ class IntArrayBase { size_t size() const { return array_.size(); } + int64_t operator[](int64_t i) const { return array_[i]; } + const std::vector& GetData() const { return array_; } private: diff --git a/paddle/phi/kernels/cpu/slice_grad_kernel.cc b/paddle/phi/kernels/cpu/slice_grad_kernel.cc index f22e3634f15..54834cb6c5c 100644 --- a/paddle/phi/kernels/cpu/slice_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/slice_grad_kernel.cc @@ -31,3 +31,29 @@ PD_REGISTER_KERNEL(slice_grad, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(slice_array_grad, + CPU, + ALL_LAYOUT, + phi::SliceArrayGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(slice_array_dense_grad, + CPU, + ALL_LAYOUT, + phi::SliceArrayDenseGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/slice_kernel.cc b/paddle/phi/kernels/cpu/slice_kernel.cc index ff9a5c1593f..89f00e71f3a 100644 --- a/paddle/phi/kernels/cpu/slice_kernel.cc +++ b/paddle/phi/kernels/cpu/slice_kernel.cc @@ -31,3 +31,31 @@ PD_REGISTER_KERNEL(slice, phi::dtype::complex, phi::dtype::complex, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(slice_array, + CPU, + ALL_LAYOUT, + phi::SliceArrayKernel, + bool, + int, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(slice_array_dense, + CPU, + ALL_LAYOUT, + phi::SliceArrayDenseKernel, + bool, + int, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc index ec575ab952b..b7de6c9d941 100644 --- a/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc @@ -32,3 +32,33 @@ PD_REGISTER_KERNEL(slice_grad, phi::dtype::complex, phi::dtype::bfloat16, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(slice_array_grad, + GPU, + ALL_LAYOUT, + phi::SliceArrayGradKernel, + bool, + int, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(slice_array_dense_grad, + GPU, + ALL_LAYOUT, + phi::SliceArrayDenseGradKernel, + bool, + int, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/slice_kernel.cu.cc b/paddle/phi/kernels/gpu/slice_kernel.cu.cc index 5232ce35811..492dc82998b 100644 --- a/paddle/phi/kernels/gpu/slice_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/slice_kernel.cu.cc @@ -32,3 +32,33 @@ PD_REGISTER_KERNEL(slice, phi::dtype::complex, phi::dtype::bfloat16, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(slice_array, + GPU, + ALL_LAYOUT, + phi::SliceArrayKernel, + bool, + int, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(slice_array_dense, + GPU, + ALL_LAYOUT, + phi::SliceArrayDenseKernel, + bool, + int, + uint8_t, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/slice_grad_kernel_impl.h b/paddle/phi/kernels/impl/slice_grad_kernel_impl.h index 1a6d64ee58a..2fad8d7a59c 100644 --- a/paddle/phi/kernels/impl/slice_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/slice_grad_kernel_impl.h @@ -14,8 +14,10 @@ #pragma once +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/slice_utils.h" #include "paddle/phi/kernels/slice_grad_kernel.h" @@ -350,4 +352,62 @@ void SliceGradRawKernel(const Context& ctx, } } +template +void SliceArrayGradKernel(const Context& dev_ctx, + const TensorArray& input, + const TensorArray& out_grad, + const IntArray& starts, + const IntArray& ends, + TensorArray* input_grad) { + int64_t d_in_size = input.size(); + input_grad->resize(d_in_size); + // If the input is TensorArray, the rank of input is 1. + // So only use the 0th element of starts. + int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0]; + start = std::max(start, static_cast(0)); + // set zero + phi::funcs::SetConstant functor; + for (int i = 0; i < d_in_size; ++i) { + const auto& dim = input.at(i).dims(); + auto* in_grad_tensor = &input_grad->at(i); + in_grad_tensor->Resize(dim); + dev_ctx.template Alloc(in_grad_tensor); + functor(dev_ctx, in_grad_tensor, static_cast(0)); + } + + int d_out_size = out_grad.size(); + for (int i = 0; i < d_out_size; ++i) { + phi::Copy(dev_ctx, + out_grad[i], + dev_ctx.GetPlace(), + false, + &input_grad->at(start + i)); + } +} + +template +void SliceArrayDenseGradKernel(const Context& dev_ctx, + const TensorArray& input, + const DenseTensor& out_grad, + const IntArray& starts, + TensorArray* input_grad) { + int64_t d_in_size = input.size(); + input_grad->resize(d_in_size); + // If the input is TensorArray, the rank of input is 1. + // So only use the 0th element of starts. + int64_t start = starts[0] < 0 ? (starts[0] + d_in_size) : starts[0]; + start = std::max(start, static_cast(0)); + // set zero + phi::funcs::SetConstant functor; + for (int i = 0; i < d_in_size; ++i) { + const auto& dim = input.at(i).dims(); + auto* in_grad_tensor = &input_grad->at(i); + in_grad_tensor->Resize(dim); + dev_ctx.template Alloc(in_grad_tensor); + functor(dev_ctx, in_grad_tensor, static_cast(0)); + } + phi::Copy( + dev_ctx, out_grad, dev_ctx.GetPlace(), false, &input_grad->at(start)); +} + } // namespace phi diff --git a/paddle/phi/kernels/impl/slice_kernel_impl.h b/paddle/phi/kernels/impl/slice_kernel_impl.h index b855ef43aa7..78ed41d9f07 100644 --- a/paddle/phi/kernels/impl/slice_kernel_impl.h +++ b/paddle/phi/kernels/impl/slice_kernel_impl.h @@ -14,9 +14,13 @@ #pragma once +#include + +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/slice_utils.h" +#include "paddle/phi/kernels/slice_kernel.h" namespace phi { @@ -151,4 +155,59 @@ void SliceRawKernel(const Context& ctx, } } +template +void SliceArrayKernel(const Context& dev_ctx, + const TensorArray& input, + const IntArray& starts, + const IntArray& ends, + TensorArray* out) { + int64_t in_size = input.size(); + int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0]; + int64_t end = ends[0] < 0 ? (ends[0] + in_size) : ends[0]; + + start = std::max(start, static_cast(0)); + end = std::max(end, static_cast(0)); + end = std::min(end, in_size); + + if (starts[0] == -1 && end == 0) { + end = start + 1; + } + + PADDLE_ENFORCE_GT(end, + start, + phi::errors::InvalidArgument( + "Attr(ends) should be greater than attr(starts) in " + "slice op. But received end = %d, start = %d.", + ends[0], + starts[0])); + int64_t out_size = end - start; + + out->resize(out_size); + for (int i = 0; i < out_size; ++i) { + auto* out_tensor = &out->at(i); + const auto& in_tensor = input.at(i + start); + out_tensor->set_lod(in_tensor.lod()); + if (in_tensor.memory_size() > 0) { + phi::Copy( + dev_ctx, in_tensor, dev_ctx.GetPlace(), false, out_tensor); + } else { + VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so " + "nothing has been written to output array[" + << i << "]."; + } + } +} + +template +void SliceArrayDenseKernel(const Context& dev_ctx, + const TensorArray& input, + const IntArray& starts, + DenseTensor* out) { + int64_t in_size = input.size(); + int64_t start = starts[0] < 0 ? (starts[0] + in_size) : starts[0]; + start = std::max(start, static_cast(0)); + + phi::Copy(dev_ctx, input[start], dev_ctx.GetPlace(), false, out); +} + } // namespace phi diff --git a/paddle/phi/kernels/slice_grad_kernel.h b/paddle/phi/kernels/slice_grad_kernel.h index a74b432c2b1..5c01631a93d 100644 --- a/paddle/phi/kernels/slice_grad_kernel.h +++ b/paddle/phi/kernels/slice_grad_kernel.h @@ -16,6 +16,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_array.h" namespace phi { @@ -30,4 +31,19 @@ void SliceGradRawKernel(const Context& ctx, const std::vector& decrease_axis, DenseTensor* input_grad); +template +void SliceArrayGradKernel(const Context& dev_ctx, + const TensorArray& input, + const TensorArray& out_grad, + const IntArray& starts, + const IntArray& ends, + TensorArray* input_grad); + +template +void SliceArrayDenseGradKernel(const Context& dev_ctx, + const TensorArray& input, + const DenseTensor& out_grad, + const IntArray& starts, + TensorArray* input_grad); + } // namespace phi diff --git a/paddle/phi/kernels/slice_kernel.h b/paddle/phi/kernels/slice_kernel.h index e01ff3d74fb..160fde880d7 100644 --- a/paddle/phi/kernels/slice_kernel.h +++ b/paddle/phi/kernels/slice_kernel.h @@ -16,6 +16,7 @@ #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/tensor_array.h" #include "paddle/phi/infermeta/unary.h" namespace phi { @@ -30,6 +31,19 @@ void SliceRawKernel(const Context& ctx, const std::vector& decrease_axis, DenseTensor* out); +template +void SliceArrayKernel(const Context& dev_ctx, + const TensorArray& input, + const IntArray& starts, + const IntArray& ends, + TensorArray* out); + +template +void SliceArrayDenseKernel(const Context& dev_ctx, + const TensorArray& input, + const IntArray& starts, + DenseTensor* out); + template DenseTensor SliceKernel(const Context& ctx, const DenseTensor& input, diff --git a/paddle/phi/ops/compat/slice_sig.cc b/paddle/phi/ops/compat/slice_sig.cc index 607d0b31310..beb5e4c959a 100644 --- a/paddle/phi/ops/compat/slice_sig.cc +++ b/paddle/phi/ops/compat/slice_sig.cc @@ -19,7 +19,27 @@ namespace phi { KernelSignature SliceOpArgumentMapping(const ArgumentMappingContext& ctx) { // if input is Tensor Array if (ctx.IsDenseTensorVectorInput("Input")) { - return KernelSignature("unregistered", {}, {}, {}); + const char* starts_name = "starts"; + if (ctx.HasInput("StartsTensor")) { + starts_name = "StartsTensor"; + } else if (ctx.InputSize("StartsTensorList") > 0) { + starts_name = "StartsTensorList"; + } + const char* ends_name = "ends"; + if (ctx.HasInput("EndsTensor")) { + ends_name = "EndsTensor"; + } else if (ctx.InputSize("EndsTensorList") > 0) { + ends_name = "EndsTensorList"; + } + + if (paddle::any_cast>(ctx.Attr("decrease_axis")).size() > + 0) { + return KernelSignature( + "slice_array_dense", {"Input"}, {starts_name}, {"Out"}); + } else { + return KernelSignature( + "slice_array", {"Input"}, {starts_name, ends_name}, {"Out"}); + } } if (ctx.HasInput("StartsTensor")) { @@ -99,7 +119,30 @@ KernelSignature SliceOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorVectorInput("Input")) { - return KernelSignature("unregistered", {}, {}, {}); + const char* starts_name = "starts"; + if (ctx.HasInput("StartsTensor")) { + starts_name = "StartsTensor"; + } else if (ctx.InputSize("StartsTensorList") > 0) { + starts_name = "StartsTensorList"; + } + const char* ends_name = "ends"; + if (ctx.HasInput("EndsTensor")) { + ends_name = "EndsTensor"; + } else if (ctx.InputSize("EndsTensorList") > 0) { + ends_name = "EndsTensorList"; + } + if (paddle::any_cast>(ctx.Attr("decrease_axis")).size() > + 0) { + return KernelSignature("slice_array_dense_grad", + {"Input", "Out@GRAD"}, + {starts_name}, + {"Input@GRAD"}); + } else { + return KernelSignature("slice_array_grad", + {"Input", "Out@GRAD"}, + {starts_name, ends_name}, + {"Input@GRAD"}); + } } if (ctx.HasInput("StartsTensor")) { -- GitLab