未验证 提交 c33b4f95 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Migrate strided_slice into Phi (#40708)

* [Phi] Migrate strided_slice into Phi

* [Phi] Migrate strided_slice into Phi

* fix compilation problem
上级 fd0c0e3c
...@@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/operators/slice_op.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,149 +33,6 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -28,149 +33,6 @@ class StridedSliceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "StridedSlice");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "StridedSlice");
auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime.
return;
}
}
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_LT(
in_dims.size(), 7,
platform::errors::InvalidArgument(
"The dimension of StridedSlice operator's input should be less "
"than 7, but received dimension is %d.",
in_dims.size()));
auto starts_int = ctx->Attrs().Get<std::vector<int>>("starts");
auto ends_int = ctx->Attrs().Get<std::vector<int>>("ends");
auto strides_int = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int64_t> starts(starts_int.begin(), starts_int.end());
std::vector<int64_t> ends(ends_int.begin(), ends_int.end());
std::vector<int64_t> strides(strides_int.begin(), strides_int.end());
auto axes = ctx->Attrs().Get<std::vector<int>>("axes");
auto infer_flags = ctx->Attrs().Get<std::vector<int>>("infer_flags");
auto decrease_axis = ctx->Attrs().Get<std::vector<int>>("decrease_axis");
auto starts_size = starts.size();
auto ends_size = ends.size();
auto strides_size = strides.size();
for (size_t i = 0; i < axes.size(); ++i) {
PADDLE_ENFORCE_GE(axes[i], 0,
platform::errors::InvalidArgument(
"The axis should be greater than or equal to 0."
"But received %d of axes[%d]",
axes[i], i));
PADDLE_ENFORCE_LT(
axes[i], in_dims.size(),
platform::errors::InvalidArgument(
"The axes should be less than or equal to input tensor's rank."
"But received %d of axes[%d], input tensor shape [%d]",
axes[i], i, in_dims.size()));
}
if (ctx->HasInputs("StartsTensorList")) {
auto StartsTensorList = ctx->Inputs("StartsTensorList");
PADDLE_ENFORCE_GT(
StartsTensorList.size(), 0,
platform::errors::InvalidArgument(
"StridedSlice operator's StartsTensorList is empty."));
starts_size = StartsTensorList.size();
}
if (ctx->HasInputs("EndsTensorList")) {
auto EndsTensorList = ctx->Inputs("EndsTensorList");
PADDLE_ENFORCE_GT(
EndsTensorList.size(), 0,
platform::errors::InvalidArgument(
"StridedSlice operator's EndsTensorList is empty."));
ends_size = EndsTensorList.size();
}
if (ctx->HasInputs("StridesTensorList")) {
auto StridesTensorList = ctx->Inputs("StridesTensorList");
PADDLE_ENFORCE_GT(
StridesTensorList.size(), 0,
platform::errors::InvalidArgument(
"StridedSlice operator's StridesTensorList is empty."));
strides_size = StridesTensorList.size();
}
auto tensor_input = false;
if (ctx->HasInput("EndsTensor") || ctx->HasInput("StartsTensor") ||
ctx->HasInput("StridesTensor")) {
tensor_input = true;
}
if (!ctx->HasInput("EndsTensor")) {
PADDLE_ENFORCE_EQ(
ends_size, axes.size(),
platform::errors::InvalidArgument(
"The size of ends attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The ends attribute's size "
"is %d, axes attribute's size is %d.",
ends_size, axes.size()));
}
if (!ctx->HasInput("StartsTensor")) {
PADDLE_ENFORCE_EQ(
starts_size, axes.size(),
platform::errors::InvalidArgument(
"The size of starts attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The starts attribute's "
"size is %d, axes attribute's size is %d.",
starts_size, axes.size()));
}
if (!ctx->HasInput("StridesTensor")) {
PADDLE_ENFORCE_EQ(
strides_size, axes.size(),
platform::errors::InvalidArgument(
"The size of strides attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The strides attribute's "
"size is %d, axes attribute's size is %d.",
strides_size, axes.size()));
}
// we need to analysis strided slice op is valid for
// the parameter that we get from python front
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
if (!tensor_input) {
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims,
decrease_axis, out_dims_vector.data(), axes.size(),
true);
}
framework::DDim out_dims(phi::make_ddim(out_dims_vector));
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (ctx->IsRuntime() && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]], 1,
platform::errors::InvalidArgument(
"the size of decrease dimension should be 1, "
"but received %d.",
out_dims[decrease_axis[i]]));
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = phi::make_ddim(new_out_shape);
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("Input", /*->*/ "Out");
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -304,26 +166,6 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { ...@@ -304,26 +166,6 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"StridedSliceGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "StridedSliceGrad");
auto input_var_type = ctx->GetInputsVarType("Input")[0];
if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// shape is determined by Runtime
return;
}
}
auto x_dims = ctx->GetInputDim("Input");
auto x_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
...@@ -384,35 +226,19 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer, ...@@ -384,35 +226,19 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice, StridedSliceInferShape,
PD_INFER_META(phi::StridedSliceInferMeta));
REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker, REGISTER_OPERATOR(strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>, ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>, ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
ops::StridedSliceOpVarTypeInference); ops::StridedSliceOpVarTypeInference, StridedSliceInferShape);
DECLARE_INFER_SHAPE_FUNCTOR(strided_slice_grad, StridedSliceGradInferShape,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad, REGISTER_OPERATOR(strided_slice_grad, ops::StridedSliceOpGrad,
ops::StridedSliceOpGradNoNeedBufferVarsInferer, ops::StridedSliceOpGradNoNeedBufferVarsInferer,
ops::StridedSliceGradOpVarTypeInference); ops::StridedSliceGradOpVarTypeInference,
StridedSliceGradInferShape);
REGISTER_OP_CPU_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/platform/complex.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
strided_slice,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
此差异已折叠。
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h" #include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/fluid/operators/slice_op.h" #include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
...@@ -112,16 +112,16 @@ class StridedSliceNPUKernel : public framework::OpKernel<T> { ...@@ -112,16 +112,16 @@ class StridedSliceNPUKernel : public framework::OpKernel<T> {
// out dims calculation // out dims calculation
std::vector<int64_t> out_dims_vector(in_dims.size(), -1); std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, in_dims, phi::funcs::StridedSliceOutDims(starts, ends, strides, axes, infer_flags,
decrease_axis, out_dims_vector.data(), axes.size(), in_dims, decrease_axis,
false); out_dims_vector.data(), axes.size(), false);
framework::DDim out_dims(phi::make_ddim(out_dims_vector)); framework::DDim out_dims(phi::make_ddim(out_dims_vector));
// check whether need to reverse (false: stride > 0; true: stride < 0) // check whether need to reverse (false: stride > 0; true: stride < 0)
std::vector<int> reverse_vector(starts.size(), 0); std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), phi::funcs::StridedSliceFunctor(starts.data(), ends.data(), strides.data(),
reverse_vector.data(), in_dims, infer_flags, axes.data(), reverse_vector.data(), in_dims,
decrease_axis, starts.size()); infer_flags, decrease_axis, starts.size());
// construct the starts_indices, ends_indices and strides_indices tensor for // construct the starts_indices, ends_indices and strides_indices tensor for
// calling StridedSlice op // calling StridedSlice op
...@@ -317,14 +317,15 @@ class StridedSliceGradNPUKernel : public framework::OpKernel<T> { ...@@ -317,14 +317,15 @@ class StridedSliceGradNPUKernel : public framework::OpKernel<T> {
} }
std::vector<int64_t> out_dims_vector(input_dims.size(), -1); std::vector<int64_t> out_dims_vector(input_dims.size(), -1);
StridedSliceOutDims(starts, ends, strides, axes, infer_flags, input_dims, phi::funcs::StridedSliceOutDims(starts, ends, strides, axes, infer_flags,
decrease_axis, out_dims_vector.data(), axes.size(), input_dims, decrease_axis,
false); out_dims_vector.data(), axes.size(), false);
std::vector<int> reverse_vector(starts.size(), 0); std::vector<int> reverse_vector(starts.size(), 0);
StridedSliceFunctor(starts.data(), ends.data(), strides.data(), axes.data(), phi::funcs::StridedSliceFunctor(starts.data(), ends.data(), strides.data(),
reverse_vector.data(), input_dims, infer_flags, axes.data(), reverse_vector.data(),
decrease_axis, starts.size()); input_dims, infer_flags, decrease_axis,
starts.size());
std::vector<int64_t> starts_indices_vector(D, 0); std::vector<int64_t> starts_indices_vector(D, 0);
std::vector<int64_t> ends_indices_vector(out_dims_vector.begin(), std::vector<int64_t> ends_indices_vector(out_dims_vector.begin(),
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h" #include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/phi/kernels/funcs/unfold_functor.h" #include "paddle/phi/kernels/funcs/unfold_functor.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/funcs/unsqueeze.h"
...@@ -1708,6 +1709,136 @@ void SqueezeInferMeta(const MetaTensor& x, ...@@ -1708,6 +1709,136 @@ void SqueezeInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config) {
auto in_dims = x.dims();
PADDLE_ENFORCE_LT(
in_dims.size(),
7,
errors::InvalidArgument(
"The dimension of StridedSlice operator's input should be less "
"than 7, but received dimension is %d.",
in_dims.size()));
auto starts_ = starts.GetData();
auto ends_ = ends.GetData();
auto strides_ = strides.GetData();
auto starts_size = starts_.size();
auto ends_size = ends_.size();
auto strides_size = strides_.size();
for (size_t i = 0; i < axes.size(); ++i) {
PADDLE_ENFORCE_GE(
axes[i],
0,
errors::InvalidArgument("The axis should be greater than or equal to 0."
"But received %d of axes[%d]",
axes[i],
i));
PADDLE_ENFORCE_LT(
axes[i],
in_dims.size(),
errors::InvalidArgument(
"The axes should be less than or equal to input tensor's rank."
"But received %d of axes[%d], input tensor shape [%d]",
axes[i],
i,
in_dims.size()));
}
auto tensor_input = false;
auto HasInput = [](const ScalarArray& arr) { return arr.FromTensor(); };
if (HasInput(starts) || HasInput(ends) || HasInput(strides)) {
tensor_input = true;
}
if (!HasInput(ends)) {
PADDLE_ENFORCE_EQ(
ends_size,
axes.size(),
errors::InvalidArgument(
"The size of ends attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The ends attribute's size "
"is %d, axes attribute's size is %d.",
ends_size,
axes.size()));
}
if (!HasInput(starts)) {
PADDLE_ENFORCE_EQ(
starts_size,
axes.size(),
errors::InvalidArgument(
"The size of starts attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The starts attribute's "
"size is %d, axes attribute's size is %d.",
starts_size,
axes.size()));
}
if (!HasInput(strides)) {
PADDLE_ENFORCE_EQ(
strides_size,
axes.size(),
errors::InvalidArgument(
"The size of strides attribute in StridedSlice operator is not "
"equal to the size of axes attribute. The strides attribute's "
"size is %d, axes attribute's size is %d.",
strides_size,
axes.size()));
}
// we need to analysis strided slice op is valid for
// the parameter that we get from python front
std::vector<int64_t> out_dims_vector(in_dims.size(), -1);
if (!tensor_input || config.is_runtime) {
phi::funcs::StridedSliceOutDims(starts_,
ends_,
strides_,
axes,
infer_flags,
in_dims,
decrease_axis,
out_dims_vector.data(),
axes.size(),
true);
}
DDim out_dims(phi::make_ddim(out_dims_vector));
// generate new shape
if (decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < decrease_axis.size(); ++i) {
if (config.is_runtime && infer_flags[i] != -1) {
PADDLE_ENFORCE_EQ(out_dims[decrease_axis[i]],
1,
errors::InvalidArgument(
"the size of decrease dimension should be 1, "
"but received %d.",
out_dims[decrease_axis[i]]));
}
out_dims[decrease_axis[i]] = 0;
}
for (int i = 0; i < out_dims.size(); ++i) {
if (out_dims[i] != 0) {
new_out_shape.push_back(out_dims[i]);
}
}
if (new_out_shape.size() == 0) {
new_out_shape.push_back(1);
}
out_dims = phi::make_ddim(new_out_shape);
}
VLOG(1) << "out_dims: " << out_dims;
out->set_dims(out_dims);
out->share_lod(x);
out->set_dtype(x.dtype());
}
/* Why not use SumRawInferMeta directly? /* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of api.yaml Because we need make InferMetaFunction's args follow the design of api.yaml
*/ */
......
...@@ -267,6 +267,16 @@ void SqueezeInferMeta(const MetaTensor& x, ...@@ -267,6 +267,16 @@ void SqueezeInferMeta(const MetaTensor& x,
MetaTensor* xshape, MetaTensor* xshape,
MetaTensor* out); MetaTensor* out);
void StridedSliceInferMeta(const MetaTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
MetaTensor* out,
MetaConfig config = MetaConfig());
void SumInferMeta(const MetaTensor& x, void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
DataType dtype, DataType dtype,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad,
CPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array_grad,
CPU,
ALL_LAYOUT,
phi::StridedSliceArrayGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice,
CPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array,
CPU,
ALL_LAYOUT,
phi::StridedSliceArrayKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_grad_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice_grad,
GPU,
ALL_LAYOUT,
phi::StridedSliceGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array_grad,
GPU,
ALL_LAYOUT,
phi::StridedSliceArrayGradKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/strided_slice_kernel_impl.h"
PD_REGISTER_KERNEL(strided_slice,
GPU,
ALL_LAYOUT,
phi::StridedSliceKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(strided_slice_array,
GPU,
ALL_LAYOUT,
phi::StridedSliceArrayKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
...@@ -22,8 +22,7 @@ ...@@ -22,8 +22,7 @@
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
#include "paddle/fluid/operators/strided_slice_op.h"
namespace phi { namespace phi {
...@@ -73,29 +72,29 @@ void SetValueGradImpl(const Context& dev_ctx, ...@@ -73,29 +72,29 @@ void SetValueGradImpl(const Context& dev_ctx,
std::vector<int64_t> starts_local = starts.GetData(); std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData(); std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData(); std::vector<int64_t> steps_local = steps.GetData();
paddle::operators::StridedSliceOutDims(starts_local, funcs::StridedSliceOutDims(starts_local,
ends_local, ends_local,
steps_local, steps_local,
axes_int32, axes_int32,
infer_flags, infer_flags,
in_dims, in_dims,
decrease_axis_int32, decrease_axis_int32,
out_dims_vector.data(), out_dims_vector.data(),
axes.size(), axes.size(),
false); false);
DDim out_dims(phi::make_ddim(out_dims_vector)); DDim out_dims(phi::make_ddim(out_dims_vector));
std::vector<int> reverse_vector(starts_local.size(), 0); std::vector<int> reverse_vector(starts_local.size(), 0);
paddle::operators::StridedSliceFunctor(starts_local.data(), funcs::StridedSliceFunctor(starts_local.data(),
ends_local.data(), ends_local.data(),
steps_local.data(), steps_local.data(),
axes_int32.data(), axes_int32.data(),
reverse_vector.data(), reverse_vector.data(),
in_dims, in_dims,
infer_flags, infer_flags,
decrease_axis_int32, decrease_axis_int32,
starts_local.size()); starts_local.size());
auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>(); auto starts_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>(); auto ends_indices = Eigen::DSizes<Eigen::DenseIndex, RANK>();
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/strided_slice_grad_kernel.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad) {
int rank = x.dims().size();
#define SLICE_CASE(Rank) \
case Rank: \
funcs::StridedSliceGradCompute<Context, T, Rank>(dev_ctx, \
x, \
out_grad, \
axes, \
starts, \
ends, \
strides, \
infer_flags, \
decrease_axis, \
x_grad); \
break;
switch (rank) {
SLICE_CASE(1)
SLICE_CASE(2)
SLICE_CASE(3)
SLICE_CASE(4)
SLICE_CASE(5)
SLICE_CASE(6)
}
#undef SLICE_CASE
}
template <typename T, typename Context>
void StridedSliceArrayGradKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> x_grad) {
funcs::StridedSliceGradCompute<Context, T, 1>(dev_ctx,
x,
out_grad,
axes,
starts,
ends,
strides,
infer_flags,
decrease_axis,
x_grad);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/strided_slice_kernel.h"
#include "paddle/phi/kernels/funcs/strided_slice.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out) {
int rank = x.dims().size();
#define SLICE_CASE(Rank) \
case Rank: \
funcs::StridedSliceCompute<Context, T, Rank>(dev_ctx, \
x, \
axes, \
starts, \
ends, \
strides, \
infer_flags, \
decrease_axis, \
out); \
break;
switch (rank) {
SLICE_CASE(1)
SLICE_CASE(2)
SLICE_CASE(3)
SLICE_CASE(4)
SLICE_CASE(5)
SLICE_CASE(6)
}
#undef SLICE_CASE
}
template <typename T, typename Context>
void StridedSliceArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> out) {
funcs::StridedSliceCompute<Context, T, 1>(
dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* x_grad);
template <typename T, typename Context>
void StridedSliceArrayGradKernel(
const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<const DenseTensor*>& out_grad,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void StridedSliceKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
DenseTensor* out);
template <typename T, typename Context>
void StridedSliceArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
const std::vector<int>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const ScalarArray& strides,
const std::vector<int>& infer_flags,
const std::vector<int>& decrease_axis,
std::vector<DenseTensor*> out);
} // namespace phi
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册