From 7fc0c61930c6d17eef7c988328eef65c05a8f0bf Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 22 Mar 2022 14:41:58 +0800 Subject: [PATCH] [Phi] Move reverse kernel and infershape into phi (#40791) * add reverse phi kernel * add reverse infermeta * remove original reverse op kernl & infershape --- paddle/fluid/operators/reverse_op.cc | 69 ++--------- paddle/fluid/operators/reverse_op.h | 113 ------------------ paddle/phi/infermeta/unary.cc | 27 +++++ paddle/phi/infermeta/unary.h | 4 + paddle/phi/kernels/cpu/reverse_kernel.cc | 30 +++++ paddle/phi/kernels/gpu/reverse_kernel.cu.cc | 30 +++++ paddle/phi/kernels/impl/reverse_kernel_impl.h | 91 ++++++++++++++ paddle/phi/kernels/reverse_kernel.cc | 74 ++++++++++++ paddle/phi/kernels/reverse_kernel.h | 35 ++++++ paddle/phi/ops/compat/reverse_sig.cc | 29 +++++ 10 files changed, 328 insertions(+), 174 deletions(-) delete mode 100644 paddle/fluid/operators/reverse_op.h create mode 100644 paddle/phi/kernels/cpu/reverse_kernel.cc create mode 100644 paddle/phi/kernels/gpu/reverse_kernel.cu.cc create mode 100644 paddle/phi/kernels/impl/reverse_kernel_impl.h create mode 100644 paddle/phi/kernels/reverse_kernel.cc create mode 100644 paddle/phi/kernels/reverse_kernel.h create mode 100644 paddle/phi/ops/compat/reverse_sig.cc diff --git a/paddle/fluid/operators/reverse_op.cc b/paddle/fluid/operators/reverse_op.cc index 98a1610be60..975eecafc06 100644 --- a/paddle/fluid/operators/reverse_op.cc +++ b/paddle/fluid/operators/reverse_op.cc @@ -12,60 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reverse_op.h" #include #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace operators { class ReverseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Reverse"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Reverse"); - - auto x_var_type = ctx->GetInputsVarType("X")[0]; - const auto& axis = ctx->Attrs().Get>("axis"); - if (x_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) { - PADDLE_ENFORCE_EQ( - axis.size(), 1, - platform::errors::InvalidArgument( - "The size of axis must be 1 when the Input(X) is LoDTensorArray, " - "but received %d.", - axis.size())); - PADDLE_ENFORCE_EQ(axis[0], 0, platform::errors::InvalidArgument( - "The value of axis should be 1 when " - "the Input(X) is LoDTensorArray, " - "but received %d.", - axis[0])); - // In runtime, shape is determined by RunImpl. - if (!ctx->IsRuntime()) { - const auto& x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - } - return; - } - const auto& x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_NE(axis.empty(), true, platform::errors::InvalidArgument( - "'axis' can not be empty.")); - for (int a : axis) { - PADDLE_ENFORCE_LT(a, x_dims.size(), - paddle::platform::errors::OutOfRange( - "The axis must be less than input tensor's rank. " - "but got %d >= %d", - a, x_dims.size())); - PADDLE_ENFORCE_GE( - a, -x_dims.size(), - paddle::platform::errors::OutOfRange( - "The axis must be greater than the negative number of " - "input tensor's rank, but got %d < %d", - a, -x_dims.size())); - } - ctx->SetOutputDim("Out", x_dims); - } }; class ReverseOpVarTypeInference : public framework::VarTypeInference { @@ -134,23 +94,10 @@ class ReverseGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(reverse, ReverseInferShapeFunctor, + PD_INFER_META(phi::ReverseInferMeta)); REGISTER_OPERATOR(reverse, ops::ReverseOp, ops::ReverseOpMaker, ops::ReverseGradMaker, ops::ReverseGradMaker, - ops::ReverseOpVarTypeInference); + ops::ReverseOpVarTypeInference, ReverseInferShapeFunctor); REGISTER_OPERATOR(reverse_grad, ops::ReverseOp, ops::ReverseOpVarTypeInference); -REGISTER_OP_CPU_KERNEL( - reverse, ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel); - -REGISTER_OP_CUDA_KERNEL( - reverse, ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel, - ops::ReverseKernel); diff --git a/paddle/fluid/operators/reverse_op.h b/paddle/fluid/operators/reverse_op.h deleted file mode 100644 index d5e331e2fe5..00000000000 --- a/paddle/fluid/operators/reverse_op.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" - -namespace paddle { -namespace operators { -template -struct ReverseFunctor { - void operator()(const DeviceContext& context, const framework::LoDTensor& in, - framework::LoDTensor* out, const std::vector& axis) { - Eigen::DSizes reverse_axis; - for (int i = 0; i < Rank; ++i) { - reverse_axis[i] = false; - } - for (int a : axis) { - if (a >= 0) { - reverse_axis[a] = true; - } else { - reverse_axis[Rank + a] = true; - } - } - - auto in_eigen = framework::EigenTensor::From(in); - auto out_eigen = framework::EigenTensor::From(*out); - auto& dev = *context.eigen_device(); - - EigenReverse, T, Rank>::Eval( - dev, out_eigen, in_eigen, reverse_axis); - } -}; - -template -class ReverseKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x_var = context.InputVar("X"); - const auto& axis = context.Attr>("axis"); - if (x_var->IsType()) { - auto& x_array = x_var->Get(); - auto* out_array = context.Output("Out"); - - out_array->resize(x_array.size()); - for (size_t offset = 0; offset < x_array.size(); offset++) { - auto& x_tensor = x_array.at(offset); - PADDLE_ENFORCE_GT( - x_tensor.memory_size(), 0, - platform::errors::PreconditionNotMet( - "The input LoDTensorArray X[%d] holds no memory.", offset)); - auto out_offset = x_array.size() - offset - 1; - auto* out_tensor = &out_array->at(out_offset); - - out_tensor->set_lod(x_tensor.lod()); - paddle::framework::TensorCopy(x_tensor, context.GetPlace(), out_tensor); - } - return; - } - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - int rank = x->dims().size(); - auto& dev_ctx = context.template device_context(); - - switch (rank) { - case 1: - ReverseFunctor functor1; - functor1(dev_ctx, *x, out, axis); - break; - case 2: - ReverseFunctor functor2; - functor2(dev_ctx, *x, out, axis); - break; - case 3: - ReverseFunctor functor3; - functor3(dev_ctx, *x, out, axis); - break; - case 4: - ReverseFunctor functor4; - functor4(dev_ctx, *x, out, axis); - break; - case 5: - ReverseFunctor functor5; - functor5(dev_ctx, *x, out, axis); - break; - case 6: - ReverseFunctor functor6; - functor6(dev_ctx, *x, out, axis); - break; - default: - PADDLE_THROW(paddle::platform::errors::OutOfRange( - "The reserve operator does not support input tensors" - "whose ranks are greater than 6.")); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 7c5f38744f8..80503dd2430 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1240,6 +1240,33 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, ReshapeInferMeta(x, shape, out, config); } +void ReverseInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out) { + PADDLE_ENFORCE_NE(axis.empty(), + true, + phi::errors::InvalidArgument("'axis' can not be empty.")); + const auto& x_dims = x.dims(); + for (int a : axis) { + PADDLE_ENFORCE_LT(a, + x_dims.size(), + phi::errors::OutOfRange( + "The axis must be less than input tensor's rank. " + "but got %d >= %d", + a, + x_dims.size())); + PADDLE_ENFORCE_GE( + a, + -x_dims.size(), + phi::errors::OutOfRange( + "The axis must be greater than the negative number of " + "input tensor's rank, but got %d < %d", + a, + -x_dims.size())); + } + out->share_meta(x); +} + void RollInferMeta(const MetaTensor& x, const ScalarArray& shifts, const std::vector& axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d84283a65c4..0322a18fc31 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -198,6 +198,10 @@ void ReshapeWithXShapeInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void ReverseInferMeta(const MetaTensor& x, + const std::vector& axis, + MetaTensor* out); + void RollInferMeta(const MetaTensor& x, const ScalarArray& shifts, const std::vector& axis, diff --git a/paddle/phi/kernels/cpu/reverse_kernel.cc b/paddle/phi/kernels/cpu/reverse_kernel.cc new file mode 100644 index 00000000000..43eff7c0550 --- /dev/null +++ b/paddle/phi/kernels/cpu/reverse_kernel.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/reverse_kernel_impl.h" + +PD_REGISTER_KERNEL(reverse, + CPU, + ALL_LAYOUT, + phi::ReverseKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/reverse_kernel.cu.cc b/paddle/phi/kernels/gpu/reverse_kernel.cu.cc new file mode 100644 index 00000000000..f11eaa11bcd --- /dev/null +++ b/paddle/phi/kernels/gpu/reverse_kernel.cu.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/reverse_kernel_impl.h" + +PD_REGISTER_KERNEL(reverse, + GPU, + ALL_LAYOUT, + phi::ReverseKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} diff --git a/paddle/phi/kernels/impl/reverse_kernel_impl.h b/paddle/phi/kernels/impl/reverse_kernel_impl.h new file mode 100644 index 00000000000..acdd46a0865 --- /dev/null +++ b/paddle/phi/kernels/impl/reverse_kernel_impl.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +struct ReverseFunctor { + void operator()(const Context& dev_ctx, + const DenseTensor& in, + DenseTensor* out, + const std::vector& axis) { + Eigen::DSizes reverse_axis; + for (int i = 0; i < Rank; ++i) { + reverse_axis[i] = false; + } + for (int a : axis) { + if (a >= 0) { + reverse_axis[a] = true; + } else { + reverse_axis[Rank + a] = true; + } + } + + auto in_eigen = EigenTensor::From(in); + auto out_eigen = EigenTensor::From(*out); + auto& dev = *dev_ctx.eigen_device(); + + funcs::EigenReverse, T, Rank>::Eval( + dev, out_eigen, in_eigen, reverse_axis); + } +}; + +template +void ReverseKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + int rank = x.dims().size(); + + switch (rank) { + case 1: + ReverseFunctor functor1; + functor1(dev_ctx, x, out, axis); + break; + case 2: + ReverseFunctor functor2; + functor2(dev_ctx, x, out, axis); + break; + case 3: + ReverseFunctor functor3; + functor3(dev_ctx, x, out, axis); + break; + case 4: + ReverseFunctor functor4; + functor4(dev_ctx, x, out, axis); + break; + case 5: + ReverseFunctor functor5; + functor5(dev_ctx, x, out, axis); + break; + case 6: + ReverseFunctor functor6; + functor6(dev_ctx, x, out, axis); + break; + default: + PADDLE_THROW(phi::errors::OutOfRange( + "The reserve operator does not support input tensors" + "whose ranks are greater than 6.")); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/reverse_kernel.cc b/paddle/phi/kernels/reverse_kernel.cc new file mode 100644 index 00000000000..c6c2781a07b --- /dev/null +++ b/paddle/phi/kernels/reverse_kernel.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/reverse_kernel.h" + +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" + +namespace phi { + +template +void ReverseArrayKernel(const Context& dev_ctx, + const std::vector& x, + const std::vector& axis, + std::vector out) { + PADDLE_ENFORCE_EQ( + x.size(), + out.size(), + phi::errors::InvalidArgument("The input size(%d) and output size(%d) of " + "ReverseArrayKernel is different.", + x.size(), + out.size())); + for (size_t offset = 0; offset < x.size(); ++offset) { + auto* x_tensor = x.at(offset); + PADDLE_ENFORCE_GT( + x_tensor->memory_size(), + 0, + phi::errors::PreconditionNotMet( + "The input LoDTensorArray X[%d] holds no memory.", offset)); + auto out_offset = x.size() - offset - 1; + auto* out_tensor = out.at(out_offset); + + out_tensor->set_lod(x_tensor->lod()); + phi::Copy( + dev_ctx, *x_tensor, dev_ctx.GetPlace(), false, out_tensor); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(reverse_array, + CPU, + ALL_LAYOUT, + phi::ReverseArrayKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(reverse_array, + GPU, + ALL_LAYOUT, + phi::ReverseArrayKernel, + int, + uint8_t, + int64_t, + bool, + float, + double) {} +#endif diff --git a/paddle/phi/kernels/reverse_kernel.h b/paddle/phi/kernels/reverse_kernel.h new file mode 100644 index 00000000000..2b81f4018c2 --- /dev/null +++ b/paddle/phi/kernels/reverse_kernel.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ReverseKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out); + +template +void ReverseArrayKernel(const Context& dev_ctx, + const std::vector& x, + const std::vector& axis, + std::vector out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/reverse_sig.cc b/paddle/phi/ops/compat/reverse_sig.cc new file mode 100644 index 00000000000..0b70893fa78 --- /dev/null +++ b/paddle/phi/ops/compat/reverse_sig.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ReverseOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorVectorInput("X")) { + return KernelSignature("reverse_array", {"X"}, {"axis"}, {"Out"}); + } else { + return KernelSignature("reverse", {"X"}, {"axis"}, {"Out"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(reverse, phi::ReverseOpArgumentMapping); -- GitLab