diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 44edc224795706621050fb5b2c22037db86fc9bd..8630515a9fdafb55164724f1eea782db3df48e7c 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -83,9 +83,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx, const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); + auto *out = ctx.Output("Out"); - bool is_inplaced = x->IsSharedBufferWith(*y); + bool is_inplaced = x->IsSharedBufferWith(*out); platform::ActivationMKLDNNHandler handler(algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x); @@ -94,9 +94,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx, std::shared_ptr dst_memory_p = nullptr; if (is_inplaced) { dst_memory_p = src_memory_p; - y->mutable_data(ctx.GetPlace()); + out->mutable_data(ctx.GetPlace()); } else { - dst_memory_p = handler.AcquireDstMemory(y); + dst_memory_p = handler.AcquireDstMemory(out); } auto activation_p = handler.AcquireForwardPrimitive(); @@ -105,8 +105,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx, astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}}); astream.wait(); - y->set_layout(DataLayout::kMKLDNN); - y->set_format(GetMKLDNNFormat(*dst_memory_p)); + out->set_layout(DataLayout::kMKLDNN); + out->set_format(GetMKLDNNFormat(*dst_memory_p)); } template @@ -116,15 +116,15 @@ void eltwise_grad(const framework::ExecutionContext &ctx, const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto *x = ctx.Input("X"); - const auto *diff_y = ctx.Input(framework::GradVarName("Out")); - auto *diff_x = ctx.Output(framework::GradVarName("X")); + const auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); platform::ActivationMKLDNNHandler handler(algorithm, ctx, mkldnn_engine, - ctx.GetPlace(), x, diff_y); + ctx.GetPlace(), x, dout); auto src_memory_p = handler.AcquireBackwardSrcMemory(x); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); auto activation_backward_p = handler.AcquireBackwardPrimitive(); auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); @@ -134,8 +134,37 @@ void eltwise_grad(const framework::ExecutionContext &ctx, {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); astream.wait(); - diff_x->set_layout(DataLayout::kMKLDNN); - diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p)); + dx->set_layout(DataLayout::kMKLDNN); + dx->set_format(GetMKLDNNFormat(*diff_src_memory_p)); +} + +template +void eltwise_grad_use_out(const framework::ExecutionContext &ctx, + dnnl::algorithm algorithm) { + auto &dev_ctx = ctx.template device_context(); + const auto &mkldnn_engine = dev_ctx.GetEngine(); + + const auto *out = ctx.Input("Out"); + const auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + + platform::ActivationMKLDNNHandler handler(algorithm, ctx, mkldnn_engine, + ctx.GetPlace(), out, dout); + + auto dst_memory_p = handler.AcquireBackwardSrcMemory(out); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); + auto activation_backward_p = handler.AcquireBackwardPrimitive(); + + auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + activation_backward_p->execute(astream, + {{DNNL_ARG_DST, *dst_memory_p}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, + {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); + astream.wait(); + + dx->set_layout(DataLayout::kMKLDNN); + dx->set_format(GetMKLDNNFormat(*diff_src_memory_p)); } template @@ -152,6 +181,13 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor { } }; +template +struct MKLDNNActivationGradUseOutFunc : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + eltwise_grad_use_out(ctx, algorithm); + } +}; + template struct GeluMKLDNNFunctor : public BaseActivationFunctor { void operator()(const framework::ExecutionContext &ctx) const { @@ -217,6 +253,9 @@ using AbsMKLDNNFunctor = MKLDNNActivationFunc; template using EluMKLDNNFunctor = MKLDNNActivationFunc; +template +using ExpMKLDNNFunctor = MKLDNNActivationFunc; + template using ReluMKLDNNGradFunctor = MKLDNNActivationGradFunc; @@ -234,24 +273,29 @@ using HardSwishMKLDNNGradFunctor = MKLDNNActivationGradFunc; template -using SigmoidMKLDNNGradFunctor = - MKLDNNActivationGradFunc; +using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc< + T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>; template -using TanhMKLDNNGradFunctor = - MKLDNNActivationGradFunc; +using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc< + T, dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>; template -using SqrtMKLDNNGradFunctor = - MKLDNNActivationGradFunc; +using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc< + T, dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>; template using AbsMKLDNNGradFunctor = MKLDNNActivationGradFunc; template -using EluMKLDNNGradFunctor = - MKLDNNActivationGradFunc; +using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc< + T, dnnl::algorithm::eltwise_elu_use_dst_for_bwd>; + +template +using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc< + T, dnnl::algorithm::eltwise_exp_use_dst_for_bwd>; + } // namespace operators } // namespace paddle @@ -281,9 +325,10 @@ namespace ops = paddle::operators; __macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \ __macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \ __macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \ - __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \ + __macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor); \ __macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \ - __macro(elu, EluMKLDNNFunctor, EluMKLDNNGradFunctor); + __macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \ + __macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor); FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL); REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor, @@ -291,9 +336,9 @@ REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor, REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor, - SigmoidMKLDNNGradFunctor); + SigmoidMKLDNNGradUseOutFunctor); REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor, - SqrtMKLDNNGradFunctor); + SqrtMKLDNNGradUseOutFunctor); namespace ops = paddle::operators; REGISTER_OP_KERNEL( diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index 6ee266a93d56a280cf9efbbf3572885651d1962e..8af2101346fec258e58680edc84bd0f2871e0d31 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -349,6 +349,16 @@ class TestMKLDNNEluCustomAlpha(TestMKLDNNEluDefaultAlpha): self.alpha = 2.5 +class TestMKLDNNExpOp(TestActivation): + def setUp(self): + self.op_type = "exp" + x = np.random.random((5, 5, 4)).astype("float32") + + self.inputs = {'X': x} + self.attrs = {'use_mkldnn': True} + self.outputs = {'Out': np.exp(x)} + + # Check if primitives already exist in backward class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase): def setUp(self):