From 1e8432f24b846f3bfce8490cff1374aff6e96bfc Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 6 Jan 2022 02:57:57 +0100 Subject: [PATCH] [CHERRY-PICK] Fix for matmul_v2 oneDNN op broadcasting when inputs dims have different lengths (#38733) * fix for matmul_v2 broadcasting * resolved conflicts --- .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 68 ++++++++++++------- .../mkldnn/test_matmul_v2_mkldnn_op.py | 35 ++++++++++ 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index c332b919416..ba3ce00547a 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -152,7 +152,7 @@ class MatMulV2MKLDNNKernel x_bd_dims[x_bd_dims.size() - 2] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { - x_bd_dims[i] = x_dims[i]; + x_bd_dims[x_bd_dims.size() - x_dims.size() + i] = x_dims[i]; } } if (y_dims.size() == 1) { @@ -162,20 +162,21 @@ class MatMulV2MKLDNNKernel y_bd_dims[y_bd_dims.size() - 2] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { - y_bd_dims[i] = y_dims[i]; + y_bd_dims[y_bd_dims.size() - y_dims.size() + i] = y_dims[i]; } } - if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2) { - for (size_t i = 0; i < x_dims.size() - 2; ++i) { + if (x_dims.size() > 2 && y_dims.size() > 2) { + for (size_t i = 0; i < x_bd_dims.size() - 2; ++i) { PADDLE_ENFORCE_EQ( - x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, - paddle::platform::errors::InvalidArgument( - "Tensor dimensions are incorrect for broadcasting." - "Dimensions in X and Y must be same or equal to 1, but " - "received x_dim[%d]=%d and y_dims[%d]= %d", - i, x_dims[i], i, y_dims[i])); - out_dims[i] = std::max(x_dims[i], y_dims[i]); + x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] == 1 || + y_bd_dims[i] == 1, + true, paddle::platform::errors::InvalidArgument( + "Tensor dimensions are incorrect for broadcasting." + "Dimensions in X and Y must be same or equal to 1, but " + "received x_dim[%d]=%d and y_dims[%d]= %d", + i, x_bd_dims[i], i, y_bd_dims[i])); + out_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]); } out->Resize(make_ddim(out_dims)); } @@ -237,11 +238,11 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { dy_tmp->mutable_data(ctx.GetPlace()); } - void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx, - const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine onednn_engine, - const Tensor* dx_tmp, Tensor* dx, - std::vector dx_dims) const { + void ReduceSumForMatmulGradOutput( + const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, + const dnnl::engine onednn_engine, const Tensor* dx_tmp, Tensor* dx, + std::vector& dx_dims, + const std::vector& squeezed_dims) const { paddle::platform::ReductionMKLDNNHandler handler( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, ctx.GetPlace(), dx_tmp, dx, dx_dims); @@ -257,6 +258,19 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { reduction_p->execute(astream, reduction_args); astream.wait(); + + dx->set_format(paddle::platform::GetMKLDNNFormat( + dst_memory_p->get_desc().reshape(squeezed_dims))); + } + + std::vector ExtendDimsWithOnes(const std::vector& dims, + int new_size) const { + std::vector new_dims(new_size, 1); + for (size_t i = 0; i < dims.size(); ++i) { + new_dims[new_size - dims.size() + i] = dims[i]; + } + + return new_dims; } void RunKernel(const ExecutionContext& ctx) const { @@ -295,8 +309,14 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { bool trans_y = ctx.Attr("trans_y"); auto dout_dims = vectorize(dout->dims()); - int ndims = std::max(x->dims().size(), y->dims().size()); - ndims = std::max(ndims, 3); + size_t ndims = std::max(x->dims().size(), y->dims().size()); + ndims = std::max(ndims, 3); + + if (x_dims.size() != ndims) { + x_dims = ExtendDimsWithOnes(x_dims, ndims); + } else if (y_dims.size() != ndims) { + y_dims = ExtendDimsWithOnes(y_dims, ndims); + } // in broadcasting scenario new memory is required because // reduce sum must be calculated upon broadcasted dims @@ -340,21 +360,21 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { if (x_dims != dx_bd_dims) { ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx, - x_dims); + x_dims, + paddle::framework::vectorize(x->dims())); } else { *dx = std::move(dx_tmp); } if (y_dims != dy_bd_dims) { ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy, - y_dims); + y_dims, + paddle::framework::vectorize(y->dims())); } else { *dy = std::move(dy_tmp); } - dx->set_layout(paddle::framework::DataLayout::kMKLDNN); - dx->set_format(x->format()); - dy->set_layout(paddle::framework::DataLayout::kMKLDNN); - dy->set_format(y->format()); + dx->Resize(x->dims()); + dy->Resize(y->dims()); } }; } // anonymous namespace diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 994d78126bd..2fe28c934b1 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -251,6 +251,41 @@ class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): self.trans_y = False +class TestMatMulV2MatrixXMatrix4Dx3DTransposeXOneDNNOp( + TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (5, 4, 15, 10) + self.y_shape = (1, 15, 20) + self.trans_x = True + self.trans_y = False + + +class TestMatMulV2MatrixXMatrix3Dx4DTransposeYOneDNNOp( + TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (2, 10, 15) + self.y_shape = (4, 2, 20, 15) + self.trans_x = False + self.trans_y = True + + +class TestMatMulV2MatrixXMatrix5Dx3DTransposeXTransposeYOneDNNOp( + TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (4, 3, 2, 15, 10) + self.y_shape = (1, 20, 15) + self.trans_x = True + self.trans_y = True + + +class TestMatMulV2MatrixXMatrix3Dx4DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (1, 1, 32, 16) + self.y_shape = (16, 16, 16) + self.trans_x = False + self.trans_y = False + + # BF16 TESTS def create_bf16_test_class(parent): @OpTestTool.skip_if_not_cpu_bf16() -- GitLab