diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index a5f287fb78d20d2adffa3536b3b04347809e81ad..e1b829b03a4d49f6e8efe0509ab93035b466895c 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -78,6 +78,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { scale_x, scale_y, scale_o, + true, get_post_ops(ctx)); // oneDNN's binary is optimized for broadcasting y into x, so in other case @@ -126,7 +127,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { binary_prim->execute(astream, args); astream.wait(); - z->set_mem_desc(dst_memory->get_desc()); + if (handler.use_broadcasting_hack == false) { + z->set_mem_desc(dst_memory->get_desc()); + } else { + auto dims = dst_memory->get_desc().dims(); + dims.insert(dims.begin(), x->dims()[0]); + dims[1] /= dims[0]; + z->set_mem_desc(dst_memory->get_desc().reshape(dims)); + } } }; @@ -210,7 +218,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel { dx, 1.0f, 1.0f, - 1.0f); + 1.0f, + false); const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout); const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y); @@ -276,7 +285,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel { nullptr, 1.0f, 1.0f, - 1.0f); + 1.0f, + false); src_1_memory = binary_handler.AcquireSecondSrcMemory(x); @@ -291,7 +301,8 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel { nullptr, 1.0f, 1.0f, - 1.0f); + 1.0f, + false); post_op_memory = post_op_binary_handler.AcquireSrcMemory(y); @@ -310,6 +321,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel { -1.0f, 1.0f, 1.0f, + false, po); src_1_memory = binary_handler.AcquireSecondSrcMemory(out); diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 4775a694e1f1af6857de3b82446eecccc0c3e2e2..cd8c076b28503c2dd76493fd8913c67d86816a58 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -825,6 +825,7 @@ class ReorderOneDNNHandler { template class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT { public: + bool use_broadcasting_hack; BinaryOneDNNHandler(const dnnl::algorithm algo, const int axis, const dnnl::engine engine, @@ -835,15 +836,17 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT { float scale_x, float scale_y, float scale_out, + bool allow_hack, const dnnl::post_ops& post_ops = dnnl::post_ops{}) : OneDNNHandlerNoCachingT(engine, cpu_place) { + use_broadcasting_hack = false; const auto src_x_tz = vectorize(x->dims()); const auto src_y_tz = vectorize(y->dims()); // if output tensor(z) is nullptr then we are computing into oneDNN // managed buffer auto rankdiff = x->dims().size() - y->dims().size(); - const auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz) - : vectorize(out->dims()); + auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz) + : vectorize(out->dims()); auto src0_md = x->mem_desc(); auto src1_md = y->mem_desc(); @@ -870,12 +873,48 @@ class BinaryOneDNNHandler : public OneDNNHandlerNoCachingT { } src0_md = src0_md.reshape(dims0_ex); } - const auto dst_md = - memory::desc(dst_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops); + // Workaround for U2++ model which deletes first tensor dimensions to enable + // optimized oneDNNs broadcasting. Output tensor is reshaped back afterwards + // at the end of the kernel, after the computation + if (allow_hack && dst_tz.size() == 4 && + src0_md.dims()[2] != src1_md.dims()[2]) { + auto are_strides_plain = [](int64_t* strides, int ndims) { + for (int i = 0; i < ndims - 1; ++i) { + if (strides[i] < strides[i + 1]) { + return false; + } + } + return true; + }; + + auto src0_strides = src0_md.data.format_desc.blocking.strides; + auto src1_strides = src1_md.data.format_desc.blocking.strides; + auto src0_dims = src0_md.dims(); + auto src1_dims = src1_md.dims(); + + bool can_squeeze = src0_dims[0] == src1_dims[0] && + src0_dims[1] == src1_dims[1] && + src0_dims[3] == src1_dims[3]; + + if (can_squeeze && are_strides_plain(src0_strides, 4) && + are_strides_plain(src1_strides, 4)) { + src0_dims[1] *= dst_tz[0]; + src1_dims[1] *= dst_tz[0]; + dst_tz[1] *= dst_tz[0]; + dst_tz.erase(dst_tz.begin()); + src0_md = src0_md.reshape({src0_dims.begin() + 1, src0_dims.end()}); + src1_md = src1_md.reshape({src1_dims.begin() + 1, src1_dims.end()}); + use_broadcasting_hack = true; + } + } + + auto dst_md = + memory::desc(dst_tz, OneDNNGetDataType(), OneDNNMemoryFormat::any); + if (x->numel() < y->numel()) { if (algo == dnnl::algorithm::binary_sub) { attributes = CreateAttributes(