diff --git a/cmake/operators.cmake b/cmake/operators.cmake index c560dddfef5e79a113b00466be67bcd508afdbee..bbf77b6615dfb731242cfef5feb8f9802b1ca3fa 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -510,7 +510,7 @@ function(op_library TARGET) if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) # Append first implemented MKLDNN activation operator if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") - file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(gelu, MKLDNN);\n") + file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n") diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc index 6f7bb614cc79f52c4063519c335f1f94bac79043..49db8b8f7f8e546259ec2906a3e2d4121a59efbb 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_conv_bn_fuse_pass_tester.cc @@ -38,7 +38,7 @@ USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); USE_OP_ITSELF(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_ITSELF(gelu); -USE_OP_DEVICE_KERNEL(gelu, MKLDNN); +PD_DECLARE_KERNEL(gelu, OneDNN, ALL_LAYOUT); PD_DECLARE_ARG_MAPPING_FN(gelu); namespace paddle { diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index add87fdd3c11211a441db1e3394f77cee8de0648..e7debf896a2861fef70ed92507853395cc47ad43 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -80,11 +80,11 @@ class GeluGradOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - auto it = this->Attrs().find("use_mkldnn"); - if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 728d86cd94e33da80f1670f584a0d3e4a9aabd7d..4909fdc32ba6fc96be6d900cdb4ba0cb735a44f7 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -42,139 +42,6 @@ class MKLDNNActivationKernel } }; -template -class MKLDNNActivationGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - Functor functor; - functor(ctx); - } -}; - -template -void eltwise_forward(const framework::ExecutionContext &ctx, - dnnl::algorithm algorithm) { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), - true, - paddle::platform::errors::PreconditionNotMet( - "Operator DNNL eletwise_forward must use CPUPlace")); - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const auto *x = ctx.Input("X"); - auto *out = ctx.Output("Out"); - - bool is_inplaced = x->IsSharedBufferWith(*out); - - platform::ActivationMKLDNNHandler handler( - algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x); - - auto src_memory_p = handler.AcquireSrcMemory(x); - std::shared_ptr dst_memory_p = nullptr; - if (is_inplaced) { - dst_memory_p = src_memory_p; - out->mutable_data(ctx.GetPlace()); - } else { - dst_memory_p = handler.AcquireDstMemory(out); - } - auto activation_p = handler.AcquireForwardPrimitive(); - - auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_p->execute( - astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}}); - astream.wait(); - - out->set_mem_desc(dst_memory_p->get_desc()); -} - -template -void eltwise_grad(const framework::ExecutionContext &ctx, - dnnl::algorithm algorithm) { - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const auto *x = ctx.Input("X"); - const auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - - platform::ActivationMKLDNNHandler handler( - algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, dout); - - auto src_memory_p = handler.AcquireBackwardSrcMemory(x); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); - auto activation_backward_p = handler.AcquireBackwardPrimitive(); - - auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_backward_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); - astream.wait(); - - dx->set_mem_desc(diff_src_memory_p->get_desc()); -} - -template -void eltwise_grad_use_out(const framework::ExecutionContext &ctx, - dnnl::algorithm algorithm) { - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const auto *out = ctx.Input("Out"); - const auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - - platform::ActivationMKLDNNHandler handler( - algorithm, ctx, mkldnn_engine, ctx.GetPlace(), out, dout); - - auto dst_memory_p = handler.AcquireBackwardSrcMemory(out); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); - auto activation_backward_p = handler.AcquireBackwardPrimitive(); - - auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_backward_p->execute(astream, - {{DNNL_ARG_DST, *dst_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); - astream.wait(); - - dx->set_mem_desc(diff_src_memory_p->get_desc()); -} - -template -struct MKLDNNActivationGradFunc : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - eltwise_grad(ctx, algorithm); - } -}; - -template -struct GeluMKLDNNFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const bool approximate = ctx.Attr("approximate"); - if (approximate) { - eltwise_forward(ctx, dnnl::algorithm::eltwise_gelu_tanh); - } else { - eltwise_forward(ctx, dnnl::algorithm::eltwise_gelu_erf); - } - } -}; - -template -struct GeluMKLDNNGradFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const bool approximate = ctx.Attr("approximate"); - if (approximate) { - eltwise_grad(ctx, dnnl::algorithm::eltwise_gelu_tanh); - } else { - eltwise_grad(ctx, dnnl::algorithm::eltwise_gelu_erf); - } - } -}; - template struct SoftplusMKLDNNFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { @@ -182,10 +49,6 @@ struct SoftplusMKLDNNFunctor : public BaseActivationFunctor { } }; -template -using Relu6MKLDNNGradFunctor = - MKLDNNActivationGradFunc; - } // namespace operators } // namespace paddle @@ -199,16 +62,4 @@ namespace ops = paddle::operators; ops::MKLDNNActivationKernel>, \ ops::MKLDNNActivationKernel>); -#define REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(act_type, grad_functor) \ - REGISTER_OP_KERNEL( \ - act_type##_grad, \ - MKLDNN, \ - ::paddle::platform::CPUPlace, \ - ops::MKLDNNActivationGradKernel>, \ - ops::MKLDNNActivationGradKernel< \ - ops::grad_functor>); - REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor); -REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNFunctor); -REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(gelu, GeluMKLDNNGradFunctor); -REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 933ac4f12e3c4ef4a300689729fb4151f35cb82f..959329995613023a8c5d210da91ca11c76fab82d 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -293,103 +293,6 @@ class MatMulV2MKLDNNHandler } }; -template -class ActivationMKLDNNHandler - : public MKLDNNHandlerNoCachingT { - public: - ActivationMKLDNNHandler(dnnl::algorithm algorithm, - const framework::ExecutionContext& ctx, - const dnnl::engine engine, - Place cpu_place, - const framework::Tensor* x) - : platform::MKLDNNHandlerNoCachingT(engine, - cpu_place) { - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; - - if (ctx.Type() == "scale") { - bool bias_after_scale = ctx.Attr("bias_after_scale"); - auto* scale_tensor = ctx.Input("ScaleTensor"); - alpha = (scale_tensor == nullptr) - ? ctx.Attr("scale") - : static_cast(*(scale_tensor->data())); - beta = ctx.Attr("bias"); - // if bias_after_scale == true - // out = scale*X + bias - // else - // out = scale*(X + bias) = scale*X + scale*bias - if (!bias_after_scale) { - beta *= alpha; - } - } else if (ctx.Type() == "clip") { - alpha = ctx.HasInput("Min") ? ctx.Input("Min")->data()[0] - : ctx.Attr("min"); - beta = ctx.HasInput("Max") ? ctx.Input("Max")->data()[0] - : ctx.Attr("max"); - } else { - // paddle uses beta but mkldnn uses alpha for swish - if (algorithm == dnnl::algorithm::eltwise_swish) { - std::swap(alpha, beta); - } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { - alpha = ctx.Attr("threshold"); - } - } - - this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, - algorithm, - x->mem_desc(), - alpha, - beta); - } - - ActivationMKLDNNHandler(dnnl::algorithm algorithm, - const framework::ExecutionContext& ctx, - const dnnl::engine engine, - Place cpu_place, - const framework::Tensor* x, - const Tensor* dout) - : platform::MKLDNNHandlerNoCachingT(engine, - cpu_place) { - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; - - // paddle uses beta but mkldnn uses alpha for swish - if (algorithm == dnnl::algorithm::eltwise_swish) { - std::swap(alpha, beta); - } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { - alpha = ctx.Attr("threshold"); - } - - if (ctx.Type() == "clip_grad") { - alpha = ctx.HasInput("Min") ? ctx.Input("Min")->data()[0] - : ctx.Attr("min"); - beta = ctx.HasInput("Max") ? ctx.Input("Max")->data()[0] - : ctx.Attr("max"); - } - - this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, - algorithm, - x->mem_desc(), - alpha, - beta); - this->AcquireBackwardPrimitiveDescriptor( - algorithm, dout->mem_desc(), x->mem_desc(), alpha, beta); - } - - std::shared_ptr AcquireBackwardSrcMemory( - const framework::Tensor* input) { - const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), - to_void_cast(input_data)); - } -}; - static std::unordered_map GetAttributeMap( std::string act_type) { std::unordered_map attr_map; diff --git a/paddle/phi/kernels/onednn/activation_grad_kernel.cc b/paddle/phi/kernels/onednn/activation_grad_kernel.cc index 4ad073cb00d62f20145a882b9e227a4381c6c08f..9e183abf0287fa595c6905f1a44665eef62ead4a 100644 --- a/paddle/phi/kernels/onednn/activation_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_grad_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/activation_grad_kernel.h" +#include "paddle/phi/kernels/gelu_grad_kernel.h" #include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" @@ -23,16 +24,6 @@ namespace phi { -#define DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \ - template \ - void name##GradKernel(const Context& dev_ctx, \ - const DenseTensor& x, \ - const DenseTensor& dout, \ - DenseTensor* dx) { \ - functor_class functor; \ - functor(dev_ctx, x, dout, 0, 0, dx); \ - } - #define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \ name, functor_class, attr) \ template \ @@ -55,18 +46,6 @@ namespace phi { functor(dev_ctx, out, dout, 0, 0, dx); \ } -#define DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \ - name, functor_class, attr) \ - template \ - void name##GradKernel(const Context& dev_ctx, \ - const DenseTensor& out, \ - const DenseTensor& dout, \ - float attr, \ - DenseTensor* dx) { \ - functor_class functor; \ - functor(dev_ctx, out, dout, attr, 0, dx); \ - } - template void eltwise_grad(const OneDNNContext& dev_ctx, const DenseTensor& x, @@ -158,12 +137,14 @@ using AbsOneDNNGradFunctor = OneDNNActivationGradFunc; template -using ReluOneDNNGradFunctor = - OneDNNActivationGradFunc; +using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< + T, + dnnl::algorithm::eltwise_elu_use_dst_for_bwd>; template -using SwishOneDNNGradFunctor = - OneDNNActivationGradFunc; +using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< + T, + dnnl::algorithm::eltwise_exp_use_dst_for_bwd>; template using HardSwishOneDNNGradFunctor = @@ -174,14 +155,26 @@ using MishOneDNNGradFunctor = OneDNNActivationGradFunc; template -using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< +using GeluTanhOneDNNGradFunctor = + OneDNNActivationGradFunc; + +template +using GeluErfOneDNNGradFunctor = + OneDNNActivationGradFunc; + +template +using ReluOneDNNGradFunctor = + OneDNNActivationGradFunc; + +template +using Relu6OneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< T, - dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>; + dnnl::algorithm::eltwise_clip_v2_use_dst_for_bwd>; template -using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< +using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< T, - dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>; + dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>; template using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< @@ -189,22 +182,21 @@ using SqrtOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>; template -using EluOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< - T, - dnnl::algorithm::eltwise_elu_use_dst_for_bwd>; +using SwishOneDNNGradFunctor = + OneDNNActivationGradFunc; template -using ExpOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< +using TanhOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< T, - dnnl::algorithm::eltwise_exp_use_dst_for_bwd>; + dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>; -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor); -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor); -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, - SigmoidOneDNNGradUseOutFunctor); -DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor); DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Abs, AbsOneDNNGradFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpOneDNNGradUseOutFunctor); DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluOneDNNGradFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, + SigmoidOneDNNGradUseOutFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtOneDNNGradUseOutFunctor); +DEFINE_ONEDNN_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhOneDNNGradUseOutFunctor); DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, ReluOneDNNGradFunctor, @@ -215,6 +207,33 @@ DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_ONEDNN_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishOneDNNGradFunctor, beta); + +template +void EluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + float alpha, + DenseTensor* dx) { + EluOneDNNGradUseOutFunctor functor; + functor(dev_ctx, out, dout, alpha, 0, dx); +} + +template +void GeluGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + bool approximate, + DenseTensor* x_grad) { + if (approximate) { + GeluTanhOneDNNGradFunctor functor; + functor(dev_ctx, x, out_grad, 0, 0, x_grad); + } else { + GeluErfOneDNNGradFunctor functor; + functor(dev_ctx, x, out_grad, 0, 0, x_grad); + } +} + template void HardSwishGradKernel(const Context& dev_ctx, const DenseTensor& x, @@ -228,14 +247,13 @@ void HardSwishGradKernel(const Context& dev_ctx, } template -void EluGradKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& out, - const DenseTensor& dout, - float alpha, - DenseTensor* dx) { - EluOneDNNGradUseOutFunctor functor; - functor(dev_ctx, out, dout, alpha, 0, dx); +void Relu6GradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + float threshold, + DenseTensor* dx) { + Relu6OneDNNGradUseOutFunctor functor; + functor(dev_ctx, out, dout, 0, threshold, dx); } } // namespace phi @@ -254,9 +272,11 @@ PD_REGISTER_KERNEL(relu_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(abs_grad, AbsGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(exp_grad, ExpGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) diff --git a/paddle/phi/kernels/onednn/activation_kernel.cc b/paddle/phi/kernels/onednn/activation_kernel.cc index 40f2d8fd4c49e602480362c68de4a39c73b67ffa..36ba1be724ccf103ff47f65a8873992289fe9f44 100644 --- a/paddle/phi/kernels/onednn/activation_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_kernel.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/activation_kernel.h" +#include "paddle/phi/kernels/gelu_grad_kernel.h" #include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_reuse.h" @@ -91,16 +92,18 @@ template using AbsOneDNNFunctor = OneDNNActivationFunc; template -using ReluOneDNNFunctor = - OneDNNActivationFunc; +using EluOneDNNFunctor = OneDNNActivationFunc; template -using Relu6OneDNNFunctor = - OneDNNActivationFunc; +using ExpOneDNNFunctor = OneDNNActivationFunc; template -using SwishOneDNNFunctor = - OneDNNActivationFunc; +using GeluTanhOneDNNFunctor = + OneDNNActivationFunc; + +template +using GeluErfOneDNNFunctor = + OneDNNActivationFunc; template using HardSwishOneDNNFunctor = @@ -111,41 +114,46 @@ using MishOneDNNFunctor = OneDNNActivationFunc; template -using SigmoidOneDNNFunctor = - OneDNNActivationFunc; +using ReluOneDNNFunctor = + OneDNNActivationFunc; template -using TanhOneDNNFunctor = - OneDNNActivationFunc; +using Relu6OneDNNFunctor = + OneDNNActivationFunc; template -using SqrtOneDNNFunctor = - OneDNNActivationFunc; +using RoundOneDNNFunctor = + OneDNNActivationFunc; template -using EluOneDNNFunctor = OneDNNActivationFunc; +using SigmoidOneDNNFunctor = + OneDNNActivationFunc; template -using ExpOneDNNFunctor = OneDNNActivationFunc; +using SqrtOneDNNFunctor = + OneDNNActivationFunc; template -using RoundOneDNNFunctor = - OneDNNActivationFunc; +using SwishOneDNNFunctor = + OneDNNActivationFunc; + +template +using TanhOneDNNFunctor = + OneDNNActivationFunc; DEFINE_ONEDNN_ACTIVATION_KERNEL(Abs, AbsOneDNNFunctor) -DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor) -DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Exp, ExpOneDNNFunctor) -DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor) +DEFINE_ONEDNN_ACTIVATION_KERNEL(Relu, ReluOneDNNFunctor) DEFINE_ONEDNN_ACTIVATION_KERNEL(Sigmoid, SigmoidOneDNNFunctor) +DEFINE_ONEDNN_ACTIVATION_KERNEL(Sqrt, SqrtOneDNNFunctor) +DEFINE_ONEDNN_ACTIVATION_KERNEL(Tanh, TanhOneDNNFunctor) // round eltwise primitive doesn't support BF16, nor does it support grad DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor) +DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold) -DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha) -DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta) template @@ -159,6 +167,29 @@ void HardSwishKernel(const Context& dev_ctx, functor(dev_ctx, x, threshold, 0, out); } +template +void GeluKernel(const Context& dev_ctx, + const DenseTensor& x, + bool approximate, + DenseTensor* out) { + if (approximate) { + GeluTanhOneDNNFunctor functor; + functor(dev_ctx, x, 0, 0, out); + } else { + GeluErfOneDNNFunctor functor; + functor(dev_ctx, x, 0, 0, out); + } +} + +template +void Relu6Kernel(const Context& dev_ctx, + const DenseTensor& x, + float threshold, + DenseTensor* out) { + Relu6OneDNNFunctor functor; + functor(dev_ctx, x, 0, threshold, out); +} + } // namespace phi PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {} @@ -170,6 +201,7 @@ PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {} PD_REGISTER_ACTIVATION_KERNEL(abs, AbsKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) +PD_REGISTER_ACTIVATION_KERNEL(gelu, GeluKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)