From ce6a27d91202b1812332f1d816d94943e91abb52 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Wed, 13 Oct 2021 07:01:15 +0200 Subject: [PATCH] fix for matmul_v2 6D x 2D (#36379) --- .../operators/mkldnn/matmul_v2_mkldnn_op.cc | 8 +++---- .../mkldnn/test_matmul_v2_mkldnn_op.py | 21 ++++++++++++++++++- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 57a3c385593..c332b919416 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -148,8 +148,8 @@ class MatMulV2MKLDNNKernel if (x_dims.size() == 1) { x_bd_dims[x_bd_dims.size() - 1] = x_dims[0]; } else if (x_dims.size() == 2) { - x_bd_dims[2] = x_dims[1]; - x_bd_dims[1] = x_dims[0]; + x_bd_dims[x_bd_dims.size() - 1] = x_dims[1]; + x_bd_dims[x_bd_dims.size() - 2] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { x_bd_dims[i] = x_dims[i]; @@ -158,8 +158,8 @@ class MatMulV2MKLDNNKernel if (y_dims.size() == 1) { y_bd_dims[x_bd_dims.size() - 2] = y_dims[0]; } else if (y_dims.size() == 2) { - y_bd_dims[2] = y_dims[1]; - y_bd_dims[1] = y_dims[0]; + y_bd_dims[y_bd_dims.size() - 1] = y_dims[1]; + y_bd_dims[y_bd_dims.size() - 2] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { y_bd_dims[i] = y_dims[i]; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 5cc6651bb0e..994d78126bd 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -235,6 +235,22 @@ class TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp( self.trans_y = True +class TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (1, 1, 2, 1, 8, 9) + self.y_shape = (9, 12) + self.trans_x = False + self.trans_y = False + + +class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (20, 5) + self.y_shape = (1, 2, 1, 5, 11) + self.trans_x = False + self.trans_y = False + + # BF16 TESTS def create_bf16_test_class(parent): @OpTestTool.skip_if_not_cpu_bf16() @@ -274,7 +290,8 @@ def create_bf16_test_class(parent): 2: [1, 0], 3: [0, 2, 1], 4: [0, 1, 3, 2], - 5: [0, 1, 2, 4, 3] + 5: [0, 1, 2, 4, 3], + 6: [0, 1, 2, 3, 5, 4] } # expand vector so it will be a valid matrix for multiplication @@ -370,6 +387,8 @@ create_bf16_test_class(TestMatMulV2Matrix3DXVectorOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeXTransposeYOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeY2OneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp) +create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp) +create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp) if __name__ == "__main__": paddle.enable_static() -- GitLab