diff --git a/paddle/fluid/operators/strided_slice_op.cc b/paddle/fluid/operators/strided_slice_op.cc index bea0319744ca6eb87efb3350ce2ce0d13bef333c..0ff7d654fc29d1e739147a5fc37fe76c9fcf5e71 100644 --- a/paddle/fluid/operators/strided_slice_op.cc +++ b/paddle/fluid/operators/strided_slice_op.cc @@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/strided_slice_op.h" #include #include #include #include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/slice_op.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/kernels/funcs/strided_slice.h" namespace paddle { namespace operators { @@ -28,149 +33,6 @@ class StridedSliceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice"); - auto input_var_type = ctx->GetInputsVarType("Input")[0]; - if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { - if (ctx->IsRuntime()) { - // shape is determined by Runtime. - return; - } - } - auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_LT( - in_dims.size(), 7, - platform::errors::InvalidArgument( - "The dimension of StridedSlice operator's input should be less " - "than 7, but received dimension is %d.", - in_dims.size())); - - auto starts_int = ctx->Attrs().Get>("starts"); - auto ends_int = ctx->Attrs().Get>("ends"); - auto strides_int = ctx->Attrs().Get>("strides"); - - std::vector starts(starts_int.begin(), starts_int.end()); - std::vector ends(ends_int.begin(), ends_int.end()); - std::vector strides(strides_int.begin(), strides_int.end()); - - auto axes = ctx->Attrs().Get>("axes"); - auto infer_flags = ctx->Attrs().Get>("infer_flags"); - auto decrease_axis = ctx->Attrs().Get>("decrease_axis"); - - auto starts_size = starts.size(); - auto ends_size = ends.size(); - auto strides_size = strides.size(); - - for (size_t i = 0; i < axes.size(); ++i) { - PADDLE_ENFORCE_GE(axes[i], 0, - platform::errors::InvalidArgument( - "The axis should be greater than or equal to 0." - "But received %d of axes[%d]", - axes[i], i)); - PADDLE_ENFORCE_LT( - axes[i], in_dims.size(), - platform::errors::InvalidArgument( - "The axes should be less than or equal to input tensor's rank." - "But received %d of axes[%d], input tensor shape [%d]", - axes[i], i, in_dims.size())); - } - - if (ctx->HasInputs("StartsTensorList")) { - auto StartsTensorList = ctx->Inputs("StartsTensorList"); - PADDLE_ENFORCE_GT( - StartsTensorList.size(), 0, - platform::errors::InvalidArgument( - "StridedSlice operator's StartsTensorList is empty.")); - starts_size = StartsTensorList.size(); - } - if (ctx->HasInputs("EndsTensorList")) { - auto EndsTensorList = ctx->Inputs("EndsTensorList"); - PADDLE_ENFORCE_GT( - EndsTensorList.size(), 0, - platform::errors::InvalidArgument( - "StridedSlice operator's EndsTensorList is empty.")); - ends_size = EndsTensorList.size(); - } - if (ctx->HasInputs("StridesTensorList")) { - auto StridesTensorList = ctx->Inputs("StridesTensorList"); - PADDLE_ENFORCE_GT( - StridesTensorList.size(), 0, - platform::errors::InvalidArgument( - "StridedSlice operator's StridesTensorList is empty.")); - strides_size = StridesTensorList.size(); - } - - auto tensor_input = false; - if (ctx->HasInput("EndsTensor") || ctx->HasInput("StartsTensor") || - ctx->HasInput("StridesTensor")) { - tensor_input = true; - } - if (!ctx->HasInput("EndsTensor")) { - PADDLE_ENFORCE_EQ( - ends_size, axes.size(), - platform::errors::InvalidArgument( - "The size of ends attribute in StridedSlice operator is not " - "equal to the size of axes attribute. The ends attribute's size " - "is %d, axes attribute's size is %d.", - ends_size, axes.size())); - } - if (!ctx->HasInput("StartsTensor")) { - PADDLE_ENFORCE_EQ( - starts_size, axes.size(), - platform::errors::InvalidArgument( - "The size of starts attribute in StridedSlice operator is not " - "equal to the size of axes attribute. The starts attribute's " - "size is %d, axes attribute's size is %d.", - starts_size, axes.size())); - } - if (!ctx->HasInput("StridesTensor")) { - PADDLE_ENFORCE_EQ( - strides_size, axes.size(), - platform::errors::InvalidArgument( - "The size of strides attribute in StridedSlice operator is not " - "equal to the size of axes attribute. The strides attribute's " - "size is %d, axes attribute's size is %d.", - strides_size, axes.size())); - } - // we need to analysis strided slice op is valid for - // the parameter that we get from python front - std::vector out_dims_vector(in_dims.size(), -1); - if (!tensor_input) { - StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, - decrease_axis, out_dims_vector.data(), axes.size(), - true); - } - framework::DDim out_dims(phi::make_ddim(out_dims_vector)); - // generate new shape - if (decrease_axis.size() > 0) { - std::vector new_out_shape; - for (size_t i = 0; i < decrease_axis.size(); ++i) { - if (ctx->IsRuntime() && infer_flags[i] != -1) { - PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1, - platform::errors::InvalidArgument( - "the size of decrease dimension should be 1, " - "but received %d.", - out_dims[decrease_axis[i]])); - } - out_dims[decrease_axis[i]] = 0; - } - - for (int i = 0; i < out_dims.size(); ++i) { - if (out_dims[i] != 0) { - new_out_shape.push_back(out_dims[i]); - } - } - if (new_out_shape.size() == 0) { - new_out_shape.push_back(1); - } - - out_dims = phi::make_ddim(new_out_shape); - } - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("Input", /*->*/ "Out"); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -304,26 +166,6 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", - "StridedSliceGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", - "Out@GRAD", "StridedSliceGrad"); - - auto input_var_type = ctx->GetInputsVarType("Input")[0]; - if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { - if (ctx->IsRuntime()) { - // shape is determined by Runtime - return; - } - } - auto x_dims = ctx->GetInputDim("Input"); - auto x_grad_name = framework::GradVarName("Input"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( @@ -384,35 +226,19 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer, } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape, + PD_INFER_META(phi::StridedSliceInferMeta)); + REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker, ops::StridedSliceOpGradMaker, ops::StridedSliceOpGradMaker, - ops::StridedSliceOpVarTypeInference); + ops::StridedSliceOpVarTypeInference, StridedSliceInferShape); + +DECLARE_INFER_SHAPE_FUNCTOR(strided_slice_grad, StridedSliceGradInferShape, + PD_INFER_META(phi::GeneralUnaryGradInferMeta)); REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad, ops::StridedSliceOpGradNoNeedBufferVarsInferer, - ops::StridedSliceGradOpVarTypeInference); - -REGISTER_OP_CPU_KERNEL( - strided_slice, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel>, - ops::StridedSliceKernel>); - -REGISTER_OP_CPU_KERNEL( - strided_slice_grad, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel>, - ops::StridedSliceGradKernel>); + ops::StridedSliceGradOpVarTypeInference, + StridedSliceGradInferShape); diff --git a/paddle/fluid/operators/strided_slice_op.cu b/paddle/fluid/operators/strided_slice_op.cu deleted file mode 100644 index f88605fbfc86dc30b16b4c0115eff2f6e9bbdc3b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/strided_slice_op.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/strided_slice_op.h" -#include "paddle/fluid/platform/complex.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - strided_slice, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel, - ops::StridedSliceKernel>, - ops::StridedSliceKernel>); - -REGISTER_OP_CUDA_KERNEL( - strided_slice_grad, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel, - ops::StridedSliceGradKernel>, - ops::StridedSliceGradKernel>); diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h deleted file mode 100644 index f28585821edeb7fe21766cd7a96dceae4db73230..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/strided_slice_op.h +++ /dev/null @@ -1,659 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/slice_op.h" -#include "paddle/phi/kernels/funcs/math_function.h" -namespace paddle { -namespace operators { - -static void StridedSliceOutDims( - const std::vector& starts, const std::vector& ends, - const std::vector& strides, const std::vector& axes, - const std::vector& infer_flags, const framework::DDim in_dims, - const std::vector& decrease_axis, int64_t* out_dims_vector, - const size_t size, bool infer_shape) { - for (int i = 0; i < in_dims.size(); i++) { - out_dims_vector[i] = in_dims[i]; - } - int64_t stride_index, start_index, end_index; - for (size_t i = 0; i < size; i++) { - int axes_index = axes[i]; - start_index = starts[i]; - end_index = ends[i]; - stride_index = strides[i]; - bool decrease_axis_affect = false; - if (start_index == -1 && end_index == 0 && infer_flags[i] == -1) { - auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); - if (ret != decrease_axis.end()) { - decrease_axis_affect = true; - } - } - if (decrease_axis_affect) { - out_dims_vector[axes_index] = 1; - continue; - } - if (infer_shape && infer_flags[i] == -1) { - out_dims_vector[axes_index] = -1; - continue; - } - - PADDLE_ENFORCE_NE(stride_index, 0, - platform::errors::InvalidArgument( - "stride index in StridedSlice operator is 0.")); - int64_t axis_size = in_dims[axes_index]; - - if (axis_size < 0) { - continue; - } - - if (start_index < 0) { - start_index = start_index + axis_size; - } - if (end_index < 0) { - if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition - end_index = end_index + axis_size; - } - } - - if (stride_index < 0) { - start_index = start_index + 1; - end_index = end_index + 1; - } - - bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) || - (stride_index > 0 && (start_index > end_index))); - PADDLE_ENFORCE_EQ(neg_dim_condition, false, - platform::errors::InvalidArgument( - "The start index and end index are invalid for their " - "corresponding stride.")); - - int64_t left = - std::max(static_cast(0), std::min(start_index, end_index)); - int64_t right = std::min(axis_size, std::max(start_index, end_index)); - int64_t step = std::abs(stride_index); - - auto out_dims_index = (std::abs(right - left) + step - 1) / step; - - out_dims_vector[axes_index] = out_dims_index; - } -} - -static void StridedSliceFunctor(int64_t* starts, int64_t* ends, - int64_t* strides, int* axes, int* reverse_axis, - const framework::DDim dims, - const std::vector& infer_flags, - const std::vector& decrease_axis, - const size_t size) { - for (size_t axis = 0; axis < size; axis++) { - int64_t axis_size = dims[axes[axis]]; - int axis_index = axis; - if (axis_size < 0) { - starts[axis_index] = 0; - ends[axis_index] = 1; - strides[axis_index] = 1; - } - bool decrease_axis_affect = false; - if (starts[axis_index] == -1 && ends[axis_index] == 0 && - infer_flags[axis_index] == -1) { - auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), - axes[axis_index]); - if (ret != decrease_axis.end()) { - decrease_axis_affect = true; - } - } - // stride must not be zero - if (starts[axis_index] < 0) { - starts[axis_index] = starts[axis_index] + axis_size; - starts[axis_index] = std::max(starts[axis_index], 0); - } - if (ends[axis_index] < 0) { - if (!(ends[axis_index] == -1 && - strides[axis_index] < 0)) { // skip None stop condition - ends[axis_index] = ends[axis_index] + axis_size; - if (ends[axis_index] < 0) { - ends[axis_index] = 0; - } - } - } - if (decrease_axis_affect) { - if (strides[axis_index] < 0) { - ends[axis_index] = starts[axis_index] - 1; - } else { - ends[axis_index] = starts[axis_index] + 1; - } - } - - if (strides[axis_index] < 0) { - reverse_axis[axis_index] = 1; - strides[axis_index] = -strides[axis_index]; - if (starts[axis_index] > ends[axis_index]) { - // swap the reverse - auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1 - : starts[axis_index]; - auto offset = (end_dim - ends[axis_index]) % strides[axis_index]; - offset = offset == 0 ? strides[axis_index] : offset; - - starts[axis_index] = starts[axis_index] + offset; - ends[axis_index] = ends[axis_index] + offset; - } - std::swap(starts[axis_index], ends[axis_index]); - } else { - reverse_axis[axis_index] = 0; - strides[axis_index] = strides[axis_index]; - } - } -} - -template -class StridedSliceKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Variable* input_var = ctx.InputVar("Input"); - bool is_tensor_array = input_var->IsType(); - int rank = is_tensor_array - ? 1 - : ctx.Input("Input")->dims().size(); - switch (rank) { - case 1: - StridedSliceCompute<1>(ctx); - break; - case 2: - StridedSliceCompute<2>(ctx); - break; - case 3: - StridedSliceCompute<3>(ctx); - break; - case 4: - StridedSliceCompute<4>(ctx); - break; - case 5: - StridedSliceCompute<5>(ctx); - break; - case 6: - StridedSliceCompute<6>(ctx); - break; - } - } - - private: - template - void StridedSliceCompute(const framework::ExecutionContext& context) const { - auto& place = - *context.template device_context().eigen_device(); - - framework::DDim in_dims; - auto* input_var = context.InputVar("Input"); - - bool is_input_var_array = input_var->IsType(); - if (is_input_var_array) { - const int64_t size = input_var->Get().size(); - in_dims = phi::make_ddim({size}); - } else { - in_dims = context.Input("Input")->dims(); - } - - auto starts_int = context.Attr>("starts"); - auto ends_int = context.Attr>("ends"); - auto strides_int = context.Attr>("strides"); - - std::vector starts(starts_int.begin(), starts_int.end()); - std::vector ends(ends_int.begin(), ends_int.end()); - std::vector strides(strides_int.begin(), strides_int.end()); - - auto axes = context.Attr>("axes"); - auto infer_flags = context.Attr>("infer_flags"); - auto decrease_axis = context.Attr>("decrease_axis"); - - auto starts_indices = Eigen::DSizes(); - auto ends_indices = Eigen::DSizes(); - auto strides_indices = Eigen::DSizes(); - auto reverse_axis = Eigen::array(); - - auto list_new_ends_tensor = - context.MultiInput("EndsTensorList"); - auto list_new_starts_tensor = - context.MultiInput("StartsTensorList"); - auto list_new_strides_tensor = - context.MultiInput("StridesTensorList"); - - if (list_new_starts_tensor.size() > 0) { - starts = GetDataFromTensorList(list_new_starts_tensor); - } else if (context.HasInput("StartsTensor")) { - auto* starts_tensor = context.Input("StartsTensor"); - starts = GetDataFromTensor(starts_tensor); - } - - if (list_new_ends_tensor.size() > 0) { - ends = GetDataFromTensorList(list_new_ends_tensor); - } else if (context.HasInput("EndsTensor")) { - auto* ends_tensor = context.Input("EndsTensor"); - ends = GetDataFromTensor(ends_tensor); - } - - if (list_new_strides_tensor.size() > 0) { - strides = GetDataFromTensorList(list_new_strides_tensor); - } else if (context.HasInput("StridesTensor")) { - auto* strides_tensor = context.Input("StridesTensor"); - strides = GetDataFromTensor(strides_tensor); - } - - std::vector out_dims_vector(in_dims.size(), -1); - StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, - decrease_axis, out_dims_vector.data(), axes.size(), - false); - framework::DDim out_dims(phi::make_ddim(out_dims_vector)); - - std::vector reverse_vector(starts.size(), 0); - StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), - reverse_vector.data(), in_dims, infer_flags, - decrease_axis, starts.size()); - - for (size_t axis = 0; axis < D; axis++) { - starts_indices[axis] = 0; - ends_indices[axis] = out_dims[axis]; - strides_indices[axis] = 1; - reverse_axis[axis] = false; - } - for (size_t axis = 0; axis < axes.size(); axis++) { - int axis_index = axes[axis]; - starts_indices[axis_index] = starts[axis]; - ends_indices[axis_index] = ends[axis]; - strides_indices[axis_index] = strides[axis]; - reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; - } - - auto out_dims_origin = out_dims; - if (decrease_axis.size() > 0) { - std::vector new_out_shape; - for (size_t i = 0; i < decrease_axis.size(); ++i) { - PADDLE_ENFORCE_EQ( - out_dims[decrease_axis[i]], 1, - platform::errors::InvalidArgument( - "the size of decrease dimension should be 1, but received %d.", - out_dims[decrease_axis[i]])); - out_dims_origin[decrease_axis[i]] = 0; - } - - for (int i = 0; i < out_dims_origin.size(); ++i) { - if (out_dims_origin[i] != 0) { - new_out_shape.push_back(out_dims_origin[i]); - } - } - if (new_out_shape.size() == 0) { - new_out_shape.push_back(1); - } - out_dims_origin = phi::make_ddim(new_out_shape); - } - - bool need_reverse = false; - for (size_t axis = 0; axis < axes.size(); axis++) { - if (reverse_vector[axis] == 1) { - need_reverse = true; - break; - } - } - - if (is_input_var_array) { - PADDLE_ENFORCE_EQ( - starts_indices.size(), 1, - platform::errors::InvalidArgument( - "When the input of 'strided_slice_op' is `TensorArray`, the " - "dimension of start index should be 1, but received %d.", - starts_indices.size())); - - PADDLE_ENFORCE_EQ( - ends_indices.size(), 1, - platform::errors::InvalidArgument( - "When the input of 'strided_slice_op' is `TensorArray`, the " - "dimension of end index should be 1, but received %d.", - ends_indices.size())); - - PADDLE_ENFORCE_EQ( - strides_indices.size(), 1, - platform::errors::InvalidArgument( - "When the input of 'strided_slice_op' is `TensorArray`, the " - "dimension of stride should be 1, but received %d.", - strides_indices.size())); - - auto* output_var = context.OutputVar("Out"); - - PADDLE_ENFORCE_EQ( - output_var->IsType(), true, - platform::errors::InvalidArgument( - "When the input of `strided_slice_op` is `TensorArray`. The " - "output is excepted `TensorArray` , but received %s.", - framework::ToTypeName(output_var->Type()))); - - PADDLE_ENFORCE_EQ( - out_dims_origin.size(), 1, - platform::errors::InvalidArgument( - "When the input of 'strided_slice_op' is `TensorArray`, the " - "dimension of Output should be 1, but received %d", - out_dims_origin.size())); - - auto& in_array = input_var->Get(); - - auto* out_array = context.Output("Out"); - - out_array->resize(out_dims_origin[0]); - size_t const in_array_size = in_array.size(); - for (size_t i = 0; i < out_array->size(); i++) { - size_t in_offset = - (starts_indices[0] % in_array_size) + i * strides_indices[0]; - - int64_t out_offset = i; - if (need_reverse) { - out_offset = out_array->size() - i - 1; - } - - auto& in_tensor = in_array.at(in_offset); - PADDLE_ENFORCE_GT( - in_tensor.memory_size(), 0, - platform::errors::PreconditionNotMet( - "The input LoDTensorArray Input[%d] holds no memory.", - in_offset)); - auto* out_tensor = &out_array->at(out_offset); - - out_tensor->set_lod(in_tensor.lod()); - paddle::framework::TensorCopy(in_tensor, context.GetPlace(), - out_tensor); - } - - } else { - auto in = context.Input("Input"); - auto out = context.Output("Out"); - out->Resize(out_dims); - out->mutable_data(context.GetPlace()); - auto in_t = framework::EigenTensor::From(*in); - auto out_t = - framework::EigenTensor::From(*out, out_dims); - if (need_reverse) { - framework::Tensor tmp; - tmp.mutable_data(out_dims, context.GetPlace()); - auto tmp_t = framework::EigenTensor::From(tmp); - tmp_t.device(place) = - in_t.stridedSlice(starts_indices, ends_indices, strides_indices); - out_t.device(place) = tmp_t.reverse(reverse_axis); - } else { - out_t.device(place) = - in_t.stridedSlice(starts_indices, ends_indices, strides_indices); - } - - if (decrease_axis.size() > 0) { - out->Resize(out_dims_origin); - } - } - } -}; - -template -class StridedSliceGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Variable* input_var = ctx.InputVar("Input"); - bool is_tensor_array = input_var->IsType(); - int rank = is_tensor_array - ? 1 - : ctx.Input("Input")->dims().size(); - switch (rank) { - case 1: - StridedSliceGradCompute<1>(ctx); - break; - case 2: - StridedSliceGradCompute<2>(ctx); - break; - case 3: - StridedSliceGradCompute<3>(ctx); - break; - case 4: - StridedSliceGradCompute<4>(ctx); - break; - case 5: - StridedSliceGradCompute<5>(ctx); - break; - case 6: - StridedSliceGradCompute<6>(ctx); - break; - } - } - - private: - template - void StridedSliceGradCompute( - const framework::ExecutionContext& context) const { - auto& place = - *context.template device_context().eigen_device(); - - auto& dev_ctx = context.template device_context(); - - framework::DDim out_dims; - auto* out_var = context.OutputVar(framework::GradVarName("Input")); - bool is_out_var_array = out_var->IsType(); - if (is_out_var_array) { - // Note(weixin):Since the shape of `framework::GradVarName("Input")` of - // StridedSliceGrad cannot be calculated by - // `framework::GradVarName("Output")`, the dim of "Input" is used to - // calculate the output shape. when set it to inplace OP, there may be - // some problems. - const int64_t size = - context.Input("Input")->size(); - - out_dims = phi::make_ddim({size}); - } else { - out_dims = - context.Output(framework::GradVarName("Input")) - ->dims(); - } - - auto starts_int = context.Attr>("starts"); - auto ends_int = context.Attr>("ends"); - auto strides_int = context.Attr>("strides"); - - std::vector starts(starts_int.begin(), starts_int.end()); - std::vector ends(ends_int.begin(), ends_int.end()); - std::vector strides(strides_int.begin(), strides_int.end()); - - auto axes = context.Attr>("axes"); - auto infer_flags = context.Attr>("infer_flags"); - auto decrease_axis = context.Attr>("decrease_axis"); - - auto list_new_ends_tensor = - context.MultiInput("EndsTensorList"); - auto list_new_starts_tensor = - context.MultiInput("StartsTensorList"); - auto list_new_strides_tensor = - context.MultiInput("StridesTensorList"); - - if (list_new_starts_tensor.size() > 0) { - starts = GetDataFromTensorList(list_new_starts_tensor); - } else if (context.HasInput("StartsTensor")) { - auto* starts_tensor = context.Input("StartsTensor"); - starts = GetDataFromTensor(starts_tensor); - } - - if (list_new_ends_tensor.size() > 0) { - ends = GetDataFromTensorList(list_new_ends_tensor); - } else if (context.HasInput("EndsTensor")) { - auto* ends_tensor = context.Input("EndsTensor"); - ends = GetDataFromTensor(ends_tensor); - } - - if (list_new_strides_tensor.size() > 0) { - strides = GetDataFromTensorList(list_new_strides_tensor); - } else if (context.HasInput("StridesTensor")) { - auto* strides_tensor = context.Input("StridesTensor"); - strides = GetDataFromTensor(strides_tensor); - } - - auto starts_indices = Eigen::DSizes(); - auto ends_indices = Eigen::DSizes(); - auto strides_indices = Eigen::DSizes(); - - auto reverse_axis = Eigen::array(); - std::vector reverse_vector(starts.size(), 0); - - StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), - reverse_vector.data(), out_dims, infer_flags, - decrease_axis, starts.size()); - - for (size_t axis = 0; axis < D; axis++) { - starts_indices[axis] = 0; - ends_indices[axis] = out_dims[axis]; - strides_indices[axis] = 1; - } - for (size_t axis = 0; axis < axes.size(); axis++) { - int axis_index = axes[axis]; - starts_indices[axis_index] = starts[axis]; - ends_indices[axis_index] = ends[axis]; - strides_indices[axis_index] = strides[axis]; - reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; - } - - bool need_reverse = false; - for (size_t axis = 0; axis < axes.size(); axis++) { - if (reverse_vector[axis] == 1) { - need_reverse = true; - break; - } - } - - if (is_out_var_array) { - PADDLE_ENFORCE_EQ( - starts_indices.size(), 1, - platform::errors::InvalidArgument( - "When the input of 'strided_slice_grad_op' is `TensorArray`, the " - "dimension of start index should be 1, but received %d.", - starts_indices.size())); - PADDLE_ENFORCE_EQ( - ends_indices.size(), 1, - platform::errors::InvalidArgument( - "When the input of 'strided_slice_op' is `TensorArray`, the " - "dimension of end index should be 1, but received %d.", - ends_indices.size())); - PADDLE_ENFORCE_EQ( - strides_indices.size(), 1, - platform::errors::InvalidArgument( - "When the input of 'strided_slice_grad_op' is `TensorArray`, the " - "dimension of stride should be 1, but received %d.", - strides_indices.size())); - - auto* d_input_var = context.InputVar(framework::GradVarName("Out")); - - PADDLE_ENFORCE_EQ( - d_input_var->IsType(), true, - platform::errors::InvalidArgument( - "When the output of `strided_slice_grad_op` is " - "`TensorArray`, the input is excepted `TensorArray` , " - "but received %s.", - framework::ToTypeName(d_input_var->Type()))); - - PADDLE_ENFORCE_EQ( - out_dims.size(), 1, - platform::errors::InvalidArgument( - "When the output of `strided_slice_grad_op` is `TensorArray`, " - "the dimension of output should be 1, but received %d.", - out_dims.size())); - auto& d_in_array = d_input_var->Get(); - - auto* d_out_array = context.Output( - framework::GradVarName("Input")); - - d_out_array->resize(out_dims[0]); - auto const d_out_array_size = d_out_array->size(); - auto* input_tensor_array = - context.Input("Input"); - - for (size_t j = 0; j < d_out_array_size; j++) { - auto& dim = input_tensor_array->at(j).dims(); - auto* d_out_tensor = &d_out_array->at(j); - - int64_t sub = j - starts_indices[0]; - - int64_t in_offset = sub / strides_indices[0]; - - if (need_reverse) { - in_offset = d_in_array.size() - in_offset - 1; - } - - if ((sub % strides_indices[0] == 0) && (0 <= in_offset) && - (static_cast(in_offset) < d_in_array.size())) { - auto& in_tensor = d_in_array.at(in_offset); - PADDLE_ENFORCE_GT( - in_tensor.memory_size(), 0, - platform::errors::PreconditionNotMet( - "The input LoDTensorArray Input[%d] holds no memory.", - in_offset)); - - d_out_tensor->set_lod(in_tensor.lod()); - paddle::framework::TensorCopy(in_tensor, context.GetPlace(), - d_out_tensor); - - } else { - d_out_tensor->Resize(dim); - - if (!d_out_tensor->IsInitialized()) { - d_out_tensor->mutable_data(context.GetPlace()); - } - - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, d_out_tensor, static_cast(0)); - } - } - - } else { - auto* d_input = - context.Input(framework::GradVarName("Out")); - auto* d_out = - context.Output(framework::GradVarName("Input")); - - d_out->mutable_data(context.GetPlace()); - - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, d_out, static_cast(0)); - - auto in_dims = d_input->dims(); - - auto in_t = framework::EigenTensor::From(*d_input); - auto out_t = - framework::EigenTensor::From(*d_out, out_dims); - if (need_reverse) { - framework::Tensor reverse_input; - reverse_input.mutable_data(in_dims, context.GetPlace()); - auto reverse_in_t = - framework::EigenTensor::From(reverse_input); - - reverse_in_t.device(place) = in_t.reverse(reverse_axis); - out_t.stridedSlice(starts_indices, ends_indices, strides_indices) - .device(place) = reverse_in_t; - } else { - out_t.stridedSlice(starts_indices, ends_indices, strides_indices) - .device(place) = in_t; - } - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/strided_slice_op_npu.cc b/paddle/fluid/operators/strided_slice_op_npu.cc index 1413975bd8333e1fb22109d7318e71411e2dc670..b142b8f099b8956416467a1acbbd7a51452f8348 100644 --- a/paddle/fluid/operators/strided_slice_op_npu.cc +++ b/paddle/fluid/operators/strided_slice_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/strided_slice_op.h" +#include "paddle/phi/kernels/funcs/strided_slice.h" #include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" @@ -112,16 +112,16 @@ class StridedSliceNPUKernel : public framework::OpKernel { // out dims calculation std::vector out_dims_vector(in_dims.size(), -1); - StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, - decrease_axis, out_dims_vector.data(), axes.size(), - false); + phi::funcs::StridedSliceOutDims(starts, ends, strides, axes, infer_flags, + in_dims, decrease_axis, + out_dims_vector.data(), axes.size(), false); framework::DDim out_dims(phi::make_ddim(out_dims_vector)); // check whether need to reverse (false: stride > 0; true: stride < 0) std::vector reverse_vector(starts.size(), 0); - StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), - reverse_vector.data(), in_dims, infer_flags, - decrease_axis, starts.size()); + phi::funcs::StridedSliceFunctor(starts.data(), ends.data(), strides.data(), + axes.data(), reverse_vector.data(), in_dims, + infer_flags, decrease_axis, starts.size()); // construct the starts_indices, ends_indices and strides_indices tensor for // calling StridedSlice op @@ -317,14 +317,15 @@ class StridedSliceGradNPUKernel : public framework::OpKernel { } std::vector out_dims_vector(input_dims.size(), -1); - StridedSliceOutDims(starts, ends, strides, axes, infer_flags, input_dims, - decrease_axis, out_dims_vector.data(), axes.size(), - false); + phi::funcs::StridedSliceOutDims(starts, ends, strides, axes, infer_flags, + input_dims, decrease_axis, + out_dims_vector.data(), axes.size(), false); std::vector reverse_vector(starts.size(), 0); - StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), - reverse_vector.data(), input_dims, infer_flags, - decrease_axis, starts.size()); + phi::funcs::StridedSliceFunctor(starts.data(), ends.data(), strides.data(), + axes.data(), reverse_vector.data(), + input_dims, infer_flags, decrease_axis, + starts.size()); std::vector starts_indices_vector(D, 0); std::vector ends_indices_vector(out_dims_vector.begin(), diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 086c4d1ee7b48dbda2db088744744fc208ffc676..09fdc321f7081f25020e2d3d65ed60b15bbc490e 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/pooling.h" +#include "paddle/phi/kernels/funcs/strided_slice.h" #include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" @@ -1708,6 +1709,136 @@ void SqueezeInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void StridedSliceInferMeta(const MetaTensor& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + MetaTensor* out, + MetaConfig config) { + auto in_dims = x.dims(); + PADDLE_ENFORCE_LT( + in_dims.size(), + 7, + errors::InvalidArgument( + "The dimension of StridedSlice operator's input should be less " + "than 7, but received dimension is %d.", + in_dims.size())); + + auto starts_ = starts.GetData(); + auto ends_ = ends.GetData(); + auto strides_ = strides.GetData(); + + auto starts_size = starts_.size(); + auto ends_size = ends_.size(); + auto strides_size = strides_.size(); + + for (size_t i = 0; i < axes.size(); ++i) { + PADDLE_ENFORCE_GE( + axes[i], + 0, + errors::InvalidArgument("The axis should be greater than or equal to 0." + "But received %d of axes[%d]", + axes[i], + i)); + PADDLE_ENFORCE_LT( + axes[i], + in_dims.size(), + errors::InvalidArgument( + "The axes should be less than or equal to input tensor's rank." + "But received %d of axes[%d], input tensor shape [%d]", + axes[i], + i, + in_dims.size())); + } + + auto tensor_input = false; + auto HasInput = [](const ScalarArray& arr) { return arr.FromTensor(); }; + if (HasInput(starts) || HasInput(ends) || HasInput(strides)) { + tensor_input = true; + } + if (!HasInput(ends)) { + PADDLE_ENFORCE_EQ( + ends_size, + axes.size(), + errors::InvalidArgument( + "The size of ends attribute in StridedSlice operator is not " + "equal to the size of axes attribute. The ends attribute's size " + "is %d, axes attribute's size is %d.", + ends_size, + axes.size())); + } + if (!HasInput(starts)) { + PADDLE_ENFORCE_EQ( + starts_size, + axes.size(), + errors::InvalidArgument( + "The size of starts attribute in StridedSlice operator is not " + "equal to the size of axes attribute. The starts attribute's " + "size is %d, axes attribute's size is %d.", + starts_size, + axes.size())); + } + if (!HasInput(strides)) { + PADDLE_ENFORCE_EQ( + strides_size, + axes.size(), + errors::InvalidArgument( + "The size of strides attribute in StridedSlice operator is not " + "equal to the size of axes attribute. The strides attribute's " + "size is %d, axes attribute's size is %d.", + strides_size, + axes.size())); + } + // we need to analysis strided slice op is valid for + // the parameter that we get from python front + std::vector out_dims_vector(in_dims.size(), -1); + if (!tensor_input || config.is_runtime) { + phi::funcs::StridedSliceOutDims(starts_, + ends_, + strides_, + axes, + infer_flags, + in_dims, + decrease_axis, + out_dims_vector.data(), + axes.size(), + true); + } + DDim out_dims(phi::make_ddim(out_dims_vector)); + // generate new shape + if (decrease_axis.size() > 0) { + std::vector new_out_shape; + for (size_t i = 0; i < decrease_axis.size(); ++i) { + if (config.is_runtime && infer_flags[i] != -1) { + PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], + 1, + errors::InvalidArgument( + "the size of decrease dimension should be 1, " + "but received %d.", + out_dims[decrease_axis[i]])); + } + out_dims[decrease_axis[i]] = 0; + } + + for (int i = 0; i < out_dims.size(); ++i) { + if (out_dims[i] != 0) { + new_out_shape.push_back(out_dims[i]); + } + } + if (new_out_shape.size() == 0) { + new_out_shape.push_back(1); + } + out_dims = phi::make_ddim(new_out_shape); + } + VLOG(1) << "out_dims: " << out_dims; + out->set_dims(out_dims); + out->share_lod(x); + out->set_dtype(x.dtype()); +} + /* Why not use SumRawInferMeta directly? Because we need make InferMetaFunction's args follow the design of api.yaml */ diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index fe11f7d44ab401e38d81a0637b5bc4ba1a6958bc..ac8d62db363a2a4e928b017369f1b0ae1fca14b5 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -267,6 +267,16 @@ void SqueezeInferMeta(const MetaTensor& x, MetaTensor* xshape, MetaTensor* out); +void StridedSliceInferMeta(const MetaTensor& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void SumInferMeta(const MetaTensor& x, const std::vector& axis, DataType dtype, diff --git a/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..cdc5534d63c085263450036cfcff073fb271909f --- /dev/null +++ b/paddle/phi/kernels/cpu/strided_slice_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(strided_slice_grad, + CPU, + ALL_LAYOUT, + phi::StridedSliceGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(strided_slice_array_grad, + CPU, + ALL_LAYOUT, + phi::StridedSliceArrayGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/strided_slice_kernel.cc b/paddle/phi/kernels/cpu/strided_slice_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..f34a3301fcb42be52f707a84d51bd167ed4cde18 --- /dev/null +++ b/paddle/phi/kernels/cpu/strided_slice_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h" + +PD_REGISTER_KERNEL(strided_slice, + CPU, + ALL_LAYOUT, + phi::StridedSliceKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(strided_slice_array, + CPU, + ALL_LAYOUT, + phi::StridedSliceArrayKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/strided_slice.h b/paddle/phi/kernels/funcs/strided_slice.h new file mode 100644 index 0000000000000000000000000000000000000000..38a611ba26e229fb19e861adf77832a12ab4bc72 --- /dev/null +++ b/paddle/phi/kernels/funcs/strided_slice.h @@ -0,0 +1,659 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +namespace funcs { +static void StridedSliceOutDims(const std::vector& starts, + const std::vector& ends, + const std::vector& strides, + const std::vector& axes, + const std::vector& infer_flags, + const DDim in_dims, + const std::vector& decrease_axis, + int64_t* out_dims_vector, + const size_t size, + bool infer_shape) { + for (int i = 0; i < in_dims.size(); i++) { + out_dims_vector[i] = in_dims[i]; + } + int64_t stride_index, start_index, end_index; + for (size_t i = 0; i < size; i++) { + int axes_index = axes[i]; + start_index = starts[i]; + end_index = ends[i]; + stride_index = strides[i]; + bool decrease_axis_affect = false; + if (start_index == -1 && end_index == 0 && infer_flags[i] == -1) { + auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); + if (ret != decrease_axis.end()) { + decrease_axis_affect = true; + } + } + if (decrease_axis_affect) { + out_dims_vector[axes_index] = 1; + continue; + } + if (infer_shape && infer_flags[i] == -1) { + out_dims_vector[axes_index] = -1; + continue; + } + + PADDLE_ENFORCE_NE( + stride_index, + 0, + errors::InvalidArgument("stride index in StridedSlice operator is 0.")); + int64_t axis_size = in_dims[axes_index]; + + if (axis_size < 0) { + continue; + } + + if (start_index < 0) { + start_index = start_index + axis_size; + } + if (end_index < 0) { + if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition + end_index = end_index + axis_size; + } + } + + if (stride_index < 0) { + start_index = start_index + 1; + end_index = end_index + 1; + } + + bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) || + (stride_index > 0 && (start_index > end_index))); + PADDLE_ENFORCE_EQ(neg_dim_condition, + false, + errors::InvalidArgument( + "The start index and end index are invalid for their " + "corresponding stride.")); + + int64_t left = + std::max(static_cast(0), std::min(start_index, end_index)); + int64_t right = std::min(axis_size, std::max(start_index, end_index)); + int64_t step = std::abs(stride_index); + + auto out_dims_index = (std::abs(right - left) + step - 1) / step; + + out_dims_vector[axes_index] = out_dims_index; + } +} + +static void StridedSliceFunctor(int64_t* starts, + int64_t* ends, + int64_t* strides, + const int* axes, + int* reverse_axis, + const DDim dims, + const std::vector& infer_flags, + const std::vector& decrease_axis, + const size_t size) { + for (size_t axis = 0; axis < size; axis++) { + int64_t axis_size = dims[axes[axis]]; + int axis_index = axis; + if (axis_size < 0) { + starts[axis_index] = 0; + ends[axis_index] = 1; + strides[axis_index] = 1; + } + bool decrease_axis_affect = false; + if (starts[axis_index] == -1 && ends[axis_index] == 0 && + infer_flags[axis_index] == -1) { + auto ret = std::find( + decrease_axis.begin(), decrease_axis.end(), axes[axis_index]); + if (ret != decrease_axis.end()) { + decrease_axis_affect = true; + } + } + // stride must not be zero + if (starts[axis_index] < 0) { + starts[axis_index] = starts[axis_index] + axis_size; + starts[axis_index] = std::max(starts[axis_index], 0); + } + if (ends[axis_index] < 0) { + if (!(ends[axis_index] == -1 && + strides[axis_index] < 0)) { // skip None stop condition + ends[axis_index] = ends[axis_index] + axis_size; + if (ends[axis_index] < 0) { + ends[axis_index] = 0; + } + } + } + if (decrease_axis_affect) { + if (strides[axis_index] < 0) { + ends[axis_index] = starts[axis_index] - 1; + } else { + ends[axis_index] = starts[axis_index] + 1; + } + } + + if (strides[axis_index] < 0) { + reverse_axis[axis_index] = 1; + strides[axis_index] = -strides[axis_index]; + if (starts[axis_index] > ends[axis_index]) { + // swap the reverse + auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1 + : starts[axis_index]; + auto offset = (end_dim - ends[axis_index]) % strides[axis_index]; + offset = offset == 0 ? strides[axis_index] : offset; + + starts[axis_index] = starts[axis_index] + offset; + ends[axis_index] = ends[axis_index] + offset; + } + std::swap(starts[axis_index], ends[axis_index]); + } else { + reverse_axis[axis_index] = 0; + strides[axis_index] = strides[axis_index]; + } + } +} + +template +void StridedSliceCompute(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out) { + auto& place = *dev_ctx.eigen_device(); + DDim in_dims = x.dims(); + + auto starts_ = starts.GetData(); + auto ends_ = ends.GetData(); + auto strides_ = strides.GetData(); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + auto reverse_axis = Eigen::array(); + + std::vector out_dims_vector(in_dims.size(), -1); + StridedSliceOutDims(starts_, + ends_, + strides_, + axes, + infer_flags, + in_dims, + decrease_axis, + out_dims_vector.data(), + axes.size(), + false); + DDim out_dims(phi::make_ddim(out_dims_vector)); + + std::vector reverse_vector(starts_.size(), 0); + StridedSliceFunctor(starts_.data(), + ends_.data(), + strides_.data(), + axes.data(), + reverse_vector.data(), + in_dims, + infer_flags, + decrease_axis, + starts_.size()); + + for (size_t axis = 0; axis < D; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + strides_indices[axis] = 1; + reverse_axis[axis] = false; + } + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts_[axis]; + ends_indices[axis_index] = ends_[axis]; + strides_indices[axis_index] = strides_[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; + } + + auto out_dims_origin = out_dims; + if (decrease_axis.size() > 0) { + std::vector new_out_shape; + for (size_t i = 0; i < decrease_axis.size(); ++i) { + PADDLE_ENFORCE_EQ( + out_dims[decrease_axis[i]], + 1, + errors::InvalidArgument( + "the size of decrease dimension should be 1, but received %d.", + out_dims[decrease_axis[i]])); + out_dims_origin[decrease_axis[i]] = 0; + } + + for (int i = 0; i < out_dims_origin.size(); ++i) { + if (out_dims_origin[i] != 0) { + new_out_shape.push_back(out_dims_origin[i]); + } + } + if (new_out_shape.size() == 0) { + new_out_shape.push_back(1); + } + out_dims_origin = phi::make_ddim(new_out_shape); + } + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + + out->Resize(out_dims); + dev_ctx.template Alloc(out); + auto in_t = EigenTensor::From(x); + auto out_t = EigenTensor::From( + *out, out_dims); + if (need_reverse) { + DenseTensor tmp; + tmp.Resize(out_dims); + dev_ctx.template Alloc(&tmp); + + auto tmp_t = + EigenTensor::From(tmp); + tmp_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, strides_indices); + out_t.device(place) = tmp_t.reverse(reverse_axis); + } else { + out_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, strides_indices); + } + + if (decrease_axis.size() > 0) { + out->Resize(out_dims_origin); + } +} + +template +void StridedSliceCompute(const Context& dev_ctx, + const std::vector& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + std::vector out) { + const int64_t size = x.size(); + auto in_dims = phi::make_ddim({size}); + + auto starts_ = starts.GetData(); + auto ends_ = ends.GetData(); + auto strides_ = strides.GetData(); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + auto reverse_axis = Eigen::array(); + + std::vector out_dims_vector(in_dims.size(), -1); + StridedSliceOutDims(starts_, + ends_, + strides_, + axes, + infer_flags, + in_dims, + decrease_axis, + out_dims_vector.data(), + axes.size(), + false); + DDim out_dims(phi::make_ddim(out_dims_vector)); + + std::vector reverse_vector(starts_.size(), 0); + StridedSliceFunctor(starts_.data(), + ends_.data(), + strides_.data(), + axes.data(), + reverse_vector.data(), + in_dims, + infer_flags, + decrease_axis, + starts_.size()); + + for (size_t axis = 0; axis < D; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + strides_indices[axis] = 1; + reverse_axis[axis] = false; + } + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts_[axis]; + ends_indices[axis_index] = ends_[axis]; + strides_indices[axis_index] = strides_[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; + } + + auto out_dims_origin = out_dims; + if (decrease_axis.size() > 0) { + std::vector new_out_shape; + for (size_t i = 0; i < decrease_axis.size(); ++i) { + PADDLE_ENFORCE_EQ( + out_dims[decrease_axis[i]], + 1, + errors::InvalidArgument( + "the size of decrease dimension should be 1, but received %d.", + out_dims[decrease_axis[i]])); + out_dims_origin[decrease_axis[i]] = 0; + } + + for (int i = 0; i < out_dims_origin.size(); ++i) { + if (out_dims_origin[i] != 0) { + new_out_shape.push_back(out_dims_origin[i]); + } + } + if (new_out_shape.size() == 0) { + new_out_shape.push_back(1); + } + out_dims_origin = phi::make_ddim(new_out_shape); + } + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + + PADDLE_ENFORCE_EQ( + starts_indices.size(), + 1, + errors::InvalidArgument( + "When the input of 'strided_slice_op' is `TensorArray`, the " + "dimension of start index should be 1, but received %d.", + starts_indices.size())); + + PADDLE_ENFORCE_EQ( + ends_indices.size(), + 1, + errors::InvalidArgument( + "When the input of 'strided_slice_op' is `TensorArray`, the " + "dimension of end index should be 1, but received %d.", + ends_indices.size())); + + PADDLE_ENFORCE_EQ( + strides_indices.size(), + 1, + errors::InvalidArgument( + "When the input of 'strided_slice_op' is `TensorArray`, the " + "dimension of stride should be 1, but received %d.", + strides_indices.size())); + + PADDLE_ENFORCE_EQ( + out_dims_origin.size(), + 1, + errors::InvalidArgument( + "When the input of 'strided_slice_op' is `TensorArray`, the " + "dimension of Output should be 1, but received %d", + out_dims_origin.size())); + + out.resize(out_dims_origin[0]); + size_t const in_array_size = x.size(); + for (size_t i = 0; i < out.size(); i++) { + size_t in_offset = + (starts_indices[0] % in_array_size) + i * strides_indices[0]; + + int64_t out_offset = i; + if (need_reverse) { + out_offset = out.size() - i - 1; + } + + auto* in_tensor = x.at(in_offset); + PADDLE_ENFORCE_GT( + in_tensor->memory_size(), + 0, + errors::PreconditionNotMet( + "The input LoDTensorArray Input[%d] holds no memory.", in_offset)); + auto* out_tensor = out.at(out_offset); + out_tensor->Resize(in_tensor->dims()); + + phi::Copy( + dev_ctx, *in_tensor, dev_ctx.GetPlace(), false, out_tensor); + out_tensor->set_lod(in_tensor->lod()); + } +} + +template +void StridedSliceGradCompute(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* x_grad) { + auto& place = *dev_ctx.eigen_device(); + DDim out_dims = x.dims(); + + auto starts_ = starts.GetData(); + auto ends_ = ends.GetData(); + auto strides_ = strides.GetData(); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + + auto reverse_axis = Eigen::array(); + std::vector reverse_vector(starts_.size(), 0); + + StridedSliceFunctor(starts_.data(), + ends_.data(), + strides_.data(), + axes.data(), + reverse_vector.data(), + out_dims, + infer_flags, + decrease_axis, + starts_.size()); + + for (size_t axis = 0; axis < D; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + strides_indices[axis] = 1; + } + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts_[axis]; + ends_indices[axis_index] = ends_[axis]; + strides_indices[axis_index] = strides_[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; + } + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, x_grad, static_cast(0)); + + auto out_grad_dims = out_grad.dims(); + + auto in_t = + EigenTensor::From(out_grad); + auto out_t = EigenTensor::From( + *x_grad, out_dims); + if (need_reverse) { + DenseTensor reverse_input; + reverse_input.Resize(out_grad_dims); + dev_ctx.template Alloc(&reverse_input); + + auto reverse_in_t = + EigenTensor::From( + reverse_input); + + reverse_in_t.device(place) = in_t.reverse(reverse_axis); + out_t.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(place) = reverse_in_t; + } else { + out_t.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(place) = in_t; + } +} + +template +void StridedSliceGradCompute(const Context& dev_ctx, + const std::vector& x, + const std::vector& out_grad, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + std::vector x_grad) { + // Note(weixin):Since the shape of `framework::GradVarName("Input")` of + // StridedSliceGrad cannot be calculated by + // `framework::GradVarName("Output")`, the dim of "Input" is used to + // calculate the output shape. when set it to inplace OP, there may be + // some problems. + const int64_t size = x.size(); + DDim out_dims = phi::make_ddim({size}); + + auto starts_ = starts.GetData(); + auto ends_ = ends.GetData(); + auto strides_ = strides.GetData(); + + auto starts_indices = Eigen::DSizes(); + auto ends_indices = Eigen::DSizes(); + auto strides_indices = Eigen::DSizes(); + + auto reverse_axis = Eigen::array(); + std::vector reverse_vector(starts_.size(), 0); + + StridedSliceFunctor(starts_.data(), + ends_.data(), + strides_.data(), + axes.data(), + reverse_vector.data(), + out_dims, + infer_flags, + decrease_axis, + starts_.size()); + + for (size_t axis = 0; axis < D; axis++) { + starts_indices[axis] = 0; + ends_indices[axis] = out_dims[axis]; + strides_indices[axis] = 1; + } + for (size_t axis = 0; axis < axes.size(); axis++) { + int axis_index = axes[axis]; + starts_indices[axis_index] = starts_[axis]; + ends_indices[axis_index] = ends_[axis]; + strides_indices[axis_index] = strides_[axis]; + reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; + } + + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + PADDLE_ENFORCE_EQ( + starts_indices.size(), + 1, + errors::InvalidArgument( + "When the input of 'strided_slice_grad_op' is `TensorArray`, the " + "dimension of start index should be 1, but received %d.", + starts_indices.size())); + PADDLE_ENFORCE_EQ( + ends_indices.size(), + 1, + errors::InvalidArgument( + "When the input of 'strided_slice_op' is `TensorArray`, the " + "dimension of end index should be 1, but received %d.", + ends_indices.size())); + PADDLE_ENFORCE_EQ( + strides_indices.size(), + 1, + errors::InvalidArgument( + "When the input of 'strided_slice_grad_op' is `TensorArray`, the " + "dimension of stride should be 1, but received %d.", + strides_indices.size())); + + PADDLE_ENFORCE_EQ( + out_dims.size(), + 1, + errors::InvalidArgument( + "When the output of `strided_slice_grad_op` is `TensorArray`, " + "the dimension of output should be 1, but received %d.", + out_dims.size())); + + auto const d_out_array_size = x_grad.size(); + + for (size_t j = 0; j < d_out_array_size; j++) { + auto& dim = x.at(j)->dims(); + auto* d_out_tensor = x_grad.at(j); + + int64_t sub = j - starts_indices[0]; + + int64_t in_offset = sub / strides_indices[0]; + + if (need_reverse) { + in_offset = out_grad.size() - in_offset - 1; + } + + if ((sub % strides_indices[0] == 0) && (0 <= in_offset) && + (static_cast(in_offset) < out_grad.size())) { + auto* in_tensor = out_grad.at(in_offset); + PADDLE_ENFORCE_GT( + in_tensor->memory_size(), + 0, + errors::PreconditionNotMet( + "The input LoDTensorArray Input[%d] holds no memory.", + in_offset)); + + phi::Copy( + dev_ctx, *in_tensor, dev_ctx.GetPlace(), false, d_out_tensor); + d_out_tensor->set_lod(in_tensor->lod()); + } else { + d_out_tensor->Resize(dim); + + if (!d_out_tensor->IsInitialized()) { + dev_ctx.template Alloc(d_out_tensor); + } + + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, d_out_tensor, static_cast(0)); + } + } +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..5f31d488533a6e082bea9809b7623243ceea5056 --- /dev/null +++ b/paddle/phi/kernels/gpu/strided_slice_grad_kernel.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(strided_slice_grad, + GPU, + ALL_LAYOUT, + phi::StridedSliceGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(strided_slice_array_grad, + GPU, + ALL_LAYOUT, + phi::StridedSliceArrayGradKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/strided_slice_kernel.cu b/paddle/phi/kernels/gpu/strided_slice_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..ff10718edb323e482627666e58fadaf50a99e22b --- /dev/null +++ b/paddle/phi/kernels/gpu/strided_slice_kernel.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/strided_slice_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h" + +PD_REGISTER_KERNEL(strided_slice, + GPU, + ALL_LAYOUT, + phi::StridedSliceKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(strided_slice_array, + GPU, + ALL_LAYOUT, + phi::StridedSliceArrayKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h index 4947170088cba9701ad1065098451b97139bfc95..0e39c0a726bf4c68fc298ea4af098e59bf3da1c1 100644 --- a/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_grad_kernel_impl.h @@ -22,8 +22,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/math_function.h" - -#include "paddle/fluid/operators/strided_slice_op.h" +#include "paddle/phi/kernels/funcs/strided_slice.h" namespace phi { @@ -73,29 +72,29 @@ void SetValueGradImpl(const Context& dev_ctx, std::vector starts_local = starts.GetData(); std::vector ends_local = ends.GetData(); std::vector steps_local = steps.GetData(); - paddle::operators::StridedSliceOutDims(starts_local, - ends_local, - steps_local, - axes_int32, - infer_flags, - in_dims, - decrease_axis_int32, - out_dims_vector.data(), - axes.size(), - false); + funcs::StridedSliceOutDims(starts_local, + ends_local, + steps_local, + axes_int32, + infer_flags, + in_dims, + decrease_axis_int32, + out_dims_vector.data(), + axes.size(), + false); DDim out_dims(phi::make_ddim(out_dims_vector)); std::vector reverse_vector(starts_local.size(), 0); - paddle::operators::StridedSliceFunctor(starts_local.data(), - ends_local.data(), - steps_local.data(), - axes_int32.data(), - reverse_vector.data(), - in_dims, - infer_flags, - decrease_axis_int32, - starts_local.size()); + funcs::StridedSliceFunctor(starts_local.data(), + ends_local.data(), + steps_local.data(), + axes_int32.data(), + reverse_vector.data(), + in_dims, + infer_flags, + decrease_axis_int32, + starts_local.size()); auto starts_indices = Eigen::DSizes(); auto ends_indices = Eigen::DSizes(); diff --git a/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h b/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1d75b32a5f21db59923a8352d39ef09f9649d6b1 --- /dev/null +++ b/paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/strided_slice_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/strided_slice.h" + +namespace phi { + +template +void StridedSliceGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* x_grad) { + int rank = x.dims().size(); +#define SLICE_CASE(Rank) \ + case Rank: \ + funcs::StridedSliceGradCompute(dev_ctx, \ + x, \ + out_grad, \ + axes, \ + starts, \ + ends, \ + strides, \ + infer_flags, \ + decrease_axis, \ + x_grad); \ + break; + + switch (rank) { + SLICE_CASE(1) + SLICE_CASE(2) + SLICE_CASE(3) + SLICE_CASE(4) + SLICE_CASE(5) + SLICE_CASE(6) + } +#undef SLICE_CASE +} + +template +void StridedSliceArrayGradKernel( + const Context& dev_ctx, + const std::vector& x, + const std::vector& out_grad, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + std::vector x_grad) { + funcs::StridedSliceGradCompute(dev_ctx, + x, + out_grad, + axes, + starts, + ends, + strides, + infer_flags, + decrease_axis, + x_grad); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/strided_slice_kernel_impl.h b/paddle/phi/kernels/impl/strided_slice_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f98ac1aedcf17ef63ec6e20d46f835d778318218 --- /dev/null +++ b/paddle/phi/kernels/impl/strided_slice_kernel_impl.h @@ -0,0 +1,71 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/strided_slice_kernel.h" + +#include "paddle/phi/kernels/funcs/strided_slice.h" + +namespace phi { + +template +void StridedSliceKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out) { + int rank = x.dims().size(); +#define SLICE_CASE(Rank) \ + case Rank: \ + funcs::StridedSliceCompute(dev_ctx, \ + x, \ + axes, \ + starts, \ + ends, \ + strides, \ + infer_flags, \ + decrease_axis, \ + out); \ + break; + + switch (rank) { + SLICE_CASE(1) + SLICE_CASE(2) + SLICE_CASE(3) + SLICE_CASE(4) + SLICE_CASE(5) + SLICE_CASE(6) + } +#undef SLICE_CASE +} + +template +void StridedSliceArrayKernel(const Context& dev_ctx, + const std::vector& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + std::vector out) { + funcs::StridedSliceCompute( + dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out); +} + +} // namespace phi diff --git a/paddle/phi/kernels/strided_slice_grad_kernel.h b/paddle/phi/kernels/strided_slice_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f753402e498338a5262d9e1bf3065fe0477fbb49 --- /dev/null +++ b/paddle/phi/kernels/strided_slice_grad_kernel.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void StridedSliceGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* x_grad); + +template +void StridedSliceArrayGradKernel( + const Context& dev_ctx, + const std::vector& x, + const std::vector& out_grad, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + std::vector x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/strided_slice_kernel.h b/paddle/phi/kernels/strided_slice_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f23d1c04d5da37376a3fa949fb9d3f2c5a29de5d --- /dev/null +++ b/paddle/phi/kernels/strided_slice_kernel.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void StridedSliceKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out); + +template +void StridedSliceArrayKernel(const Context& dev_ctx, + const std::vector& x, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const ScalarArray& strides, + const std::vector& infer_flags, + const std::vector& decrease_axis, + std::vector out); +} // namespace phi diff --git a/paddle/phi/ops/compat/strided_slice_sig.cc b/paddle/phi/ops/compat/strided_slice_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..70ce2e3e07ce900289256dea2caec047c017fc7c --- /dev/null +++ b/paddle/phi/ops/compat/strided_slice_sig.cc @@ -0,0 +1,704 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +KernelSignature StridedSliceOpArgumentMapping( + const ArgumentMappingContext& ctx) { + const auto& starts = paddle::any_cast>(ctx.Attr("starts")); + const auto& ends = paddle::any_cast>(ctx.Attr("ends")); + const auto& strides = paddle::any_cast>(ctx.Attr("strides")); + + bool use_attr_starts = !ctx.IsRuntime() && !starts.empty(); + bool use_attr_ends = !ctx.IsRuntime() && !ends.empty(); + bool use_attr_strides = !ctx.IsRuntime() && !strides.empty(); + + std::string starts_key = + ctx.HasInput("StartsTensor") + ? "StartsTensor" + : (ctx.InputSize("StartsTensorList") > 0 + ? (use_attr_starts ? "starts" : "StartsTensorList") + : "starts"); + std::string ends_key = + ctx.HasInput("EndsTensor") + ? "EndsTensor" + : (ctx.InputSize("EndsTensorList") > 0 + ? (use_attr_ends ? "ends" : "EndsTensorList") + : "ends"); + std::string strides_key = + ctx.HasInput("StridesTensor") + ? "StridesTensor" + : (ctx.InputSize("StridesTensorList") > 0 + ? (use_attr_strides ? "strides" : "StridesTensorList") + : "strides"); + + paddle::SmallVector inputs = {"Input"}; + paddle::SmallVector attrs = {"axes", + starts_key, + ends_key, + strides_key, + "infer_flags", + "decrease_axis"}; + paddle::SmallVector outputs = {"Out"}; + + std::string op_type; + if (ctx.IsDenseTensorVectorInput("Input")) { + op_type = "strided_slice_array"; + } else { + op_type = "strided_slice"; + } + // NOTE(dev): Use this to avoid regularization. + KernelSignature sig(op_type, inputs, attrs, outputs); + return sig; +} + +KernelSignature StridedSliceGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + const auto& starts = paddle::any_cast>(ctx.Attr("starts")); + const auto& ends = paddle::any_cast>(ctx.Attr("ends")); + const auto& strides = paddle::any_cast>(ctx.Attr("strides")); + + bool use_attr_starts = !ctx.IsRuntime() && !starts.empty(); + bool use_attr_ends = !ctx.IsRuntime() && !ends.empty(); + bool use_attr_strides = !ctx.IsRuntime() && !strides.empty(); + + std::string starts_key = + ctx.HasInput("StartsTensor") + ? "StartsTensor" + : (ctx.InputSize("StartsTensorList") > 0 + ? (use_attr_starts ? "starts" : "StartsTensorList") + : "starts"); + std::string ends_key = + ctx.HasInput("EndsTensor") + ? "EndsTensor" + : (ctx.InputSize("EndsTensorList") > 0 + ? (use_attr_ends ? "ends" : "EndsTensorList") + : "ends"); + std::string strides_key = + ctx.HasInput("StridesTensor") + ? "StridesTensor" + : (ctx.InputSize("StridesTensorList") > 0 + ? (use_attr_strides ? "strides" : "StridesTensorList") + : "strides"); + + paddle::SmallVector inputs = {"Input", GradVarName("Out")}; + paddle::SmallVector attrs = {"axes", + starts_key, + ends_key, + strides_key, + "infer_flags", + "decrease_axis"}; + paddle::SmallVector outputs = {GradVarName("Input")}; + + std::string op_type; + if (ctx.IsDenseTensorVectorInput("Input")) { + op_type = "strided_slice_array_grad"; + } else { + op_type = "strided_slice_grad"; + } + + // NOTE(dev): Use this to avoid regularization. + KernelSignature sig(op_type, inputs, attrs, outputs); + return sig; +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(strided_slice, phi::StridedSliceOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(strided_slice_grad, + phi::StridedSliceGradOpArgumentMapping); + +/* +****************************************************************** +NOTE: The following codes are for 'get_compat_kernel_signature.py' + DO NOT EDIT IT if you don't know the mechanism. +****************************************************************** + +############################ Forward ############################ + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensor", "ends", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensorList", +"starts","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "ends", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "StartsTensorList", "ends", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice}", {"Input"}, + {"axes", "starts", "ends", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensor", "ends", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "EndsTensorList", +"starts","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "ends", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "StartsTensorList", "ends", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {"Out"}); + +return KernelSignature("{strided_slice_array}", {"Input"}, + {"axes", "starts", "ends", "starts","infer_flags", +"decrease_axis"}, + {"Out"}); + +############################ Backward ############################ + + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "ends", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensorList", +"starts","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "ends", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "ends", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")}, + {"axes", "starts", "ends", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensor", "ends", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensor", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensor","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "EndsTensorList", +"starts","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "ends", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "StartsTensorList", "ends", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "EndsTensor", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "EndsTensor", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "EndsTensorList", +"StartsTensorList","infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "EndsTensorList", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "ends", "StartsTensor","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "ends", "StartsTensorList","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); + +return KernelSignature("{strided_slice_array_grad}", {"Input", +GradVarName("Out")}, + {"axes", "starts", "ends", "starts","infer_flags", +"decrease_axis"}, + {GradVarName("Input")}); +*/