From 07790ba13eeeafa45da0b7aa2348db0042ffd7d7 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Sat, 19 Dec 2020 09:05:16 +0100 Subject: [PATCH] [oneDNN] Reimplemented elementwise_add grad (#29747) * - Reimplemented elementwise_add grad - lint * - fix after review * - Fix to fix after review --- .../mkldnn/elementwise_add_mkldnn_op.cc | 42 +++++++++++++------ paddle/fluid/platform/mkldnn_reuse.h | 7 ++-- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 54902015ce1..db634813230 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -33,27 +33,45 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { ElemwiseGradKernel::Compute(ctx); using Tensor = framework::Tensor; + auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - auto set_mkldnn_format = [](Tensor* in, const Tensor* out) { - in->set_layout(DataLayout::kMKLDNN); - in->set_format(out->format()); - }; + auto tz = paddle::framework::vectorize(dout->dims()); + memory::data_type dout_type = framework::ToMKLDNNDataType(dout->type()); + std::string key = platform::CreateKey(dev_ctx, tz, dout->format(), + dout->format(), dout_type); + platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, dev_ctx, + onednn_engine, key); + + mkldnn::stream astream(onednn_engine); + auto reorder_src_memory_p = handler.AcquireSrcMemory( + dout->format(), platform::to_void_cast(dout->data())); - // TODO(jczaja): Double check if vcopy works for blocked data - auto blas = math::GetBlas(ctx); if (dx) { - blas.VCOPY(dout->numel(), dout->data(), - dx->mutable_data(ctx.GetPlace())); - set_mkldnn_format(dx, dout); + auto reorder_dst_memory_p = + handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace()); + auto reorder_p = + handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); + platform::RecordEvent record_reorder("int_reorder", + platform::EventRole::kUniqueOp); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); } if (dy) { - blas.VCOPY(dout->numel(), dout->data(), - dy->mutable_data(ctx.GetPlace())); - set_mkldnn_format(dy, dout); + auto reorder_dst_memory_p = + handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace()); + auto reorder_p = + handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); + platform::RecordEvent record_reorder("int_reorder", + platform::EventRole::kUniqueOp); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); } } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index c053815aea7..58a8f6263ff 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -1054,13 +1054,14 @@ class ReorderMKLDNNHandler : public MKLDNNHandler { std::static_pointer_cast(dev_ctx_.GetBlob(local_key)); if (mem_p == nullptr) { auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_, fmt); - - auto dst_data = output->mutable_data(place, vtype_); + auto dst_data = output->mutable_data(place, vtype_, dst_md.get_size()); mem_p = std::make_shared(dst_md, engine_, dst_data); dev_ctx_.SetBlob(local_key, mem_p); } else { - auto dst_data = output->mutable_data(place, vtype_); + // Even if memory object exists , we may be using it for diffrent tensor + auto dst_data = + output->mutable_data(place, vtype_, mem_p->get_desc().get_size()); mem_p->set_data_handle(dst_data); } return mem_p; -- GitLab