From b532315d1b51ada62d79bc602517680e6d31385e Mon Sep 17 00:00:00 2001 From: wuyefeilin <30919197+wuyefeilin@users.noreply.github.com> Date: Tue, 29 Mar 2022 11:38:02 +0800 Subject: [PATCH] [Phi] Move elementwise_floordiv and elementwise_pow to phi (#40993) * mv floordiv to phi * mv elementwise_pow to phi * fix as review --- .../elementwise/elementwise_floordiv_op.cc | 8 -- .../elementwise/elementwise_floordiv_op.cu | 46 -------- .../elementwise/elementwise_floordiv_op.h | 54 --------- .../elementwise/elementwise_functor.h | 17 --- .../elementwise/elementwise_pow_op.cc | 15 --- .../elementwise/elementwise_pow_op.cu | 54 --------- .../elementwise/elementwise_pow_op.h | 107 ------------------ .../elementwise/elementwise_pow_op_npu.cc | 1 - .../kernels/cpu/elementwise_grad_kernel.cc | 8 ++ paddle/phi/kernels/cpu/elementwise_kernel.cc | 44 +++++++ paddle/phi/kernels/elementwise_grad_kernel.h | 9 ++ paddle/phi/kernels/elementwise_kernel.cc | 39 +++++++ paddle/phi/kernels/elementwise_kernel.h | 49 ++++++++ .../phi/kernels/funcs/elementwise_functor.h | 35 ++++++ .../kernels/gpu/elementwise_grad_kernel.cu | 8 ++ paddle/phi/kernels/gpu/elementwise_kernel.cu | 18 +++ .../impl/elementwise_grad_kernel_impl.h | 40 +++++++ paddle/phi/ops/compat/elementwise_sig.cc | 32 ++++++ 18 files changed, 282 insertions(+), 302 deletions(-) delete mode 100644 paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu delete mode 100644 paddle/fluid/operators/elementwise/elementwise_floordiv_op.h delete mode 100644 paddle/fluid/operators/elementwise/elementwise_pow_op.cu delete mode 100644 paddle/fluid/operators/elementwise/elementwise_pow_op.h diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc index b876438a19..67b9b665c6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h" - #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" @@ -63,12 +61,6 @@ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(elementwise_floordiv, ops::ElementwiseOp, ops::ElementwiseFloorDivOpMaker); -REGISTER_OP_CPU_KERNEL( - elementwise_floordiv, - ops::ElementwiseFloorDivKernel, - ops::ElementwiseFloorDivKernel); - REGISTER_OP_VERSION(elementwise_floordiv) .AddCheckpoint( R"ROC(Register elementwise_floordiv for adding the attribute of Scale_y)ROC", diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu deleted file mode 100644 index 9b146fe727..0000000000 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h" - -namespace paddle { -namespace operators { - -template -class ElementwiseFloorDivKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - std::vector ins; - std::vector outs; - const auto& cuda_ctx = - ctx.template device_context(); - - int axis = PackTensorsIntoVector(ctx, &ins, &outs); - paddle::operators::LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, FloorDivFunctor()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - elementwise_floordiv, - ops::ElementwiseFloorDivKernel, - ops::ElementwiseFloorDivKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h deleted file mode 100644 index fc8f181619..0000000000 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/elementwise/elementwise_op.h" - -namespace paddle { -namespace operators { - -template -void elementwise_floor_div(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, framework::Tensor *z) { - int axis = ctx.Attr("axis"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, FloorDivFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, InverseFloorDivFunctor(), z); - } -} - -template -class ElementwiseFloorDivKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *y = ctx.Input("Y"); - auto *z = ctx.Output("Out"); - - z->mutable_data(ctx.GetPlace()); - - // dtype of x and y is int64 or int32 - elementwise_floor_div(ctx, x, y, z); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 5dfb7eece9..844b0a1950 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -49,23 +49,6 @@ using DivFunctor = phi::funcs::DivideFunctor; template using InverseDivFunctor = phi::funcs::InverseDivideFunctor; -// Floor Divide -template -struct FloorDivFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { - PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); - return static_cast(std::trunc(a / b)); - } -}; - -template -struct InverseFloorDivFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { - PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO); - return static_cast(std::trunc(b / a)); - } -}; - #undef DIV_ERROR_INFO // Maximum diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc index eddbfd3b15..c0dbb0df8c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op.cc @@ -9,8 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" - #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" @@ -70,19 +68,6 @@ REGISTER_OPERATOR(elementwise_pow, ops::ElementwiseOp, ops::ElementwisePowOpGradMaker); REGISTER_OPERATOR(elementwise_pow_grad, ops::ElementwiseOpGrad); -REGISTER_OP_CPU_KERNEL( - elementwise_pow, - ops::ElementwisePowKernel, - ops::ElementwisePowKernel, - ops::ElementwisePowKernel, - ops::ElementwisePowKernel); -REGISTER_OP_CPU_KERNEL( - elementwise_pow_grad, - ops::ElementwisePowGradKernel, - ops::ElementwisePowGradKernel, - ops::ElementwisePowGradKernel, - ops::ElementwisePowGradKernel); - REGISTER_OP_VERSION(elementwise_pow) .AddCheckpoint( R"ROC(Register elementwise_pow for adding the attribute of Scale_y)ROC", diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu b/paddle/fluid/operators/elementwise/elementwise_pow_op.cu deleted file mode 100644 index 1286064dac..0000000000 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.cu +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" - -namespace ops = paddle::operators; - -namespace paddle { -namespace operators { - -template -class ElementwisePowKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - std::vector ins; - std::vector outs; - const auto& cuda_ctx = - ctx.template device_context(); - - int axis = PackTensorsIntoVector(ctx, &ins, &outs); - paddle::operators::LaunchElementwiseCudaKernel(cuda_ctx, ins, &outs, - axis, PowFunctor()); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL( - elementwise_pow, - ops::ElementwisePowKernel, - ops::ElementwisePowKernel, - ops::ElementwisePowKernel, - ops::ElementwisePowKernel); -REGISTER_OP_CUDA_KERNEL( - elementwise_pow_grad, - ops::ElementwisePowGradKernel, - ops::ElementwisePowGradKernel, - ops::ElementwisePowGradKernel, - ops::ElementwisePowGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op.h b/paddle/fluid/operators/elementwise/elementwise_pow_op.h deleted file mode 100644 index 1dfe7ed232..0000000000 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op.h +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/fluid/operators/elementwise/elementwise_op.h" - -namespace paddle { -namespace operators { - -template -struct PowFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { -// TODO(wujionghao): A potential speed improvement is supporting different -// types in C++. -#if defined(__CUDA_ARCH__) || defined(__HIPCC__) - // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and - // it will return a float number like 2.99... , which floor to 2 - // when cast to int by default and it is wrong. - // Use llrint to cast it to the nearest integer, which is 3. - if (std::is_integral::value) { - return std::llrint( - std::pow(static_cast(a), static_cast(b))); - } -#endif - return std::pow(a, b); - } -}; - -template -class ElementwisePowKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::LoDTensor; - auto* x = ctx.Input("X"); - PADDLE_ENFORCE_EQ(x != nullptr, true, - platform::errors::NotFound( - "Cannot get input Variable X, Variable name = %s", - ctx.InputName("X"))); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - PowFunctor(), z); - } -}; - -template -struct PowGradDX { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { -#if defined(__CUDA_ARCH__) || defined(__HIPCC__) - if (std::is_integral::value) { - return dout * y * - std::pow(static_cast(x), static_cast(y - 1)); - } -#endif - return dout * y * std::pow(x, y - 1); - } -}; - -template -struct PowGradDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { -#if defined(__CUDA_ARCH__) || defined(__HIPCC__) - if (std::is_integral::value) { - return dout * std::log(static_cast(x)) * - std::pow(static_cast(x), static_cast(y)); - } -#endif - return dout * std::log(x) * std::pow(x, y); - } -}; - -template -class ElementwisePowGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* out = dout; - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, PowGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, PowGradDX(), PowGradDY()); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc index a2d2276747..c8fbd45612 100644 --- a/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_pow_op_npu.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_npu.h" -#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index b617064987..1548272f86 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -323,3 +323,11 @@ PD_REGISTER_KERNEL(minimum_grad, int, int64_t, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(elementwise_pow_grad, + CPU, + ALL_LAYOUT, + phi::ElementwisePowGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc index 1de40cb946..4ca41de7bb 100644 --- a/paddle/phi/kernels/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -113,6 +113,36 @@ void ModuloRawKernel(const Context& dev_ctx, } } +template +void FloorDivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + // allocate memory for out + dev_ctx.template Alloc(out); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + if (x_dims.size() >= y_dims.size()) { + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, axis, funcs::FloorDivideFunctor(), out); + } else { + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, axis, funcs::InverseFloorDivideFunctor(), out); + } +} + +template +void ElementwisePowRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + // allocate memory for out + dev_ctx.template Alloc(out); + funcs::ElementwiseCompute, T>( + dev_ctx, x, y, axis, funcs::ElementwisePowFunctor(), out); +} // Create the definition of Add DEFINE_CPU_ELEMENTWISE_OP(Add) @@ -207,3 +237,17 @@ PD_REGISTER_KERNEL(modulo_raw, double, int, int64_t) {} +PD_REGISTER_KERNEL(floor_divide_raw, + CPU, + ALL_LAYOUT, + phi::FloorDivideRawKernel, + int, + int64_t) {} +PD_REGISTER_KERNEL(elementwise_pow_raw, + CPU, + ALL_LAYOUT, + phi::ElementwisePowRawKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index 95832013ca..979bb61c2e 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -159,4 +159,13 @@ void MinimumGradKernel(const Context& dev_ctx, int axis, DenseTensor* dx, DenseTensor* dy); + +template +void ElementwisePowGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy); } // namespace phi diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index 019d4fed5b..4194631ea2 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -81,6 +81,25 @@ void ModuloKernel(const Context& dev_ctx, int axis = -1; ModuloRawKernel(dev_ctx, x, y, axis, out); } + +template +void FloorDivideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + FloorDivideRawKernel(dev_ctx, x, y, axis, out); +} + +template +void ElementwisePowKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int axis = -1; + ElementwisePowRawKernel(dev_ctx, x, y, axis, out); +} + } // namespace phi using complex64 = ::phi::dtype::complex; @@ -151,6 +170,16 @@ PD_REGISTER_KERNEL(minimum, phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( modulo, CPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL( + floor_divide, CPU, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {} +PD_REGISTER_KERNEL(elementwise_pow, + CPU, + ALL_LAYOUT, + phi::ElementwisePowKernel, + float, + double, + int, + int64_t) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -226,4 +255,14 @@ PD_REGISTER_KERNEL(minimum, phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( modulo, GPU, ALL_LAYOUT, phi::ModuloKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL( + floor_divide, GPU, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {} +PD_REGISTER_KERNEL(elementwise_pow, + GPU, + ALL_LAYOUT, + phi::ElementwisePowKernel, + float, + double, + int, + int64_t) {} #endif diff --git a/paddle/phi/kernels/elementwise_kernel.h b/paddle/phi/kernels/elementwise_kernel.h index f9c9c7f713..09b6b02e37 100644 --- a/paddle/phi/kernels/elementwise_kernel.h +++ b/paddle/phi/kernels/elementwise_kernel.h @@ -124,6 +124,32 @@ void ModuloKernel(const Context& dev_ctx, const DenseTensor& y, DenseTensor* out); +template +void FloorDivideRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + +template +void FloorDivideKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +template +void ElementwisePowRawKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + +template +void ElementwisePowKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + template DenseTensor Add(const Context& dev_ctx, const DenseTensor& x, @@ -200,4 +226,27 @@ DenseTensor Modulo(const Context& dev_ctx, ModuloKernel(dev_ctx, x, y, &dense_out); return dense_out; } + +template +DenseTensor FloorDivide(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + ElementwiseInferMeta(x, y, &meta_out); + FloorDivideKernel(dev_ctx, x, y, &dense_out); + return dense_out; +} + +template +DenseTensor ElementwisePow(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y) { + DenseTensor dense_out; + MetaTensor meta_out(&dense_out); + ElementwiseInferMeta(x, y, &meta_out); + ElementwisePowKernel(dev_ctx, x, y, &dense_out); + return dense_out; +} + } // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index 1e39cf5503..0ea5ff0e82 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -538,5 +538,40 @@ struct InverseModuloFunctor< return res; } }; + +template +struct FloorDivideFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); + return static_cast(std::trunc(a / b)); + } +}; + +template +struct InverseFloorDivideFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO); + return static_cast(std::trunc(b / a)); + } +}; + +template +struct ElementwisePowFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { +// TODO(wujionghao): A potential speed improvement is supporting different +// types in C++. +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and + // it will return a float number like 2.99... , which floor to 2 + // when cast to int by default and it is wrong. + // Use llrint to cast it to the nearest integer, which is 3. + if (std::is_integral::value) { + return std::llrint( + std::pow(static_cast(a), static_cast(b))); + } +#endif + return std::pow(a, b); + } +}; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 52819fd3de..3750e4b2bd 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -382,3 +382,11 @@ PD_REGISTER_KERNEL(minimum_grad, int64_t, phi::dtype::float16, phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(elementwise_pow_grad, + GPU, + ALL_LAYOUT, + phi::ElementwisePowGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/elementwise_kernel.cu b/paddle/phi/kernels/gpu/elementwise_kernel.cu index bd6995cb13..73964d31a3 100644 --- a/paddle/phi/kernels/gpu/elementwise_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_kernel.cu @@ -55,6 +55,10 @@ DEFINE_CUDA_ELEMENTWISE_OP(Maximum) DEFINE_CUDA_ELEMENTWISE_OP(Minimum) // Create the definition of Modulo DEFINE_CUDA_ELEMENTWISE_OP(Modulo) +// Create the definition of FloorDivide +DEFINE_CUDA_ELEMENTWISE_OP(FloorDivide) +// Create the definition of Pow +DEFINE_CUDA_ELEMENTWISE_OP(ElementwisePow) } // namespace phi @@ -148,3 +152,17 @@ PD_REGISTER_KERNEL(modulo_raw, double, int, int64_t) {} +PD_REGISTER_KERNEL(floor_divide_raw, + GPU, + ALL_LAYOUT, + phi::FloorDivideRawKernel, + int, + int64_t) {} +PD_REGISTER_KERNEL(elementwise_pow_raw, + GPU, + ALL_LAYOUT, + phi::ElementwisePowRawKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 07e5bf9ae0..aba4a5f5fb 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -666,4 +666,44 @@ struct MinGradDy { return dout * static_cast(x >= y); } }; + +template +struct PowGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + if (std::is_integral::value) { + return dout * y * + std::pow(static_cast(x), static_cast(y - 1)); + } +#endif + return dout * y * std::pow(x, y - 1); + } +}; + +template +struct PowGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { +#if defined(__CUDA_ARCH__) || defined(__HIPCC__) + if (std::is_integral::value) { + return dout * std::log(static_cast(x)) * + std::pow(static_cast(x), static_cast(y)); + } +#endif + return dout * std::log(x) * std::pow(x, y); + } +}; + +template +void ElementwisePowGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + phi::funcs::ElemwiseGradCompute, PowGradDY>( + dev_ctx, x, y, dout, dout, axis, dx, dy, PowGradDX(), PowGradDY()); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 7f00af6f9a..cf6f9d4dfb 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -82,6 +82,24 @@ KernelSignature ElementwiseModOpArgumentMapping( return KernelSignature("modulo_raw", {"X", "Y"}, {"axis"}, {"Out"}); } +KernelSignature ElementwiseFloorDivOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (axis == -1) { + return KernelSignature("floor_divide", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("floor_divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature ElementwisePowOpArgumentMapping( + const ArgumentMappingContext& ctx) { + int axis = paddle::any_cast(ctx.Attr("axis")); + if (axis == -1) { + return KernelSignature("elementwise_pow", {"X", "Y"}, {}, {"Out"}); + } + return KernelSignature("elementwise_pow_raw", {"X", "Y"}, {"axis"}, {"Out"}); +} + KernelSignature ElementwiseAddGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("add_grad", @@ -200,6 +218,13 @@ KernelSignature ElementwiseMinGradOpArgumentMapping( {"axis"}, {GradVarName("X"), GradVarName("Y")}); } +KernelSignature ElementwisePowGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("elementwise_pow_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); +} } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); @@ -209,6 +234,7 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide); PD_REGISTER_BASE_KERNEL_NAME(elementwise_max, maximum); PD_REGISTER_BASE_KERNEL_NAME(elementwise_min, minimum); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mod, modulo); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_floordiv, floor_divide); PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad_grad, add_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_add_triple_grad, add_triple_grad); @@ -240,6 +266,10 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_min, phi::ElementwiseMinOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_mod, phi::ElementwiseModOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_floordiv, + phi::ElementwiseFloorDivOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_pow, + phi::ElementwisePowOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad, phi::ElementwiseAddGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_add_grad_grad, @@ -272,3 +302,5 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_max_grad, phi::ElementwiseMaxGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_min_grad, phi::ElementwiseMinGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_pow_grad, + phi::ElementwisePowGradOpArgumentMapping); -- GitLab