diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 6f188636ef101f473e9112bde9bf17e1bd0e2515..bb2a0766757b2764245a3682b2a6f2cf326de72b 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -168,6 +168,9 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, T alpha, const T *A, const T *B, T beta, T *C) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; #ifdef PADDLE_WITH_LIBXSMM if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans && transB == CblasNoTrans) { @@ -175,16 +178,10 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, // Note: SMM use ColMajor const char transa = 'N'; const char transb = 'N'; - const int lda = M; - const int ldb = K; - const int ldc = M; CBlas::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, &beta, C, &ldc); } else { #endif - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); #ifdef PADDLE_WITH_LIBXSMM diff --git a/paddle/fluid/operators/math/math_function_test.cc b/paddle/fluid/operators/math/math_function_test.cc index 71103be492939c7a434f44f165087206a2b2972a..078dd448c385dbb8a00025ee2ba08d0c41a4730a 100644 --- a/paddle/fluid/operators/math/math_function_test.cc +++ b/paddle/fluid/operators/math/math_function_test.cc @@ -75,26 +75,25 @@ void MklSmmCompare(int m, int n, int k) { for (int i = 0; i < mat_b.numel(); ++i) { B[i] = static_cast(i); } + // lda,ldb,ldc follow RowMajor + int lda = k; + int ldb = n; + int ldc = n; - auto smm = [&, m, n, k, alpha, beta]() { + auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() { const char transa = 'N'; const char transb = 'N'; - const int lda = m; - const int ldb = k; - const int ldc = m; paddle::operators::math::CBlas::SMM_GEMM(&transa, &transb, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, CSMM, &ldc); }; - auto mkl = [&, m, n, k, alpha, beta]() { - int lda = k; - int ldb = n; - int ldc = n; + auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() { paddle::operators::math::CBlas::GEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, lda, B, ldb, beta, CMKL, ldc); }; + smm(); mkl(); ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel()); @@ -105,6 +104,8 @@ void MklSmmCompare(int m, int n, int k) { TEST(math_function, gemm_mkl_vs_smm) { MklSmmCompare(1, 2, 3); MklSmmCompare(1, 2, 3); + MklSmmCompare(3, 2, 1); + MklSmmCompare(3, 2, 1); MklSmmCompare(3, 8, 5); MklSmmCompare(3, 8, 5); }