From 2b24a8013f9448159099bba47da8f0fd1cc69ed5 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Wed, 4 Aug 2021 17:22:06 +0200 Subject: [PATCH] - Removed manual caching of activation --- paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc | 3 ++- paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 02715cfd3f2..2351e759f2d 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -105,13 +105,14 @@ template void eltwise_grad(const framework::ExecutionContext &ctx, mkldnn::algorithm algorithm) { auto &dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto *x = ctx.Input("X"); const auto *diff_y = ctx.Input(framework::GradVarName("Out")); auto *diff_x = ctx.Output(framework::GradVarName("X")); platform::ActivationMKLDNNHandler handler( - algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X")); + algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, diff_y); auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); diff --git a/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc index ae17048b5d5..2022c1ab910 100644 --- a/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc @@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel { void RunKernel(const framework::ExecutionContext& ctx) const { const auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); @@ -36,11 +37,10 @@ class ScaleMKLDNNKernel : public framework::OpKernel { bool is_inplaced = x->IsSharedBufferWith(*out); platform::ActivationMKLDNNHandler handler( - mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x, - ctx.InputName("X"), is_inplaced); + mkldnn::algorithm::eltwise_linear, ctx, mkldnn_engine, ctx.GetPlace(), x); auto src_memory_p = handler.AcquireSrcMemory(x); - auto dst_memory_p = handler.AcquireDstMemory(out); + auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(out); auto activation_p = handler.AcquireForwardPrimitive(); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); -- GitLab