diff --git a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc index 02715cfd3f2f824010fac9e3e08a00b80f116ca2..2351e759f2d8ec42ff64090162e300126fbb0f0a 100644 --- a/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc @@ -105,13 +105,14 @@ template void eltwise_grad(const framework::ExecutionContext &ctx, mkldnn::algorithm algorithm) { auto &dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto *x = ctx.Input("X"); const auto *diff_y = ctx.Input(framework::GradVarName("Out")); auto *diff_x = ctx.Output(framework::GradVarName("X")); platform::ActivationMKLDNNHandler handler( - algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X")); + algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, diff_y); auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); diff --git a/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc index ae17048b5d568baf4722e63299c9ef2ca3fb6bae..2022c1ab910b4199ade41c91997de3b34cd30109 100644 --- a/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/scale_mkldnn_op.cc @@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel { void RunKernel(const framework::ExecutionContext& ctx) const { const auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); auto* x = ctx.Input("X"); auto* out = ctx.Output("Out"); @@ -36,11 +37,10 @@ class ScaleMKLDNNKernel : public framework::OpKernel { bool is_inplaced = x->IsSharedBufferWith(*out); platform::ActivationMKLDNNHandler handler( - mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x, - ctx.InputName("X"), is_inplaced); + mkldnn::algorithm::eltwise_linear, ctx, mkldnn_engine, ctx.GetPlace(), x); auto src_memory_p = handler.AcquireSrcMemory(x); - auto dst_memory_p = handler.AcquireDstMemory(out); + auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(out); auto activation_p = handler.AcquireForwardPrimitive(); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();