From e80acff3a554a4d2394ff73513eca7fa772f95d2 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Wed, 15 Sep 2021 14:05:12 +0200 Subject: [PATCH] added fix for matmul and support for 6 rank tensor (#35740) --- paddle/fluid/operators/matmul_op.cc | 18 ++++++++++++++++++ paddle/fluid/platform/mkldnn_helper.h | 2 ++ 2 files changed, 20 insertions(+) diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index c0d813ccc21..4e435660ff6 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -655,6 +655,24 @@ class MatMulOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument("reshape_out supported rank is 3, " "received %d", reshape_out_size)); + + auto it = std::find(reshape_out.begin(), reshape_out.end(), -1); + + // if "-1" is present then one of reshape dims must be infered + if (it != reshape_out.end()) { + int index = std::distance(reshape_out.begin(), it); + + auto ddim_out_vec = framework::vectorize(ddim_out); + + int ddim_out_product = + std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1, + std::multiplies()); + int reshape_out_product = std::accumulate( + reshape_out.begin(), reshape_out.end(), -1, std::multiplies()); + + reshape_out[index] = ddim_out_product / reshape_out_product; + } + framework::DDim shape_out = ddim_out.transpose(transpose_out).reshape(reshape_out); context->SetOutputDim("Out", shape_out); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 01c2d95a078..f14f92cb51f 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -358,6 +358,8 @@ inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size, } else if (data_format == MKLDNNMemoryFormat::nhwc) { return MKLDNNMemoryFormat::ndhwc; } + } else if (dims_size == 6) { + return MKLDNNMemoryFormat::abcdef; } return data_format; } -- GitLab