diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 57a3c385593160d57d70516f4f6ab1243038b3ac..c332b9194164ea3be52cf793febf90f7aea679c6 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -148,8 +148,8 @@ class MatMulV2MKLDNNKernel if (x_dims.size() == 1) { x_bd_dims[x_bd_dims.size() - 1] = x_dims[0]; } else if (x_dims.size() == 2) { - x_bd_dims[2] = x_dims[1]; - x_bd_dims[1] = x_dims[0]; + x_bd_dims[x_bd_dims.size() - 1] = x_dims[1]; + x_bd_dims[x_bd_dims.size() - 2] = x_dims[0]; } else { for (size_t i = 0; i < x_dims.size(); ++i) { x_bd_dims[i] = x_dims[i]; @@ -158,8 +158,8 @@ class MatMulV2MKLDNNKernel if (y_dims.size() == 1) { y_bd_dims[x_bd_dims.size() - 2] = y_dims[0]; } else if (y_dims.size() == 2) { - y_bd_dims[2] = y_dims[1]; - y_bd_dims[1] = y_dims[0]; + y_bd_dims[y_bd_dims.size() - 1] = y_dims[1]; + y_bd_dims[y_bd_dims.size() - 2] = y_dims[0]; } else { for (size_t i = 0; i < y_dims.size(); ++i) { y_bd_dims[i] = y_dims[i]; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 5cc6651bb0ec8e4a25a6c017e9e7e44e2a294b04..994d78126bda5852a07cd04cbde82585ea739631 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -235,6 +235,22 @@ class TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp( self.trans_y = True +class TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (1, 1, 2, 1, 8, 9) + self.y_shape = (9, 12) + self.trans_x = False + self.trans_y = False + + +class TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (20, 5) + self.y_shape = (1, 2, 1, 5, 11) + self.trans_x = False + self.trans_y = False + + # BF16 TESTS def create_bf16_test_class(parent): @OpTestTool.skip_if_not_cpu_bf16() @@ -274,7 +290,8 @@ def create_bf16_test_class(parent): 2: [1, 0], 3: [0, 2, 1], 4: [0, 1, 3, 2], - 5: [0, 1, 2, 4, 3] + 5: [0, 1, 2, 4, 3], + 6: [0, 1, 2, 3, 5, 4] } # expand vector so it will be a valid matrix for multiplication @@ -370,6 +387,8 @@ create_bf16_test_class(TestMatMulV2Matrix3DXVectorOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeXTransposeYOneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrixTransposeY2OneDNNOp) create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp) +create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp) +create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp) if __name__ == "__main__": paddle.enable_static()