diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 5cb6ae34dcecf0ab82f4b6b7c4ceae96deb59780..a8d4b852ca3c248d31b247b4c2f6d2aa09647e51 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -295,7 +295,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { x_bd_dims[x_bd_dims.size() - 2] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { - x_bd_dims[i] = x_dims[i]; + x_bd_dims[x_bd_dims.size() - x_dims.size() + i] = x_dims[i]; } } if (y_dims.size() == 1) { @@ -305,21 +305,21 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel { y_bd_dims[y_bd_dims.size() - 2] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { - y_bd_dims[i] = y_dims[i]; + y_bd_dims[y_bd_dims.size() - y_dims.size() + i] = y_dims[i]; } } - if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2 && - !IsOutputFused(ctx)) { - for (size_t i = 0; i < x_dims.size() - 2; ++i) { + if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) { + for (size_t i = 0; i < x_bd_dims.size() - 2; ++i) { PADDLE_ENFORCE_EQ( - x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, - paddle::platform::errors::InvalidArgument( - "Tensor dimensions are incorrect for broadcasting." - "Dimensions in X and Y must be same or equal to 1, but " - "received x_dim[%d]=%d and y_dims[%d]= %d", - i, x_dims[i], i, y_dims[i])); - out_dims[i] = std::max(x_dims[i], y_dims[i]); + x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] == 1 || + y_bd_dims[i] == 1, + true, paddle::platform::errors::InvalidArgument( + "Tensor dimensions are incorrect for broadcasting." + "Dimensions in X and Y must be same or equal to 1, but " + "received x_dim[%d]=%d and y_dims[%d]= %d", + i, x_bd_dims[i], i, y_bd_dims[i])); + out_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]); } out->Resize(make_ddim(out_dims)); } @@ -382,11 +382,11 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { dy_tmp->mutable_data(ctx.GetPlace()); } - void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx, - const MKLDNNDeviceContext& dev_ctx, - const dnnl::engine onednn_engine, - const Tensor* dx_tmp, Tensor* dx, - std::vector dx_dims) const { + void ReduceSumForMatmulGradOutput( + const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, + const dnnl::engine onednn_engine, const Tensor* dx_tmp, Tensor* dx, + std::vector& dx_dims, + const std::vector& squeezed_dims) const { paddle::platform::ReductionMKLDNNHandler handler( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, ctx.GetPlace(), dx_tmp, dx, dx_dims); @@ -402,6 +402,19 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { reduction_p->execute(astream, reduction_args); astream.wait(); + + dx->set_format(paddle::platform::GetMKLDNNFormat( + dst_memory_p->get_desc().reshape(squeezed_dims))); + } + + std::vector ExtendDimsWithOnes(const std::vector& dims, + int new_size) const { + std::vector new_dims(new_size, 1); + for (size_t i = 0; i < dims.size(); ++i) { + new_dims[new_size - dims.size() + i] = dims[i]; + } + + return new_dims; } void RunKernel(const ExecutionContext& ctx) const { @@ -440,8 +453,14 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { bool trans_y = ctx.Attr("trans_y"); auto dout_dims = vectorize(dout->dims()); - int ndims = std::max(x->dims().size(), y->dims().size()); - ndims = std::max(ndims, 3); + size_t ndims = std::max(x->dims().size(), y->dims().size()); + ndims = std::max(ndims, 3); + + if (x_dims.size() != ndims) { + x_dims = ExtendDimsWithOnes(x_dims, ndims); + } else if (y_dims.size() != ndims) { + y_dims = ExtendDimsWithOnes(y_dims, ndims); + } // in broadcasting scenario new memory is required because // reduce sum must be calculated upon broadcasted dims @@ -481,21 +500,21 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel { if (x_dims != dx_bd_dims) { ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx, - x_dims); + x_dims, + paddle::framework::vectorize(x->dims())); } else { *dx = std::move(dx_tmp); } if (y_dims != dy_bd_dims) { ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy, - y_dims); + y_dims, + paddle::framework::vectorize(y->dims())); } else { *dy = std::move(dy_tmp); } - dx->set_layout(paddle::framework::DataLayout::kMKLDNN); - dx->set_format(x->format()); - dy->set_layout(paddle::framework::DataLayout::kMKLDNN); - dy->set_format(y->format()); + dx->Resize(x->dims()); + dy->Resize(y->dims()); } private: diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 5dd1795818c2bf11b7fbf7b89f3dc8bc626ccf5d..25701b797ec4a1cbdc540dad703c83101f2dd1ec 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -262,6 +262,41 @@ class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): self.trans_y = False +class TestMatMulV2MatrixXMatrix4Dx3DTransposeXOneDNNOp( + TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (5, 4, 15, 10) + self.y_shape = (1, 15, 20) + self.trans_x = True + self.trans_y = False + + +class TestMatMulV2MatrixXMatrix3Dx4DTransposeYOneDNNOp( + TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (2, 10, 15) + self.y_shape = (4, 2, 20, 15) + self.trans_x = False + self.trans_y = True + + +class TestMatMulV2MatrixXMatrix5Dx3DTransposeXTransposeYOneDNNOp( + TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (4, 3, 2, 15, 10) + self.y_shape = (1, 20, 15) + self.trans_x = True + self.trans_y = True + + +class TestMatMulV2MatrixXMatrix3Dx4DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (1, 1, 32, 16) + self.y_shape = (16, 16, 16) + self.trans_x = False + self.trans_y = False + + # BF16 TESTS def create_bf16_test_class(parent): @OpTestTool.skip_if_not_cpu_bf16()