From 39a5424ed129e90283a00e29481cdb4983b6e334 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 9 Mar 2021 03:58:23 +0100 Subject: [PATCH] [oneDNN] elementwise add bf16 grad kernel with broadcasting (#31385) --- .../operators/elementwise/elementwise_op.h | 5 ++- .../mkldnn/elementwise_add_mkldnn_op.cc | 31 +++++++++---- paddle/fluid/platform/mkldnn_reuse.h | 44 +++++++++++++++++++ .../test_elementwise_add_bf16_mkldnn_op.py | 41 +++++++++++++++-- .../mkldnn/test_elementwise_add_mkldnn_op.py | 12 +++-- .../unittests/mkldnn/test_reshape_bf16_op.py | 7 +-- .../paddle/fluid/tests/unittests/op_test.py | 10 +++++ 7 files changed, 131 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index a09fe4b676..6ec73b02ad 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -277,7 +277,10 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN // If broadcasting is needed, use native implementation auto CanMKLDNNElementwiseAddGradBeUsed = [&]() { - return (ctx.Input("X")->dims() == ctx.Input("Y")->dims()); + auto dx_dims = ctx.Input("X")->dims(); + auto dy_dims = ctx.Input("Y")->dims(); + // No broadcast or broadcasting of data on inner dims is supported + return (dx_dims[dx_dims.size() - 1] == dy_dims[dy_dims.size() - 1]); }; if (this->CanMKLDNNBeUsed(ctx, input_data_type) && diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 13acd3fa63..4db4adfe9e 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -64,14 +64,29 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { } if (dy) { - auto reorder_dst_memory_p = - handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace()); - auto reorder_p = - handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); - platform::RecordEvent record_reorder("int_reorder", - platform::EventRole::kUniqueOp); - reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); - astream.wait(); + // Direct copy + if (dout->dims() == dy->dims()) { + auto reorder_dst_memory_p = + handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace()); + auto reorder_p = + handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); + platform::RecordEvent record_reorder("int_reorder", + platform::EventRole::kUniqueOp); + reorder_p->execute(astream, *reorder_src_memory_p, + *reorder_dst_memory_p); + astream.wait(); + } else { + // Broadcasting + platform::ReductionMKLDNNHandler handler_sum( + dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, + ctx.GetPlace(), dout, dy, + ctx.InputName(framework::GradVarName("Out"))); + auto dy_memory_p = handler_sum.AcquireDstMemory(dy); + auto reduction_p = handler_sum.AcquireForwardPrimitive(); + reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, + {DNNL_ARG_DST, *dy_memory_p}}); + astream.wait(); + } } } }; diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 3e02a8672c..0503c3f71a 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include #include #include @@ -621,6 +622,49 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { } }; +template +class ReductionMKLDNNHandler + : public platform::MKLDNNHandlerT { + public: + ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p, + const float eps, const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine engine, platform::Place cpu_place, + const Tensor* x, const Tensor* y, + const std::string& uniq_name) + : platform::MKLDNNHandlerT( + dev_ctx, engine, cpu_place, + platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), + uniq_name, + (std::to_string(static_cast(algo))))) { + if (!this->isCached()) { + PADDLE_ENFORCE_EQ( + x->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument("Wrong layout set for X tensor.")); + PADDLE_ENFORCE_NE( + x->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for X tensor.")); + + const auto src_tz = framework::vectorize(x->dims()); + const auto dst_tz = framework::vectorize(y->dims()); + + // For oneDNN dimensionality should match so we need to + // extend Y tensor dims with values of 1 (before and after pattern) + int j = 0; + std::vector dst_tz_ex(src_tz.size(), 1); + for (size_t i = 0; i < src_tz.size(); ++i) { + dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; + } + + const auto src_md = dnnl::memory::desc( + src_tz, platform::MKLDNNGetDataType(), x->format()); + const auto dst_md = memory::desc( + dst_tz_ex, platform::MKLDNNGetDataType(), x->format()); + + this->AcquireForwardPrimitiveDescriptor(algo, src_md, dst_md, p, eps); + } + } +}; + template class ActivationMKLDNNHandler : public MKLDNNHandlerT