From bb801960a24e6364b5a156d829a05668cf85eb0b Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Mon, 14 Mar 2022 17:21:04 +0800 Subject: [PATCH] [phi]migrate fmax,fmin kernel to phi (#40140) --- .../elementwise/elementwise_functor.h | 83 ------- .../elementwise/elementwise_max_op.cc | 18 -- .../elementwise/elementwise_max_op.cu | 18 -- .../elementwise/elementwise_max_op.h | 98 -------- .../elementwise/elementwise_min_op.cc | 18 -- .../elementwise/elementwise_min_op.cu | 18 -- .../elementwise/elementwise_min_op.h | 99 -------- .../kernels/cpu/elementwise_grad_kernel.cc | 17 ++ paddle/phi/kernels/cpu/elementwise_kernel.cc | 35 +++ paddle/phi/kernels/elementwise_grad_kernel.h | 18 ++ paddle/phi/kernels/elementwise_kernel.h | 36 +++ .../phi/kernels/funcs/elementwise_functor.h | 213 ++++++++++++++++++ .../kernels/gpu/elementwise_grad_kernel.cu | 17 ++ paddle/phi/kernels/gpu/elementwise_kernel.cu | 35 +++ .../impl/elementwise_grad_kernel_impl.h | 96 ++++++++ .../kernels/impl/elementwise_kernel_impl.h | 47 ++++ paddle/phi/ops/compat/elementwise_sig.cc | 22 ++ 17 files changed, 536 insertions(+), 352 deletions(-) create mode 100644 paddle/phi/kernels/cpu/elementwise_kernel.cc create mode 100644 paddle/phi/kernels/elementwise_kernel.h create mode 100644 paddle/phi/kernels/gpu/elementwise_kernel.cu create mode 100644 paddle/phi/kernels/impl/elementwise_kernel_impl.h diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 14baeaa74d..54931d9929 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -1,11 +1,8 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -90,86 +87,6 @@ struct MinFunctor { template using Complex = paddle::platform::complex; -// Fmax -template -struct FMaxFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { - return std::fmax(a, b); - } -}; - -template <> -struct FMaxFunctor { - inline HOSTDEVICE paddle::platform::float16 operator()( - const paddle::platform::float16 a, - const paddle::platform::float16 b) const { - float float_a = static_cast(a); - float float_b = static_cast(b); - auto result = std::fmax(float_a, float_b); - return static_cast(result); - } -}; - -template <> -struct FMaxFunctor { - inline HOSTDEVICE int operator()(const int a, const int b) const { - float float_a = static_cast(a); - float float_b = static_cast(b); - auto result = std::fmax(float_a, float_b); - return std::lrint(result); - } -}; - -template <> -struct FMaxFunctor { - inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const { - double double_a = static_cast(a); - double double_b = static_cast(b); - auto result = std::fmax(double_a, double_b); - return std::llrint(result); - } -}; - -// Fmin -template -struct FMinFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { - return std::fmin(a, b); - } -}; - -template <> -struct FMinFunctor { - inline HOSTDEVICE paddle::platform::float16 operator()( - const paddle::platform::float16 a, - const paddle::platform::float16 b) const { - float float_a = static_cast(a); - float float_b = static_cast(b); - auto result = std::fmin(float_a, float_b); - return static_cast(result); - } -}; - -template <> -struct FMinFunctor { - inline HOSTDEVICE int operator()(const int a, const int b) const { - float float_a = static_cast(a); - float float_b = static_cast(b); - auto result = std::fmin(float_a, float_b); - return std::lrint(result); - } -}; - -template <> -struct FMinFunctor { - inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const { - double double_a = static_cast(a); - double double_b = static_cast(b); - auto result = std::fmin(double_a, double_b); - return std::llrint(result); - } -}; - template struct MinGradXFunctor { inline HOSTDEVICE T operator()(const T x, const T y, const T dout) const { diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc index 91da732ef0..d91315cc51 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -151,21 +151,3 @@ REGISTER_OPERATOR(elementwise_fmax, ops::ElementwiseOp, ops::ElementwiseFMaxGradOpMaker); REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad); - -REGISTER_OP_CPU_KERNEL( - elementwise_fmax, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel); -REGISTER_OP_CPU_KERNEL( - elementwise_fmax_grad, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 123332a4a2..0d5f56fda1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -86,21 +86,3 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel); - -REGISTER_OP_CUDA_KERNEL( - elementwise_fmax, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel, - ops::ElementwiseFMaxKernel); -REGISTER_OP_CUDA_KERNEL( - elementwise_fmax_grad, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel, - ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index cff30be50a..afe1073d89 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -35,21 +35,6 @@ class ElementwiseMaxKernel : public framework::OpKernel { } }; -template -class ElementwiseFMaxKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - - z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - FMaxFunctor(), z); - } -}; - template struct MaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -104,88 +89,5 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { } }; -template -struct FMaxGradDx { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x >= y) || isnan(y)); - } -}; - -template <> -struct FMaxGradDx { - HOSTDEVICE paddle::platform::float16 operator()( - paddle::platform::float16 x, paddle::platform::float16 y, - paddle::platform::float16 out, paddle::platform::float16 dout) const { - return dout * static_cast( - (x >= y) || paddle::platform::isnan(y)); - } -}; - -template <> -struct FMaxGradDx { - HOSTDEVICE int operator()(int x, int y, int out, int dout) const { - return dout * static_cast((x >= y)); - } -}; - -template <> -struct FMaxGradDx { - HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, - int64_t dout) const { - return dout * static_cast((x >= y)); - } -}; - -template -struct FMaxGradDy { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast(!((x >= y) || isnan(y))); - } -}; - -template <> -struct FMaxGradDy { - HOSTDEVICE paddle::platform::float16 operator()( - paddle::platform::float16 x, paddle::platform::float16 y, - paddle::platform::float16 out, paddle::platform::float16 dout) const { - return dout * static_cast( - !((x >= y) || paddle::platform::isnan(y))); - } -}; - -template <> -struct FMaxGradDy { - HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, - int64_t dout) const { - return dout * static_cast(!((x >= y))); - } -}; - -template <> -struct FMaxGradDy { - HOSTDEVICE int operator()(int x, int y, int out, int dout) const { - return dout * static_cast(!((x >= y))); - } -}; - -template -class ElementwiseFMaxGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - auto* out = dout; // Fake out, not used - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, FMaxGradDy>( - ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx(), - FMaxGradDy()); - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cc b/paddle/fluid/operators/elementwise/elementwise_min_op.cc index 3a19519995..dad80a2c33 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cc @@ -147,21 +147,3 @@ REGISTER_OPERATOR(elementwise_fmin, ops::ElementwiseOp, ops::ElementwiseFMinGradOpMaker); REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad); - -REGISTER_OP_CPU_KERNEL( - elementwise_fmin, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel); -REGISTER_OP_CPU_KERNEL( - elementwise_fmin_grad, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index 5af985567d..fb8bc9ac7f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -82,21 +82,3 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMinGradKernel, ops::ElementwiseMinGradKernel); - -REGISTER_OP_CUDA_KERNEL( - elementwise_fmin, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel, - ops::ElementwiseFMinKernel); -REGISTER_OP_CUDA_KERNEL( - elementwise_fmin_grad, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel, - ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 88fb044d42..283ad2adde 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -35,21 +35,6 @@ class ElementwiseMinKernel : public framework::OpKernel { } }; -template -class ElementwiseFMinKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - - z->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - FMinFunctor(), z); - } -}; - template struct MinGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -124,89 +109,5 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel { ElementwiseMinGrad(ctx, x, y, out, dout, dx, dy); } }; - -template -struct FMinGradDx { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast((x <= y) || isnan(y)); - } -}; - -template <> -struct FMinGradDx { - HOSTDEVICE paddle::platform::float16 operator()( - paddle::platform::float16 x, paddle::platform::float16 y, - paddle::platform::float16 out, paddle::platform::float16 dout) const { - return dout * static_cast( - (x <= y) || paddle::platform::isnan(y)); - } -}; - -template <> -struct FMinGradDx { - HOSTDEVICE int operator()(int x, int y, int out, int dout) const { - return dout * static_cast((x <= y)); - } -}; - -template <> -struct FMinGradDx { - HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, - int64_t dout) const { - return dout * static_cast((x <= y)); - } -}; - -template -struct FMinGradDy { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return dout * static_cast(!((x <= y) || isnan(y))); - } -}; - -template <> -struct FMinGradDy { - HOSTDEVICE paddle::platform::float16 operator()( - paddle::platform::float16 x, paddle::platform::float16 y, - paddle::platform::float16 out, paddle::platform::float16 dout) const { - return dout * static_cast( - !((x <= y) || paddle::platform::isnan(y))); - } -}; - -template <> -struct FMinGradDy { - HOSTDEVICE int operator()(int x, int y, int out, int dout) const { - return dout * static_cast(!((x <= y))); - } -}; - -template <> -struct FMinGradDy { - HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, - int64_t dout) const { - return dout * static_cast(!((x <= y))); - } -}; - -template -class ElementwiseFMinGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - auto* out = dout; // Fake out, not used - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, FMinGradDy>( - ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx(), - FMinGradDy()); - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index cd513e809f..bf6ec012b2 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -259,3 +259,20 @@ PD_REGISTER_KERNEL(multiply_triple_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} +PD_REGISTER_KERNEL(elementwise_fmax_grad, + CPU, + ALL_LAYOUT, + phi::ElementwiseFMaxGradKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(elementwise_fmin_grad, + CPU, + ALL_LAYOUT, + phi::ElementwiseFMinGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc new file mode 100644 index 0000000000..37ad18df56 --- /dev/null +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" + +PD_REGISTER_KERNEL(elementwise_fmax, + CPU, + ALL_LAYOUT, + phi::ElementwiseFMaxKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(elementwise_fmin, + CPU, + ALL_LAYOUT, + phi::ElementwiseFMinKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index 58ae11a9c4..fb2633cc9f 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -124,4 +124,22 @@ void MultiplyTripleGradKernel(const Context& dev_ctx, DenseTensor* d_ddx, DenseTensor* d_ddy); +template +void ElementwiseFMaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad); + +template +void ElementwiseFMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad); + } // namespace phi diff --git a/paddle/phi/kernels/elementwise_kernel.h b/paddle/phi/kernels/elementwise_kernel.h new file mode 100644 index 0000000000..c1e73ad91c --- /dev/null +++ b/paddle/phi/kernels/elementwise_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void ElementwiseFMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + +template +void ElementwiseFMinKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index b01d50015f..f9e66836a6 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -159,6 +159,219 @@ struct DivGradYFunctor> { return -a * out_div_c_conj; } }; +// Fmin +template +struct FMinFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + return std::fmin(a, b); + } +}; + +template <> +struct FMinFunctor { + inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a, + const dtype::float16 b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmin(float_a, float_b); + return static_cast(result); + } +}; + +template <> +struct FMinFunctor { + inline HOSTDEVICE int operator()(const int a, const int b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmin(float_a, float_b); + return std::lrint(result); + } +}; + +template <> +struct FMinFunctor { + inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const { + double double_a = static_cast(a); + double double_b = static_cast(b); + auto result = std::fmin(double_a, double_b); + return std::llrint(result); + } +}; + +// Fmax +template +struct FMaxFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { + return std::fmax(a, b); + } +}; + +template <> +struct FMaxFunctor { + inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a, + const dtype::float16 b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmax(float_a, float_b); + return static_cast(result); + } +}; + +template <> +struct FMaxFunctor { + inline HOSTDEVICE int operator()(const int a, const int b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmax(float_a, float_b); + return std::lrint(result); + } +}; + +template <> +struct FMaxFunctor { + inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const { + double double_a = static_cast(a); + double double_b = static_cast(b); + auto result = std::fmax(double_a, double_b); + return std::llrint(result); + } +}; + +template +struct FMaxGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast((x >= y) || isnan(y)); + } +}; + +template <> +struct FMaxGradDx { + HOSTDEVICE dtype::float16 operator()(dtype::float16 x, + dtype::float16 y, + dtype::float16 out, + dtype::float16 dout) const { + return dout * static_cast((x >= y) || dtype::isnan(y)); + } +}; + +template <> +struct FMaxGradDx { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast((x >= y)); + } +}; + +template <> +struct FMaxGradDx { + HOSTDEVICE int64_t operator()(int64_t x, + int64_t y, + int64_t out, + int64_t dout) const { + return dout * static_cast((x >= y)); + } +}; + +template +struct FMaxGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast(!((x >= y) || isnan(y))); + } +}; + +template <> +struct FMaxGradDy { + HOSTDEVICE dtype::float16 operator()(dtype::float16 x, + dtype::float16 y, + dtype::float16 out, + dtype::float16 dout) const { + return dout * static_cast(!((x >= y) || dtype::isnan(y))); + } +}; + +template <> +struct FMaxGradDy { + HOSTDEVICE int64_t operator()(int64_t x, + int64_t y, + int64_t out, + int64_t dout) const { + return dout * static_cast(!((x >= y))); + } +}; + +template <> +struct FMaxGradDy { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast(!((x >= y))); + } +}; + +template +struct FMinGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast((x <= y) || isnan(y)); + } +}; + +template <> +struct FMinGradDx { + HOSTDEVICE dtype::float16 operator()(dtype::float16 x, + dtype::float16 y, + dtype::float16 out, + dtype::float16 dout) const { + return dout * static_cast((x <= y) || dtype::isnan(y)); + } +}; + +template <> +struct FMinGradDx { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast((x <= y)); + } +}; + +template <> +struct FMinGradDx { + HOSTDEVICE int64_t operator()(int64_t x, + int64_t y, + int64_t out, + int64_t dout) const { + return dout * static_cast((x <= y)); + } +}; + +template +struct FMinGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast(!((x <= y) || isnan(y))); + } +}; + +template <> +struct FMinGradDy { + HOSTDEVICE dtype::float16 operator()(dtype::float16 x, + dtype::float16 y, + dtype::float16 out, + dtype::float16 dout) const { + return dout * static_cast(!((x <= y) || dtype::isnan(y))); + } +}; + +template <> +struct FMinGradDy { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast(!((x <= y))); + } +}; + +template <> +struct FMinGradDy { + HOSTDEVICE int64_t operator()(int64_t x, + int64_t y, + int64_t out, + int64_t dout) const { + return dout * static_cast(!((x <= y))); + } +}; template struct MultiplyGradFunctor { diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 81f7fac108..c4481bf6ce 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -282,3 +282,20 @@ PD_REGISTER_KERNEL(multiply_triple_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} +PD_REGISTER_KERNEL(elementwise_fmax_grad, + GPU, + ALL_LAYOUT, + phi::ElementwiseFMaxGradKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(elementwise_fmin_grad, + GPU, + ALL_LAYOUT, + phi::ElementwiseFMinGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/elementwise_kernel.cu b/paddle/phi/kernels/gpu/elementwise_kernel.cu new file mode 100644 index 0000000000..2cffc68fa0 --- /dev/null +++ b/paddle/phi/kernels/gpu/elementwise_kernel.cu @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" + +PD_REGISTER_KERNEL(elementwise_fmax, + GPU, + ALL_LAYOUT, + phi::ElementwiseFMaxKernel, + float, + double, + int, + int64_t) {} + +PD_REGISTER_KERNEL(elementwise_fmin, + GPU, + ALL_LAYOUT, + phi::ElementwiseFMinKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index 65427e8750..0b7a5d3bcb 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -258,6 +258,102 @@ void DivideDoubleGradKernel(const Context& dev_ctx, dout_result.device(place) = static_cast(-1) * dout_result; } } +template +void ElementwiseFMaxGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad) { + funcs::ElementwiseGradPreProcess(out_grad, x_grad); + + auto out = out_grad; // Fake out, not used + auto x_dim = x.dims(); + auto y_dim = y.dims(); + if (x.dims() == y.dims()) { + funcs::ElemwiseGradComputeNoBroadcast, + funcs::FMaxGradDy>( + dev_ctx, + x_dim, + y_dim, + x, + y, + out, + out_grad, + axis, + x_grad, + y_grad, + funcs::FMaxGradDx(), + funcs::FMaxGradDy()); + } else { + funcs::ElemwiseGradComputeWithBroadcast, + funcs::FMaxGradDy>( + dev_ctx, + x_dim, + y_dim, + x, + y, + out, + out_grad, + axis, + x_grad, + y_grad, + funcs::FMaxGradDx(), + funcs::FMaxGradDy()); + } +} + +template +void ElementwiseFMinGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad) { + funcs::ElementwiseGradPreProcess(out_grad, x_grad); + auto out = out_grad; // Fake out, not used + auto x_dim = x.dims(); + auto y_dim = y.dims(); + if (x.dims() == y.dims()) { + funcs::ElemwiseGradComputeNoBroadcast, + funcs::FMinGradDy>( + dev_ctx, + x_dim, + y_dim, + x, + y, + out, + out_grad, + axis, + x_grad, + y_grad, + funcs::FMinGradDx(), + funcs::FMinGradDy()); + } else { + funcs::ElemwiseGradComputeWithBroadcast, + funcs::FMinGradDy>( + dev_ctx, + x_dim, + y_dim, + x, + y, + out, + out_grad, + axis, + x_grad, + y_grad, + funcs::FMinGradDx(), + funcs::FMinGradDy()); + } +} template struct MulGradDX { diff --git a/paddle/phi/kernels/impl/elementwise_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_kernel_impl.h new file mode 100644 index 0000000000..775a91bf02 --- /dev/null +++ b/paddle/phi/kernels/impl/elementwise_kernel_impl.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/elementwise_kernel.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#endif + +namespace phi { +template +void ElementwiseFMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + funcs::ElementwiseCompute, T, T>( + dev_ctx, x, y, axis, funcs::FMaxFunctor(), out); +} + +template +void ElementwiseFMinKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { + dev_ctx.template Alloc(out); + funcs::ElementwiseCompute, T, T>( + dev_ctx, x, y, axis, funcs::FMinFunctor(), out); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index fc890fa3a4..1d2aaa04f0 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -114,6 +114,14 @@ KernelSignature ElementwiseDivGradOpArgumentMapping( {GradVarName("X"), GradVarName("Y")}); } +KernelSignature ElementwiseFMinGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("elementwise_fmin_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); +} + KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("divide_double_grad", @@ -130,6 +138,14 @@ KernelSignature ElementwiseMulGradOpArgumentMapping( {GradVarName("X"), GradVarName("Y")}); } +KernelSignature ElementwiseFMaxGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("elementwise_fmax_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); +} + KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("multiply_double_grad", @@ -192,3 +208,9 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad, phi::ElementwiseMulDoubleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad, phi::ElementwiseMulTripleGradOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad, + phi::ElementwiseFMaxGradOpArgumentMapping); + +PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad, + phi::ElementwiseFMinGradOpArgumentMapping); -- GitLab