diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index a71fd3fa02d047ad497e4053d7120e5979629079..4909fdc32ba6fc96be6d900cdb4ba0cb735a44f7 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -42,51 +42,6 @@ class MKLDNNActivationKernel } }; -template -class MKLDNNActivationGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - Functor functor; - functor(ctx); - } -}; - -template -void eltwise_grad(const framework::ExecutionContext &ctx, - dnnl::algorithm algorithm) { - auto &dev_ctx = ctx.template device_context(); - const auto &mkldnn_engine = dev_ctx.GetEngine(); - - const auto *x = ctx.Input("X"); - const auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *dx = ctx.Output(framework::GradVarName("X")); - - platform::ActivationMKLDNNHandler handler( - algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, dout); - - auto src_memory_p = handler.AcquireBackwardSrcMemory(x); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); - auto activation_backward_p = handler.AcquireBackwardPrimitive(); - - auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_backward_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); - astream.wait(); - - dx->set_mem_desc(diff_src_memory_p->get_desc()); -} - -template -struct MKLDNNActivationGradFunc : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - eltwise_grad(ctx, algorithm); - } -}; - template struct SoftplusMKLDNNFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { @@ -94,10 +49,6 @@ struct SoftplusMKLDNNFunctor : public BaseActivationFunctor { } }; -template -using Relu6MKLDNNGradFunctor = - MKLDNNActivationGradFunc; - } // namespace operators } // namespace paddle @@ -111,14 +62,4 @@ namespace ops = paddle::operators; ops::MKLDNNActivationKernel>, \ ops::MKLDNNActivationKernel>); -#define REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(act_type, grad_functor) \ - REGISTER_OP_KERNEL( \ - act_type##_grad, \ - MKLDNN, \ - ::paddle::platform::CPUPlace, \ - ops::MKLDNNActivationGradKernel>, \ - ops::MKLDNNActivationGradKernel< \ - ops::grad_functor>); - REGISTER_FWD_ACTIVATION_MKLDNN_KERNEL(softplus, SoftplusMKLDNNFunctor); -REGISTER_GRAD_ACTIVATION_MKLDNN_KERNEL(relu6, Relu6MKLDNNGradFunctor); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ca099cb65d67c04188055dfb1805c916932ae58b..8e10429b9dc813c814c4c464c17b2ee9efe0cb94 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -294,103 +294,6 @@ class MatMulV2MKLDNNHandler } }; -template -class ActivationMKLDNNHandler - : public MKLDNNHandlerNoCachingT { - public: - ActivationMKLDNNHandler(dnnl::algorithm algorithm, - const framework::ExecutionContext& ctx, - const dnnl::engine engine, - Place cpu_place, - const framework::Tensor* x) - : platform::MKLDNNHandlerNoCachingT(engine, - cpu_place) { - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; - - if (ctx.Type() == "scale") { - bool bias_after_scale = ctx.Attr("bias_after_scale"); - auto* scale_tensor = ctx.Input("ScaleTensor"); - alpha = (scale_tensor == nullptr) - ? ctx.Attr("scale") - : static_cast(*(scale_tensor->data())); - beta = ctx.Attr("bias"); - // if bias_after_scale == true - // out = scale*X + bias - // else - // out = scale*(X + bias) = scale*X + scale*bias - if (!bias_after_scale) { - beta *= alpha; - } - } else if (ctx.Type() == "clip") { - alpha = ctx.HasInput("Min") ? ctx.Input("Min")->data()[0] - : ctx.Attr("min"); - beta = ctx.HasInput("Max") ? ctx.Input("Max")->data()[0] - : ctx.Attr("max"); - } else { - // paddle uses beta but mkldnn uses alpha for swish - if (algorithm == dnnl::algorithm::eltwise_swish) { - std::swap(alpha, beta); - } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { - alpha = ctx.Attr("threshold"); - } - } - - this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, - algorithm, - x->mem_desc(), - alpha, - beta); - } - - ActivationMKLDNNHandler(dnnl::algorithm algorithm, - const framework::ExecutionContext& ctx, - const dnnl::engine engine, - Place cpu_place, - const framework::Tensor* x, - const Tensor* dout) - : platform::MKLDNNHandlerNoCachingT(engine, - cpu_place) { - float alpha = ctx.HasAttr("alpha") ? ctx.Attr("alpha") : 0; - float beta = ctx.HasAttr("beta") ? ctx.Attr("beta") : 0; - - // paddle uses beta but mkldnn uses alpha for swish - if (algorithm == dnnl::algorithm::eltwise_swish) { - std::swap(alpha, beta); - } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { - alpha = ctx.Attr("threshold"); - } - - if (ctx.Type() == "clip_grad") { - alpha = ctx.HasInput("Min") ? ctx.Input("Min")->data()[0] - : ctx.Attr("min"); - beta = ctx.HasInput("Max") ? ctx.Input("Max")->data()[0] - : ctx.Attr("max"); - } - - this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, - algorithm, - x->mem_desc(), - alpha, - beta); - this->AcquireBackwardPrimitiveDescriptor( - algorithm, dout->mem_desc(), x->mem_desc(), alpha, beta); - } - - std::shared_ptr AcquireBackwardSrcMemory( - const framework::Tensor* input) { - const T* input_data = input->data(); - return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), - to_void_cast(input_data)); - } -}; - static std::unordered_map GetAttributeMap( std::string act_type) { std::unordered_map attr_map; diff --git a/paddle/phi/kernels/onednn/activation_grad_kernel.cc b/paddle/phi/kernels/onednn/activation_grad_kernel.cc index a8c988a69d3f94459219a62359786b325bd74a76..9e183abf0287fa595c6905f1a44665eef62ead4a 100644 --- a/paddle/phi/kernels/onednn/activation_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_grad_kernel.cc @@ -166,6 +166,11 @@ template using ReluOneDNNGradFunctor = OneDNNActivationGradFunc; +template +using Relu6OneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< + T, + dnnl::algorithm::eltwise_clip_v2_use_dst_for_bwd>; + template using SigmoidOneDNNGradUseOutFunctor = OneDNNActivationGradUseOutFunc< T, @@ -241,6 +246,16 @@ void HardSwishGradKernel(const Context& dev_ctx, functor(dev_ctx, x, dout, threshold, 0, dx); } +template +void Relu6GradKernel(const Context& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + float threshold, + DenseTensor* dx) { + Relu6OneDNNGradUseOutFunctor functor; + functor(dev_ctx, out, dout, 0, threshold, dx); +} + } // namespace phi PD_REGISTER_KERNEL(relu_grad, @@ -261,6 +276,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(gelu_grad, GeluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) diff --git a/paddle/phi/kernels/onednn/activation_kernel.cc b/paddle/phi/kernels/onednn/activation_kernel.cc index b90ea1c61b416f49240319505fbabdad564589f5..36ba1be724ccf103ff47f65a8873992289fe9f44 100644 --- a/paddle/phi/kernels/onednn/activation_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_kernel.cc @@ -119,7 +119,7 @@ using ReluOneDNNFunctor = template using Relu6OneDNNFunctor = - OneDNNActivationFunc; + OneDNNActivationFunc; template using RoundOneDNNFunctor = @@ -154,7 +154,6 @@ DEFINE_ONEDNN_ACTIVATION_KERNEL(Round, RoundOneDNNFunctor) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Elu, EluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, ReluOneDNNFunctor, alpha) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishOneDNNFunctor, threshold) -DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Relu6, Relu6OneDNNFunctor, threshold) DEFINE_ONEDNN_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishOneDNNFunctor, beta) template @@ -182,6 +181,15 @@ void GeluKernel(const Context& dev_ctx, } } +template +void Relu6Kernel(const Context& dev_ctx, + const DenseTensor& x, + float threshold, + DenseTensor* out) { + Relu6OneDNNFunctor functor; + functor(dev_ctx, x, 0, threshold, out); +} + } // namespace phi PD_REGISTER_KERNEL(round, OneDNN, ALL_LAYOUT, phi::RoundKernel, float) {}