未验证 提交 c33b4f95 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Migrate strided_slice into Phi (#40708)

* [Phi] Migrate strided_slice into Phi

* [Phi] Migrate strided_slice into Phi

* fix compilation problem
上级 fd0c0e3c
......@@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/slice_op.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace paddle {
namespace operators {
......@@ -28,149 +33,6 @@ class StridedSliceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice");
auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime.
return;
}
}
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
platform::errors::InvalidArgument(
"The dimension of StridedSlice operator's input should be less "
"than 7, but received dimension is %d.",
in_dims.size()));
auto starts_int = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends_int = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides_int = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
auto starts_size = starts.size();
auto ends_size = ends.size();
auto strides_size = strides.size();
for (size_t i = 0; i < axes.size(); ++i) {
PADDLE_ENFORCE_GE(axes[i], 0,
platform::errors::InvalidArgument(
"The axis should be greater than or equal to 0."
"But received %d of axes[%d]",
axes[i], i));
PADDLE_ENFORCE_LT(
axes[i], in_dims.size(),
platform::errors::InvalidArgument(
"The axes should be less than or equal to input tensor's rank."
"But received %d of axes[%d], input tensor shape [%d]",
axes[i], i, in_dims.size()));
}
if (ctx->HasInputs("StartsTensorList")) {
auto StartsTensorList = ctx->Inputs("StartsTensorList");
PADDLE_ENFORCE_GT(
StartsTensorList.size(), 0,
platform::errors::InvalidArgument(
"StridedSlice operator's StartsTensorList is empty."));
starts_size = StartsTensorList.size();
}
if (ctx->HasInputs("EndsTensorList")) {
auto EndsTensorList = ctx->Inputs("EndsTensorList");
PADDLE_ENFORCE_GT(
EndsTensorList.size(), 0,
platform::errors::InvalidArgument(
"StridedSlice operator's EndsTensorList is empty."));
ends_size = EndsTensorList.size();
}
if (ctx->HasInputs("StridesTensorList")) {
auto StridesTensorList = ctx->Inputs("StridesTensorList");
PADDLE_ENFORCE_GT(
StridesTensorList.size(), 0,
platform::errors::InvalidArgument(
"StridedSlice operator's StridesTensorList is empty."));
strides_size = StridesTensorList.size();
}
auto tensor_input = false;
if (ctx->HasInput("EndsTensor") || ctx->HasInput("StartsTensor") ||
ctx->HasInput("StridesTensor")) {
tensor_input = true;
}
if (!ctx->HasInput("EndsTensor")) {
PADDLE_ENFORCE_EQ(
ends_size, axes.size(),
platform::errors::InvalidArgument(
"The size of ends attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The ends attribute's size "
"is %d, axes attribute's size is %d.",
ends_size, axes.size()));
}
if (!ctx->HasInput("StartsTensor")) {
PADDLE_ENFORCE_EQ(
starts_size, axes.size(),
platform::errors::InvalidArgument(
"The size of starts attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The starts attribute's "
"size is %d, axes attribute's size is %d.",
starts_size, axes.size()));
}
if (!ctx->HasInput("StridesTensor")) {
PADDLE_ENFORCE_EQ(
strides_size, axes.size(),
platform::errors::InvalidArgument(
"The size of strides attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The strides attribute's "
"size is %d, axes attribute's size is %d.",
strides_size, axes.size()));
}
// we need to analysis strided slice op is valid for
// the parameter that we get from python front
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
if (!tensor_input) {
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
true);
}
framework::DDim out_dims(phi::make_ddim(out_dims_vector));
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument(
"the size of decrease dimension should be 1, "
"but received %d.",
out_dims[decrease_axis[i]]));
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = phi::make_ddim(new_out_shape);
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("Input", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
......@@ -304,26 +166,6 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"StridedSliceGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "StridedSliceGrad");
auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime
return;
}
}
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
......@@ -384,35 +226,19 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape,
PD_INFER_META(phi::StridedSliceInferMeta));
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
ops::StridedSliceOpVarTypeInference);
ops::StridedSliceOpVarTypeInference, StridedSliceInferShape);
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice_grad, StridedSliceGradInferShape,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInferer,
ops::StridedSliceGradOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
ops::StridedSliceGradOpVarTypeInference,
StridedSliceGradInferShape);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/platform/complex.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cstdlib>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/slice_op.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
static void StridedSliceOutDims(
const std::vector<int64_t>& starts, const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides, const std::vector<int>& axes,
const std::vector<int>& infer_flags, const framework::DDim in_dims,
const std::vector<int>& decrease_axis, int64_t* out_dims_vector,
const size_t size, bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
int64_t stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) {
int axes_index = axes[i];
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
bool decrease_axis_affect = false;
if (start_index == -1 && end_index == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
if (decrease_axis_affect) {
out_dims_vector[axes_index] = 1;
continue;
}
if (infer_shape && infer_flags[i] == -1) {
out_dims_vector[axes_index] = -1;
continue;
}
PADDLE_ENFORCE_NE(stride_index, 0,
platform::errors::InvalidArgument(
"stride index in StridedSlice operator is 0."));
int64_t axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
if (start_index < 0) {
start_index = start_index + axis_size;
}
if (end_index < 0) {
if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition
end_index = end_index + axis_size;
}
}
if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
}
bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
(stride_index > 0 && (start_index > end_index)));
PADDLE_ENFORCE_EQ(neg_dim_condition, false,
platform::errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));
int64_t left =
std::max(static_cast<int64_t>(0), std::min(start_index, end_index));
int64_t right = std::min(axis_size, std::max(start_index, end_index));
int64_t step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
}
}
static void StridedSliceFunctor(int64_t* starts, int64_t* ends,
int64_t* strides, int* axes, int* reverse_axis,
const framework::DDim dims,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
const size_t size) {
for (size_t axis = 0; axis < size; axis++) {
int64_t axis_size = dims[axes[axis]];
int axis_index = axis;
if (axis_size < 0) {
starts[axis_index] = 0;
ends[axis_index] = 1;
strides[axis_index] = 1;
}
bool decrease_axis_affect = false;
if (starts[axis_index] == -1 && ends[axis_index] == 0 &&
infer_flags[axis_index] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(),
axes[axis_index]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (ends[axis_index] < 0) {
if (!(ends[axis_index] == -1 &&
strides[axis_index] < 0)) { // skip None stop condition
ends[axis_index] = ends[axis_index] + axis_size;
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
}
}
if (decrease_axis_affect) {
if (strides[axis_index] < 0) {
ends[axis_index] = starts[axis_index] - 1;
} else {
ends[axis_index] = starts[axis_index] + 1;
}
}
if (strides[axis_index] < 0) {
reverse_axis[axis_index] = 1;
strides[axis_index] = -strides[axis_index];
if (starts[axis_index] > ends[axis_index]) {
// swap the reverse
auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1
: starts[axis_index];
auto offset = (end_dim - ends[axis_index]) % strides[axis_index];
offset = offset == 0 ? strides[axis_index] : offset;
starts[axis_index] = starts[axis_index] + offset;
ends[axis_index] = ends[axis_index] + offset;
}
std::swap(starts[axis_index], ends[axis_index]);
} else {
reverse_axis[axis_index] = 0;
strides[axis_index] = strides[axis_index];
}
}
}
template <typename DeviceContext, typename T>
class StridedSliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
int rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceCompute<1>(ctx);
break;
case 2:
StridedSliceCompute<2>(ctx);
break;
case 3:
StridedSliceCompute<3>(ctx);
break;
case 4:
StridedSliceCompute<4>(ctx);
break;
case 5:
StridedSliceCompute<5>(ctx);
break;
case 6:
StridedSliceCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void StridedSliceCompute(const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
framework::DDim in_dims;
auto* input_var = context.InputVar("Input");
bool is_input_var_array = input_var->IsType<LoDTensorArray>();
if (is_input_var_array) {
const int64_t size = input_var->Get<framework::LoDTensorArray>().size();
in_dims = phi::make_ddim({size});
} else {
in_dims = context.Input<framework::Tensor>("Input")->dims();
}
auto starts_int = context.Attr<std::vector<int>>("starts");
auto ends_int = context.Attr<std::vector<int>>("ends");
auto strides_int = context.Attr<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = GetDataFromTensor<int64_t>(strides_tensor);
}
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
framework::DDim out_dims(phi::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
reverse_axis[axis] = false;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts[axis];
ends_indices[axis_index] = ends[axis];
strides_indices[axis_index] = strides[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument(
"the size of decrease dimension should be 1, but received %d.",
out_dims[decrease_axis[i]]));
out_dims_origin[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims_origin.size(); ++i) {
if (out_dims_origin[i] != 0) {
new_out_shape.push_back(out_dims_origin[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims_origin = phi::make_ddim(new_out_shape);
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
if (is_input_var_array) {
PADDLE_ENFORCE_EQ(
starts_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of start index should be 1, but received %d.",
starts_indices.size()));
PADDLE_ENFORCE_EQ(
ends_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of end index should be 1, but received %d.",
ends_indices.size()));
PADDLE_ENFORCE_EQ(
strides_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of stride should be 1, but received %d.",
strides_indices.size()));
auto* output_var = context.OutputVar("Out");
PADDLE_ENFORCE_EQ(
output_var->IsType<LoDTensorArray>(), true,
platform::errors::InvalidArgument(
"When the input of `strided_slice_op` is `TensorArray`. The "
"output is excepted `TensorArray` , but received %s.",
framework::ToTypeName(output_var->Type())));
PADDLE_ENFORCE_EQ(
out_dims_origin.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of Output should be 1, but received %d",
out_dims_origin.size()));
auto& in_array = input_var->Get<framework::LoDTensorArray>();
auto* out_array = context.Output<framework::LoDTensorArray>("Out");
out_array->resize(out_dims_origin[0]);
size_t const in_array_size = in_array.size();
for (size_t i = 0; i < out_array->size(); i++) {
size_t in_offset =
(starts_indices[0] % in_array_size) + i * strides_indices[0];
int64_t out_offset = i;
if (need_reverse) {
out_offset = out_array->size() - i - 1;
}
auto& in_tensor = in_array.at(in_offset);
PADDLE_ENFORCE_GT(
in_tensor.memory_size(), 0,
platform::errors::PreconditionNotMet(
"The input LoDTensorArray Input[%d] holds no memory.",
in_offset));
auto* out_tensor = &out_array->at(out_offset);
out_tensor->set_lod(in_tensor.lod());
paddle::framework::TensorCopy(in_tensor, context.GetPlace(),
out_tensor);
}
} else {
auto in = context.Input<framework::Tensor>("Input");
auto out = context.Output<framework::Tensor>("Out");
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*in);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*out, out_dims);
if (need_reverse) {
framework::Tensor tmp;
tmp.mutable_data<T>(out_dims, context.GetPlace());
auto tmp_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
}
}
}
};
template <typename DeviceContext, typename T>
class StridedSliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
int rank = is_tensor_array
? 1
: ctx.Input<framework::Tensor>("Input")->dims().size();
switch (rank) {
case 1:
StridedSliceGradCompute<1>(ctx);
break;
case 2:
StridedSliceGradCompute<2>(ctx);
break;
case 3:
StridedSliceGradCompute<3>(ctx);
break;
case 4:
StridedSliceGradCompute<4>(ctx);
break;
case 5:
StridedSliceGradCompute<5>(ctx);
break;
case 6:
StridedSliceGradCompute<6>(ctx);
break;
}
}
private:
template <size_t D>
void StridedSliceGradCompute(
const framework::ExecutionContext& context) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto& dev_ctx = context.template device_context<DeviceContext>();
framework::DDim out_dims;
auto* out_var = context.OutputVar(framework::GradVarName("Input"));
bool is_out_var_array = out_var->IsType<LoDTensorArray>();
if (is_out_var_array) {
// Note(weixin):Since the shape of `framework::GradVarName("Input")` of
// StridedSliceGrad cannot be calculated by
// `framework::GradVarName("Output")`, the dim of "Input" is used to
// calculate the output shape. when set it to inplace OP, there may be
// some problems.
const int64_t size =
context.Input<framework::LoDTensorArray>("Input")->size();
out_dims = phi::make_ddim({size});
} else {
out_dims =
context.Output<framework::Tensor>(framework::GradVarName("Input"))
->dims();
}
auto starts_int = context.Attr<std::vector<int>>("starts");
auto ends_int = context.Attr<std::vector<int>>("ends");
auto strides_int = context.Attr<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = context.Attr<std::vector<int>>("axes");
auto infer_flags = context.Attr<std::vector<int>>("infer_flags");
auto decrease_axis = context.Attr<std::vector<int>>("decrease_axis");
auto list_new_ends_tensor =
context.MultiInput<framework::Tensor>("EndsTensorList");
auto list_new_starts_tensor =
context.MultiInput<framework::Tensor>("StartsTensorList");
auto list_new_strides_tensor =
context.MultiInput<framework::Tensor>("StridesTensorList");
if (list_new_starts_tensor.size() > 0) {
starts = GetDataFromTensorList<int64_t>(list_new_starts_tensor);
} else if (context.HasInput("StartsTensor")) {
auto* starts_tensor = context.Input<framework::Tensor>("StartsTensor");
starts = GetDataFromTensor<int64_t>(starts_tensor);
}
if (list_new_ends_tensor.size() > 0) {
ends = GetDataFromTensorList<int64_t>(list_new_ends_tensor);
} else if (context.HasInput("EndsTensor")) {
auto* ends_tensor = context.Input<framework::Tensor>("EndsTensor");
ends = GetDataFromTensor<int64_t>(ends_tensor);
}
if (list_new_strides_tensor.size() > 0) {
strides = GetDataFromTensorList<int64_t>(list_new_strides_tensor);
} else if (context.HasInput("StridesTensor")) {
auto* strides_tensor = context.Input<framework::Tensor>("StridesTensor");
strides = GetDataFromTensor<int64_t>(strides_tensor);
}
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), out_dims, infer_flags,
decrease_axis, starts.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts[axis];
ends_indices[axis_index] = ends[axis];
strides_indices[axis_index] = strides[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
if (is_out_var_array) {
PADDLE_ENFORCE_EQ(
starts_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_grad_op' is `TensorArray`, the "
"dimension of start index should be 1, but received %d.",
starts_indices.size()));
PADDLE_ENFORCE_EQ(
ends_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of end index should be 1, but received %d.",
ends_indices.size()));
PADDLE_ENFORCE_EQ(
strides_indices.size(), 1,
platform::errors::InvalidArgument(
"When the input of 'strided_slice_grad_op' is `TensorArray`, the "
"dimension of stride should be 1, but received %d.",
strides_indices.size()));
auto* d_input_var = context.InputVar(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(
d_input_var->IsType<LoDTensorArray>(), true,
platform::errors::InvalidArgument(
"When the output of `strided_slice_grad_op` is "
"`TensorArray`, the input is excepted `TensorArray` , "
"but received %s.",
framework::ToTypeName(d_input_var->Type())));
PADDLE_ENFORCE_EQ(
out_dims.size(), 1,
platform::errors::InvalidArgument(
"When the output of `strided_slice_grad_op` is `TensorArray`, "
"the dimension of output should be 1, but received %d.",
out_dims.size()));
auto& d_in_array = d_input_var->Get<framework::LoDTensorArray>();
auto* d_out_array = context.Output<framework::LoDTensorArray>(
framework::GradVarName("Input"));
d_out_array->resize(out_dims[0]);
auto const d_out_array_size = d_out_array->size();
auto* input_tensor_array =
context.Input<framework::LoDTensorArray>("Input");
for (size_t j = 0; j < d_out_array_size; j++) {
auto& dim = input_tensor_array->at(j).dims();
auto* d_out_tensor = &d_out_array->at(j);
int64_t sub = j - starts_indices[0];
int64_t in_offset = sub / strides_indices[0];
if (need_reverse) {
in_offset = d_in_array.size() - in_offset - 1;
}
if ((sub % strides_indices[0] == 0) && (0 <= in_offset) &&
(static_cast<size_t>(in_offset) < d_in_array.size())) {
auto& in_tensor = d_in_array.at(in_offset);
PADDLE_ENFORCE_GT(
in_tensor.memory_size(), 0,
platform::errors::PreconditionNotMet(
"The input LoDTensorArray Input[%d] holds no memory.",
in_offset));
d_out_tensor->set_lod(in_tensor.lod());
paddle::framework::TensorCopy(in_tensor, context.GetPlace(),
d_out_tensor);
} else {
d_out_tensor->Resize(dim);
if (!d_out_tensor->IsInitialized()) {
d_out_tensor->mutable_data<T>(context.GetPlace());
}
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out_tensor, static_cast<T>(0));
}
}
} else {
auto* d_input =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_out =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
d_out->mutable_data<T>(context.GetPlace());
phi::funcs::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, d_out, static_cast<T>(0));
auto in_dims = d_input->dims();
auto in_t = framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*d_input);
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(*d_out, out_dims);
if (need_reverse) {
framework::Tensor reverse_input;
reverse_input.mutable_data<T>(in_dims, context.GetPlace());
auto reverse_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor,
Eigen::DenseIndex>::From(reverse_input);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......@@ -112,16 +112,16 @@ class StridedSliceNPUKernel : public framework::OpKernel<T> {
// out dims calculation
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
phi::funcs::StridedSliceOutDims(starts, ends, strides, axes, infer_flags,
in_dims, decrease_axis,
out_dims_vector.data(), axes.size(), false);
framework::DDim out_dims(phi::make_ddim(out_dims_vector));
// check whether need to reverse (false: stride > 0; true: stride < 0)
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), in_dims, infer_flags,
decrease_axis, starts.size());
phi::funcs::StridedSliceFunctor(starts.data(), ends.data(), strides.data(),
axes.data(), reverse_vector.data(), in_dims,
infer_flags, decrease_axis, starts.size());
// construct the starts_indices, ends_indices and strides_indices tensor for
// calling StridedSlice op
......@@ -317,14 +317,15 @@ class StridedSliceGradNPUKernel : public framework::OpKernel<T> {
}
std::vector<int64_t> out_dims_vector(input_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, input_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
false);
phi::funcs::StridedSliceOutDims(starts, ends, strides, axes, infer_flags,
input_dims, decrease_axis,
out_dims_vector.data(), axes.size(), false);
std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(),
reverse_vector.data(), input_dims, infer_flags,
decrease_axis, starts.size());
phi::funcs::StridedSliceFunctor(starts.data(), ends.data(), strides.data(),
axes.data(), reverse_vector.data(),
input_dims, infer_flags, decrease_axis,
starts.size());
std::vector<int64_t> starts_indices_vector(D, 0);
std::vector<int64_t> ends_indices_vector(out_dims_vector.begin(),
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
......@@ -1708,6 +1709,136 @@ void SqueezeInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
errors::InvalidArgument(
"The dimension of StridedSlice operator's input should be less "
"than 7, but received dimension is %d.",
in_dims.size()));
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
auto starts_size = starts_.size();
auto ends_size = ends_.size();
auto strides_size = strides_.size();
for (size_t i = 0; i < axes.size(); ++i) {
PADDLE_ENFORCE_GE(
axes[i],
0,
errors::InvalidArgument("The axis should be greater than or equal to 0."
"But received %d of axes[%d]",
axes[i],
i));
PADDLE_ENFORCE_LT(
axes[i],
in_dims.size(),
errors::InvalidArgument(
"The axes should be less than or equal to input tensor's rank."
"But received %d of axes[%d], input tensor shape [%d]",
axes[i],
i,
in_dims.size()));
}
auto tensor_input = false;
auto HasInput = [](const ScalarArray& arr) { return arr.FromTensor(); };
if (HasInput(starts) || HasInput(ends) || HasInput(strides)) {
tensor_input = true;
}
if (!HasInput(ends)) {
PADDLE_ENFORCE_EQ(
ends_size,
axes.size(),
errors::InvalidArgument(
"The size of ends attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The ends attribute's size "
"is %d, axes attribute's size is %d.",
ends_size,
axes.size()));
}
if (!HasInput(starts)) {
PADDLE_ENFORCE_EQ(
starts_size,
axes.size(),
errors::InvalidArgument(
"The size of starts attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The starts attribute's "
"size is %d, axes attribute's size is %d.",
starts_size,
axes.size()));
}
if (!HasInput(strides)) {
PADDLE_ENFORCE_EQ(
strides_size,
axes.size(),
errors::InvalidArgument(
"The size of strides attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The strides attribute's "
"size is %d, axes attribute's size is %d.",
strides_size,
axes.size()));
}
// we need to analysis strided slice op is valid for
// the parameter that we get from python front
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
if (!tensor_input || config.is_runtime) {
phi::funcs::StridedSliceOutDims(starts_,
ends_,
strides_,
axes,
infer_flags,
in_dims,
decrease_axis,
out_dims_vector.data(),
axes.size(),
true);
}
DDim out_dims(phi::make_ddim(out_dims_vector));
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (config.is_runtime && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]],
1,
errors::InvalidArgument(
"the size of decrease dimension should be 1, "
"but received %d.",
out_dims[decrease_axis[i]]));
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = phi::make_ddim(new_out_shape);
}
VLOG(1) << "out_dims: " << out_dims;
out->set_dims(out_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
}
/* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml
*/
......
......@@ -267,6 +267,16 @@ void SqueezeInferMeta(const MetaTensor& x,
MetaTensor* xshape,
MetaTensor* out);
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad,
CPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array_grad,
CPU,
ALL_LAYOUT,
phi::StridedSliceArrayGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice,
CPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array,
CPU,
ALL_LAYOUT,
phi::StridedSliceArrayKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace funcs {
static void StridedSliceOutDims(const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const std::vector<int>& axes,
const std::vector<int>& infer_flags,
const DDim in_dims,
const std::vector<int>& decrease_axis,
int64_t* out_dims_vector,
const size_t size,
bool infer_shape) {
for (int i = 0; i < in_dims.size(); i++) {
out_dims_vector[i] = in_dims[i];
}
int64_t stride_index, start_index, end_index;
for (size_t i = 0; i < size; i++) {
int axes_index = axes[i];
start_index = starts[i];
end_index = ends[i];
stride_index = strides[i];
bool decrease_axis_affect = false;
if (start_index == -1 && end_index == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
if (decrease_axis_affect) {
out_dims_vector[axes_index] = 1;
continue;
}
if (infer_shape && infer_flags[i] == -1) {
out_dims_vector[axes_index] = -1;
continue;
}
PADDLE_ENFORCE_NE(
stride_index,
0,
errors::InvalidArgument("stride index in StridedSlice operator is 0."));
int64_t axis_size = in_dims[axes_index];
if (axis_size < 0) {
continue;
}
if (start_index < 0) {
start_index = start_index + axis_size;
}
if (end_index < 0) {
if (!(end_index == -1 && stride_index < 0)) { // skip None stop condition
end_index = end_index + axis_size;
}
}
if (stride_index < 0) {
start_index = start_index + 1;
end_index = end_index + 1;
}
bool neg_dim_condition = ((stride_index < 0 && (start_index < end_index)) ||
(stride_index > 0 && (start_index > end_index)));
PADDLE_ENFORCE_EQ(neg_dim_condition,
false,
errors::InvalidArgument(
"The start index and end index are invalid for their "
"corresponding stride."));
int64_t left =
std::max(static_cast<int64_t>(0), std::min(start_index, end_index));
int64_t right = std::min(axis_size, std::max(start_index, end_index));
int64_t step = std::abs(stride_index);
auto out_dims_index = (std::abs(right - left) + step - 1) / step;
out_dims_vector[axes_index] = out_dims_index;
}
}
static void StridedSliceFunctor(int64_t* starts,
int64_t* ends,
int64_t* strides,
const int* axes,
int* reverse_axis,
const DDim dims,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
const size_t size) {
for (size_t axis = 0; axis < size; axis++) {
int64_t axis_size = dims[axes[axis]];
int axis_index = axis;
if (axis_size < 0) {
starts[axis_index] = 0;
ends[axis_index] = 1;
strides[axis_index] = 1;
}
bool decrease_axis_affect = false;
if (starts[axis_index] == -1 && ends[axis_index] == 0 &&
infer_flags[axis_index] == -1) {
auto ret = std::find(
decrease_axis.begin(), decrease_axis.end(), axes[axis_index]);
if (ret != decrease_axis.end()) {
decrease_axis_affect = true;
}
}
// stride must not be zero
if (starts[axis_index] < 0) {
starts[axis_index] = starts[axis_index] + axis_size;
starts[axis_index] = std::max<int64_t>(starts[axis_index], 0);
}
if (ends[axis_index] < 0) {
if (!(ends[axis_index] == -1 &&
strides[axis_index] < 0)) { // skip None stop condition
ends[axis_index] = ends[axis_index] + axis_size;
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
}
}
if (decrease_axis_affect) {
if (strides[axis_index] < 0) {
ends[axis_index] = starts[axis_index] - 1;
} else {
ends[axis_index] = starts[axis_index] + 1;
}
}
if (strides[axis_index] < 0) {
reverse_axis[axis_index] = 1;
strides[axis_index] = -strides[axis_index];
if (starts[axis_index] > ends[axis_index]) {
// swap the reverse
auto end_dim = axis_size - 1 < starts[axis_index] ? axis_size - 1
: starts[axis_index];
auto offset = (end_dim - ends[axis_index]) % strides[axis_index];
offset = offset == 0 ? strides[axis_index] : offset;
starts[axis_index] = starts[axis_index] + offset;
ends[axis_index] = ends[axis_index] + offset;
}
std::swap(starts[axis_index], ends[axis_index]);
} else {
reverse_axis[axis_index] = 0;
strides[axis_index] = strides[axis_index];
}
}
}
template <typename Context, typename T, size_t D>
void StridedSliceCompute(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out) {
auto& place = *dev_ctx.eigen_device();
DDim in_dims = x.dims();
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts_,
ends_,
strides_,
axes,
infer_flags,
in_dims,
decrease_axis,
out_dims_vector.data(),
axes.size(),
false);
DDim out_dims(phi::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts_.size(), 0);
StridedSliceFunctor(starts_.data(),
ends_.data(),
strides_.data(),
axes.data(),
reverse_vector.data(),
in_dims,
infer_flags,
decrease_axis,
starts_.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
reverse_axis[axis] = false;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts_[axis];
ends_indices[axis_index] = ends_[axis];
strides_indices[axis_index] = strides_[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]],
1,
errors::InvalidArgument(
"the size of decrease dimension should be 1, but received %d.",
out_dims[decrease_axis[i]]));
out_dims_origin[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims_origin.size(); ++i) {
if (out_dims_origin[i] != 0) {
new_out_shape.push_back(out_dims_origin[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims_origin = phi::make_ddim(new_out_shape);
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
auto in_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(x);
auto out_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, out_dims);
if (need_reverse) {
DenseTensor tmp;
tmp.Resize(out_dims);
dev_ctx.template Alloc<T>(&tmp);
auto tmp_t =
EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(tmp);
tmp_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
out_t.device(place) = tmp_t.reverse(reverse_axis);
} else {
out_t.device(place) =
in_t.stridedSlice(starts_indices, ends_indices, strides_indices);
}
if (decrease_axis.size() > 0) {
out->Resize(out_dims_origin);
}
}
template <typename Context, typename T, size_t D>
void StridedSliceCompute(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> out) {
const int64_t size = x.size();
auto in_dims = phi::make_ddim({size});
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts_,
ends_,
strides_,
axes,
infer_flags,
in_dims,
decrease_axis,
out_dims_vector.data(),
axes.size(),
false);
DDim out_dims(phi::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts_.size(), 0);
StridedSliceFunctor(starts_.data(),
ends_.data(),
strides_.data(),
axes.data(),
reverse_vector.data(),
in_dims,
infer_flags,
decrease_axis,
starts_.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
reverse_axis[axis] = false;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts_[axis];
ends_indices[axis_index] = ends_[axis];
strides_indices[axis_index] = strides_[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
auto out_dims_origin = out_dims;
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
PADDLE_ENFORCE_EQ(
out_dims[decrease_axis[i]],
1,
errors::InvalidArgument(
"the size of decrease dimension should be 1, but received %d.",
out_dims[decrease_axis[i]]));
out_dims_origin[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims_origin.size(); ++i) {
if (out_dims_origin[i] != 0) {
new_out_shape.push_back(out_dims_origin[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims_origin = phi::make_ddim(new_out_shape);
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
PADDLE_ENFORCE_EQ(
starts_indices.size(),
1,
errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of start index should be 1, but received %d.",
starts_indices.size()));
PADDLE_ENFORCE_EQ(
ends_indices.size(),
1,
errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of end index should be 1, but received %d.",
ends_indices.size()));
PADDLE_ENFORCE_EQ(
strides_indices.size(),
1,
errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of stride should be 1, but received %d.",
strides_indices.size()));
PADDLE_ENFORCE_EQ(
out_dims_origin.size(),
1,
errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of Output should be 1, but received %d",
out_dims_origin.size()));
out.resize(out_dims_origin[0]);
size_t const in_array_size = x.size();
for (size_t i = 0; i < out.size(); i++) {
size_t in_offset =
(starts_indices[0] % in_array_size) + i * strides_indices[0];
int64_t out_offset = i;
if (need_reverse) {
out_offset = out.size() - i - 1;
}
auto* in_tensor = x.at(in_offset);
PADDLE_ENFORCE_GT(
in_tensor->memory_size(),
0,
errors::PreconditionNotMet(
"The input LoDTensorArray Input[%d] holds no memory.", in_offset));
auto* out_tensor = out.at(out_offset);
out_tensor->Resize(in_tensor->dims());
phi::Copy<Context>(
dev_ctx, *in_tensor, dev_ctx.GetPlace(), false, out_tensor);
out_tensor->set_lod(in_tensor->lod());
}
}
template <typename Context, typename T, size_t D>
void StridedSliceGradCompute(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad) {
auto& place = *dev_ctx.eigen_device();
DDim out_dims = x.dims();
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
std::vector<int> reverse_vector(starts_.size(), 0);
StridedSliceFunctor(starts_.data(),
ends_.data(),
strides_.data(),
axes.data(),
reverse_vector.data(),
out_dims,
infer_flags,
decrease_axis,
starts_.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts_[axis];
ends_indices[axis_index] = ends_[axis];
strides_indices[axis_index] = strides_[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, x_grad, static_cast<T>(0));
auto out_grad_dims = out_grad.dims();
auto in_t =
EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(out_grad);
auto out_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*x_grad, out_dims);
if (need_reverse) {
DenseTensor reverse_input;
reverse_input.Resize(out_grad_dims);
dev_ctx.template Alloc<T>(&reverse_input);
auto reverse_in_t =
EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
reverse_input);
reverse_in_t.device(place) = in_t.reverse(reverse_axis);
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = reverse_in_t;
} else {
out_t.stridedSlice(starts_indices, ends_indices, strides_indices)
.device(place) = in_t;
}
}
template <typename Context, typename T, size_t D>
void StridedSliceGradCompute(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> x_grad) {
// Note(weixin):Since the shape of `framework::GradVarName("Input")` of
// StridedSliceGrad cannot be calculated by
// `framework::GradVarName("Output")`, the dim of "Input" is used to
// calculate the output shape. when set it to inplace OP, there may be
// some problems.
const int64_t size = x.size();
DDim out_dims = phi::make_ddim({size});
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto strides_indices = Eigen::DSizes<Eigen::DenseIndex, D>();
auto reverse_axis = Eigen::array<bool, D>();
std::vector<int> reverse_vector(starts_.size(), 0);
StridedSliceFunctor(starts_.data(),
ends_.data(),
strides_.data(),
axes.data(),
reverse_vector.data(),
out_dims,
infer_flags,
decrease_axis,
starts_.size());
for (size_t axis = 0; axis < D; axis++) {
starts_indices[axis] = 0;
ends_indices[axis] = out_dims[axis];
strides_indices[axis] = 1;
}
for (size_t axis = 0; axis < axes.size(); axis++) {
int axis_index = axes[axis];
starts_indices[axis_index] = starts_[axis];
ends_indices[axis_index] = ends_[axis];
strides_indices[axis_index] = strides_[axis];
reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false;
}
bool need_reverse = false;
for (size_t axis = 0; axis < axes.size(); axis++) {
if (reverse_vector[axis] == 1) {
need_reverse = true;
break;
}
}
PADDLE_ENFORCE_EQ(
starts_indices.size(),
1,
errors::InvalidArgument(
"When the input of 'strided_slice_grad_op' is `TensorArray`, the "
"dimension of start index should be 1, but received %d.",
starts_indices.size()));
PADDLE_ENFORCE_EQ(
ends_indices.size(),
1,
errors::InvalidArgument(
"When the input of 'strided_slice_op' is `TensorArray`, the "
"dimension of end index should be 1, but received %d.",
ends_indices.size()));
PADDLE_ENFORCE_EQ(
strides_indices.size(),
1,
errors::InvalidArgument(
"When the input of 'strided_slice_grad_op' is `TensorArray`, the "
"dimension of stride should be 1, but received %d.",
strides_indices.size()));
PADDLE_ENFORCE_EQ(
out_dims.size(),
1,
errors::InvalidArgument(
"When the output of `strided_slice_grad_op` is `TensorArray`, "
"the dimension of output should be 1, but received %d.",
out_dims.size()));
auto const d_out_array_size = x_grad.size();
for (size_t j = 0; j < d_out_array_size; j++) {
auto& dim = x.at(j)->dims();
auto* d_out_tensor = x_grad.at(j);
int64_t sub = j - starts_indices[0];
int64_t in_offset = sub / strides_indices[0];
if (need_reverse) {
in_offset = out_grad.size() - in_offset - 1;
}
if ((sub % strides_indices[0] == 0) && (0 <= in_offset) &&
(static_cast<size_t>(in_offset) < out_grad.size())) {
auto* in_tensor = out_grad.at(in_offset);
PADDLE_ENFORCE_GT(
in_tensor->memory_size(),
0,
errors::PreconditionNotMet(
"The input LoDTensorArray Input[%d] holds no memory.",
in_offset));
phi::Copy<Context>(
dev_ctx, *in_tensor, dev_ctx.GetPlace(), false, d_out_tensor);
d_out_tensor->set_lod(in_tensor->lod());
} else {
d_out_tensor->Resize(dim);
if (!d_out_tensor->IsInitialized()) {
dev_ctx.template Alloc<T>(d_out_tensor);
}
phi::funcs::SetConstant<Context, T> set_zero;
set_zero(dev_ctx, d_out_tensor, static_cast<T>(0));
}
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad,
GPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array_grad,
GPU,
ALL_LAYOUT,
phi::StridedSliceArrayGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice,
GPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array,
GPU,
ALL_LAYOUT,
phi::StridedSliceArrayKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -22,8 +22,7 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace phi {
......@@ -73,29 +72,29 @@ void SetValueGradImpl(const Context& dev_ctx,
std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();
paddle::operators::StridedSliceOutDims(starts_local,
ends_local,
steps_local,
axes_int32,
infer_flags,
in_dims,
decrease_axis_int32,
out_dims_vector.data(),
axes.size(),
false);
funcs::StridedSliceOutDims(starts_local,
ends_local,
steps_local,
axes_int32,
infer_flags,
in_dims,
decrease_axis_int32,
out_dims_vector.data(),
axes.size(),
false);
DDim out_dims(phi::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts_local.size(), 0);
paddle::operators::StridedSliceFunctor(starts_local.data(),
ends_local.data(),
steps_local.data(),
axes_int32.data(),
reverse_vector.data(),
in_dims,
infer_flags,
decrease_axis_int32,
starts_local.size());
funcs::StridedSliceFunctor(starts_local.data(),
ends_local.data(),
steps_local.data(),
axes_int32.data(),
reverse_vector.data(),
in_dims,
infer_flags,
decrease_axis_int32,
starts_local.size());
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad) {
int rank = x.dims().size();
#define SLICE_CASE(Rank) \
case Rank: \
funcs::StridedSliceGradCompute<Context, T, Rank>(dev_ctx, \
x, \
out_grad, \
axes, \
starts, \
ends, \
strides, \
infer_flags, \
decrease_axis, \
x_grad); \
break;
switch (rank) {
SLICE_CASE(1)
SLICE_CASE(2)
SLICE_CASE(3)
SLICE_CASE(4)
SLICE_CASE(5)
SLICE_CASE(6)
}
#undef SLICE_CASE
}
template <typename T, typename Context>
void StridedSliceArrayGradKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> x_grad) {
funcs::StridedSliceGradCompute<Context, T, 1>(dev_ctx,
x,
out_grad,
axes,
starts,
ends,
strides,
infer_flags,
decrease_axis,
x_grad);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out) {
int rank = x.dims().size();
#define SLICE_CASE(Rank) \
case Rank: \
funcs::StridedSliceCompute<Context, T, Rank>(dev_ctx, \
x, \
axes, \
starts, \
ends, \
strides, \
infer_flags, \
decrease_axis, \
out); \
break;
switch (rank) {
SLICE_CASE(1)
SLICE_CASE(2)
SLICE_CASE(3)
SLICE_CASE(4)
SLICE_CASE(5)
SLICE_CASE(6)
}
#undef SLICE_CASE
}
template <typename T, typename Context>
void StridedSliceArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> out) {
funcs::StridedSliceCompute<Context, T, 1>(
dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad);
template <typename T, typename Context>
void StridedSliceArrayGradKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out);
template <typename T, typename Context>
void StridedSliceArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace phi {
KernelSignature StridedSliceOpArgumentMapping(
const ArgumentMappingContext& ctx) {
const auto& starts = paddle::any_cast<std::vector<int>>(ctx.Attr("starts"));
const auto& ends = paddle::any_cast<std::vector<int>>(ctx.Attr("ends"));
const auto& strides = paddle::any_cast<std::vector<int>>(ctx.Attr("strides"));
bool use_attr_starts = !ctx.IsRuntime() && !starts.empty();
bool use_attr_ends = !ctx.IsRuntime() && !ends.empty();
bool use_attr_strides = !ctx.IsRuntime() && !strides.empty();
std::string starts_key =
ctx.HasInput("StartsTensor")
? "StartsTensor"
: (ctx.InputSize("StartsTensorList") > 0
? (use_attr_starts ? "starts" : "StartsTensorList")
: "starts");
std::string ends_key =
ctx.HasInput("EndsTensor")
? "EndsTensor"
: (ctx.InputSize("EndsTensorList") > 0
? (use_attr_ends ? "ends" : "EndsTensorList")
: "ends");
std::string strides_key =
ctx.HasInput("StridesTensor")
? "StridesTensor"
: (ctx.InputSize("StridesTensorList") > 0
? (use_attr_strides ? "strides" : "StridesTensorList")
: "strides");
paddle::SmallVector<std::string> inputs = {"Input"};
paddle::SmallVector<std::string> attrs = {"axes",
starts_key,
ends_key,
strides_key,
"infer_flags",
"decrease_axis"};
paddle::SmallVector<std::string> outputs = {"Out"};
std::string op_type;
if (ctx.IsDenseTensorVectorInput("Input")) {
op_type = "strided_slice_array";
} else {
op_type = "strided_slice";
}
// NOTE(dev): Use this to avoid regularization.
KernelSignature sig(op_type, inputs, attrs, outputs);
return sig;
}
KernelSignature StridedSliceGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
const auto& starts = paddle::any_cast<std::vector<int>>(ctx.Attr("starts"));
const auto& ends = paddle::any_cast<std::vector<int>>(ctx.Attr("ends"));
const auto& strides = paddle::any_cast<std::vector<int>>(ctx.Attr("strides"));
bool use_attr_starts = !ctx.IsRuntime() && !starts.empty();
bool use_attr_ends = !ctx.IsRuntime() && !ends.empty();
bool use_attr_strides = !ctx.IsRuntime() && !strides.empty();
std::string starts_key =
ctx.HasInput("StartsTensor")
? "StartsTensor"
: (ctx.InputSize("StartsTensorList") > 0
? (use_attr_starts ? "starts" : "StartsTensorList")
: "starts");
std::string ends_key =
ctx.HasInput("EndsTensor")
? "EndsTensor"
: (ctx.InputSize("EndsTensorList") > 0
? (use_attr_ends ? "ends" : "EndsTensorList")
: "ends");
std::string strides_key =
ctx.HasInput("StridesTensor")
? "StridesTensor"
: (ctx.InputSize("StridesTensorList") > 0
? (use_attr_strides ? "strides" : "StridesTensorList")
: "strides");
paddle::SmallVector<std::string> inputs = {"Input", GradVarName("Out")};
paddle::SmallVector<std::string> attrs = {"axes",
starts_key,
ends_key,
strides_key,
"infer_flags",
"decrease_axis"};
paddle::SmallVector<std::string> outputs = {GradVarName("Input")};
std::string op_type;
if (ctx.IsDenseTensorVectorInput("Input")) {
op_type = "strided_slice_array_grad";
} else {
op_type = "strided_slice_grad";
}
// NOTE(dev): Use this to avoid regularization.
KernelSignature sig(op_type, inputs, attrs, outputs);
return sig;
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(strided_slice, phi::StridedSliceOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(strided_slice_grad,
phi::StridedSliceGradOpArgumentMapping);
/*
******************************************************************
NOTE: The following codes are for 'get_compat_kernel_signature.py'
DO NOT EDIT IT if you don't know the mechanism.
******************************************************************
############################ Forward ############################
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice}", {"Input"},
{"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{"Out"});
return KernelSignature("{strided_slice_array}", {"Input"},
{"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
{"Out"});
############################ Backward ############################
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_grad}", {"Input", GradVarName("Out")},
{"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensor", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensor","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "EndsTensorList",
"starts","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "ends",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensor", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensor", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensorList",
"StartsTensorList","infer_flags", "decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensor","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "ends", "StartsTensorList","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
return KernelSignature("{strided_slice_array_grad}", {"Input",
GradVarName("Out")},
{"axes", "starts", "ends", "starts","infer_flags",
"decrease_axis"},
{GradVarName("Input")});
*/
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册