From 567e2fc8902d87affddcb2e5de83f83d4214397f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Thu, 22 Sep 2022 10:10:43 +0200 Subject: [PATCH] [PHI] Migrate gelu kernels (#45596) * gaussian random * mkldnn to onednn renaming * fix merge conflicts * remove fluid code * onednn renaming * gelu fwd * sort activations * gelu gradient * remove unused macros * merge conflicts * fix merge conflicts * remove extra contraint from gelu op --- cmake/operators.cmake | 2 +- .../mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc | 2 +- paddle/fluid/operators/gelu_op.cc | 3 +- .../operators/mkldnn/activation_mkldnn_op.cc | 90 --------------- .../kernels/onednn/activation_grad_kernel.cc | 108 +++++++++--------- .../phi/kernels/onednn/activation_kernel.cc | 64 +++++++---- 6 files changed, 103 insertions(+), 166 deletions(-) diff --git a/cmake/operators.cmake b/cmake/operators.cmake index c560dddfef..bbf77b6615 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -510,7 +510,7 @@ function(op_library TARGET) if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) # Append first implemented MKLDNN activation operator if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(gelu, MKLDNN);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc index 6f7bb614cc..49db8b8f7f 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc @@ -38,7 +38,7 @@ USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_ITSELF(gelu); -USE_OP_DEVICE_KERNEL(gelu, MKLDNN); +PD_DECLARE_KERNEL(gelu, OneDNN, ALL_LAYOUT); PD_DECLARE_ARG_MAPPING_FN(gelu); namespace paddle { diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index c5ec8d1b21..15b0a04ab2 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -77,8 +77,7 @@ class GeluGradOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - auto it = this->Attrs().find("use_mkldnn"); - if (it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { + if (this->CanMKLDNNBeUsed(ctx, data_type)) { return framework::OpKernelType(data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 728d86cd94..a71fd3fa02 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -52,42 +52,6 @@ class MKLDNNActivationGradKernel } }; -template -void eltwise_forward(const framework::ExecutionContext &ctx, - dnnl::algorithm algorithm) { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), - true, - paddle::platform::errors::PreconditionNotMet( - "Operator DNNL eletwise_forward must use CPUPlace")); - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const auto *x = ctx.Input("X"); - auto *out = ctx.Output("Out"); - - bool is_inplaced = x->IsSharedBufferWith(*out); - - platform::ActivationMKLDNNHandler handler( - algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x); - - auto src_memory_p = handler.AcquireSrcMemory(x); - std::shared_ptr dst_memory_p = nullptr; - if (is_inplaced) { - dst_memory_p = src_memory_p; - out->mutable_data(ctx.GetPlace()); - } else { - dst_memory_p = handler.AcquireDstMemory(out); - } - auto activation_p = handler.AcquireForwardPrimitive(); - - auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_p->execute( - astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}}); - astream.wait(); - - out->set_mem_desc(dst_memory_p->get_desc()); -} - template void eltwise_grad(const framework::ExecutionContext &ctx, dnnl::algorithm algorithm) { @@ -116,34 +80,6 @@ void eltwise_grad(const framework::ExecutionContext &ctx, dx->set_mem_desc(diff_src_memory_p->get_desc()); } -template -void eltwise_grad_use_out(const framework::ExecutionContext &ctx, - dnnl::algorithm algorithm) { - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const auto *out = ctx.Input("Out"); - const auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - - platform::ActivationMKLDNNHandler handler( - algorithm, ctx, mkldnn_engine, ctx.GetPlace(), out, dout); - - auto dst_memory_p = handler.AcquireBackwardSrcMemory(out); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); - auto activation_backward_p = handler.AcquireBackwardPrimitive(); - - auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_backward_p->execute(astream, - {{DNNL_ARG_DST, *dst_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); - astream.wait(); - - dx->set_mem_desc(diff_src_memory_p->get_desc()); -} - template struct MKLDNNActivationGradFunc : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { @@ -151,30 +87,6 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor { } }; -template -struct GeluMKLDNNFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const bool approximate = ctx.Attr("approximate"); - if (approximate) { - eltwise_forward(ctx, dnnl::algorithm::eltwise_gelu_tanh); - } else { - eltwise_forward(ctx, dnnl::algorithm::eltwise_gelu_erf); - } - } -}; - -template -struct GeluMKLDNNGradFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const bool approximate = ctx.Attr("approximate"); - if (approximate) { - eltwise_grad(ctx, dnnl::algorithm::eltwise_gelu_tanh); - } else { - eltwise_grad(ctx, dnnl::algorithm::eltwise_gelu_erf); - } - } -}; - template struct SoftplusMKLDNNFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { @@ -209,6 +121,4 @@ namespace ops = paddle::operators; ops::grad_functor>); REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor); -REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNFunctor); -REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNGradFunctor); REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor); diff --git a/paddle/phi/kernels/onednn/activation_grad_kernel.cc b/paddle/phi/kernels/onednn/activation_grad_kernel.cc index 4ad073cb00..a8c988a69d 100644 --- a/paddle/phi/kernels/onednn/activation_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_grad_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/activation_grad_kernel.h" +#include "paddle/phi/kernels/gelu_grad_kernel.h" #include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" @@ -23,16 +24,6 @@ namespace phi { -#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \ - template \ - void name##GradKernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& dout, \ - DenseTensor* dx) { \ - functor_class functor; \ - functor(dev_ctx, x, dout, 0, 0, dx); \ - } - #define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \ name, functor_class, attr) \ template \ @@ -55,18 +46,6 @@ namespace phi { functor(dev_ctx, out, dout, 0, 0, dx); \ } -#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \ - name, functor_class, attr) \ - template \ - void name##GradKernel(const Context& dev_ctx, \ - const DenseTensor& out, \ - const DenseTensor& dout, \ - float attr, \ - DenseTensor* dx) { \ - functor_class functor; \ - functor(dev_ctx, out, dout, attr, 0, dx); \ - } - template void eltwise_grad(const OneDNNContext& dev_ctx, const DenseTensor& x, @@ -158,12 +137,14 @@ using AbsOneDNNGradFunctor = OneDNNActivationGradFunc; template -using ReluOneDNNGradFunctor = - OneDNNActivationGradFunc; +using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< + T, + dnnl::algorithm::eltwise_elu_use_dst_for_bwd>; template -using SwishOneDNNGradFunctor = - OneDNNActivationGradFunc; +using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< + T, + dnnl::algorithm::eltwise_exp_use_dst_for_bwd>; template using HardSwishOneDNNGradFunctor = @@ -174,14 +155,21 @@ using MishOneDNNGradFunctor = OneDNNActivationGradFunc; template -using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< - T, - dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>; +using GeluTanhOneDNNGradFunctor = + OneDNNActivationGradFunc; template -using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< +using GeluErfOneDNNGradFunctor = + OneDNNActivationGradFunc; + +template +using ReluOneDNNGradFunctor = + OneDNNActivationGradFunc; + +template +using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< T, - dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>; + dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>; template using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< @@ -189,22 +177,21 @@ using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>; template -using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< - T, - dnnl::algorithm::eltwise_elu_use_dst_for_bwd>; +using SwishOneDNNGradFunctor = + OneDNNActivationGradFunc; template -using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< +using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< T, - dnnl::algorithm::eltwise_exp_use_dst_for_bwd>; + dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>; -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor); -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor); -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, - SigmoidOneDNNGradUseOutFunctor); -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor); DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Abs, AbsOneDNNGradFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor); DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluOneDNNGradFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, + SigmoidOneDNNGradUseOutFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor); DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, ReluOneDNNGradFunctor, @@ -215,17 +202,6 @@ DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishOneDNNGradFunctor, beta); -template -void HardSwishGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& dout, - float threshold, - float scale, - float offset, - DenseTensor* dx) { - HardSwishOneDNNGradFunctor functor; - functor(dev_ctx, x, dout, threshold, 0, dx); -} template void EluGradKernel(const Context& dev_ctx, @@ -238,6 +214,33 @@ void EluGradKernel(const Context& dev_ctx, functor(dev_ctx, out, dout, alpha, 0, dx); } +template +void GeluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + bool approximate, + DenseTensor* x_grad) { + if (approximate) { + GeluTanhOneDNNGradFunctor functor; + functor(dev_ctx, x, out_grad, 0, 0, x_grad); + } else { + GeluErfOneDNNGradFunctor functor; + functor(dev_ctx, x, out_grad, 0, 0, x_grad); + } +} + +template +void HardSwishGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + float threshold, + float scale, + float offset, + DenseTensor* dx) { + HardSwishOneDNNGradFunctor functor; + functor(dev_ctx, x, dout, threshold, 0, dx); +} + } // namespace phi PD_REGISTER_KERNEL(relu_grad, @@ -254,6 +257,7 @@ PD_REGISTER_KERNEL(relu_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) diff --git a/paddle/phi/kernels/onednn/activation_kernel.cc b/paddle/phi/kernels/onednn/activation_kernel.cc index 40f2d8fd4c..b90ea1c61b 100644 --- a/paddle/phi/kernels/onednn/activation_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/gelu_grad_kernel.h" #include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" @@ -91,16 +92,18 @@ template using AbsOneDNNFunctor = OneDNNActivationFunc; template -using ReluOneDNNFunctor = - OneDNNActivationFunc; +using EluOneDNNFunctor = OneDNNActivationFunc; template -using Relu6OneDNNFunctor = - OneDNNActivationFunc; +using ExpOneDNNFunctor = OneDNNActivationFunc; template -using SwishOneDNNFunctor = - OneDNNActivationFunc; +using GeluTanhOneDNNFunctor = + OneDNNActivationFunc; + +template +using GeluErfOneDNNFunctor = + OneDNNActivationFunc; template using HardSwishOneDNNFunctor = @@ -111,40 +114,46 @@ using MishOneDNNFunctor = OneDNNActivationFunc; template -using SigmoidOneDNNFunctor = - OneDNNActivationFunc; +using ReluOneDNNFunctor = + OneDNNActivationFunc; template -using TanhOneDNNFunctor = - OneDNNActivationFunc; +using Relu6OneDNNFunctor = + OneDNNActivationFunc; template -using SqrtOneDNNFunctor = - OneDNNActivationFunc; +using RoundOneDNNFunctor = + OneDNNActivationFunc; template -using EluOneDNNFunctor = OneDNNActivationFunc; +using SigmoidOneDNNFunctor = + OneDNNActivationFunc; template -using ExpOneDNNFunctor = OneDNNActivationFunc; +using SqrtOneDNNFunctor = + OneDNNActivationFunc; template -using RoundOneDNNFunctor = - OneDNNActivationFunc; +using SwishOneDNNFunctor = + OneDNNActivationFunc; + +template +using TanhOneDNNFunctor = + OneDNNActivationFunc; DEFINE_ONEDNN_ACTIVATION_KERNEL(Abs, AbsOneDNNFunctor) -DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor) -DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Exp, ExpOneDNNFunctor) -DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor) +DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Sigmoid, SigmoidOneDNNFunctor) +DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor) +DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor) // round eltwise primitive doesn't support BF16, nor does it support grad DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor) +DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold) -DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta) @@ -159,6 +168,20 @@ void HardSwishKernel(const Context& dev_ctx, functor(dev_ctx, x, threshold, 0, out); } +template +void GeluKernel(const Context& dev_ctx, + const DenseTensor& x, + bool approximate, + DenseTensor* out) { + if (approximate) { + GeluTanhOneDNNFunctor functor; + functor(dev_ctx, x, 0, 0, out); + } else { + GeluErfOneDNNFunctor functor; + functor(dev_ctx, x, 0, 0, out); + } +} + } // namespace phi PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {} @@ -170,6 +193,7 @@ PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {} PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) +PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) -- GitLab