From 21516e5cbecf7f59ff9fafa55d9192a73c6d3373 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 5 Jul 2018 22:15:47 +0800 Subject: [PATCH] add unit test of smm --- .../operators/math/math_function_test.cc | 57 ++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/math/math_function_test.cc b/paddle/fluid/operators/math/math_function_test.cc index b545671b43d..7aabb9981f7 100644 --- a/paddle/fluid/operators/math/math_function_test.cc +++ b/paddle/fluid/operators/math/math_function_test.cc @@ -54,8 +54,63 @@ TEST(math_function, gemm_notrans_cblas) { EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[7], 99); } +#ifdef PADDLE_WITH_LIBXSMM +template +void MklSmmCompare(int m, int n, int k) { + paddle::framework::Tensor mat_a; + paddle::framework::Tensor mat_b; + paddle::framework::Tensor mat_c_smm; + paddle::framework::Tensor mat_c_mkl; + auto* cpu_place = new paddle::platform::CPUPlace(); + + T* A = mat_a.mutable_data({m, k}, *cpu_place); + T* B = mat_b.mutable_data({k, n}, *cpu_place); + T* CSMM = mat_c_smm.mutable_data({m, n}, *cpu_place); + T* CMKL = mat_c_mkl.mutable_data({m, n}, *cpu_place); + T alpha = static_cast(1); + T beta = static_cast(0); + for (int i = 0; i < mat_a.numel(); ++i) { + A[i] = static_cast(i); + } + for (int i = 0; i < mat_b.numel(); ++i) { + B[i] = static_cast(i); + } + + auto smm = [&, m, n, k, alpha, beta]() { + const char transa = 'N'; + const char transb = 'N'; + const int lda = m; + const int ldb = k; + const int ldc = m; + paddle::operators::math::CBlas::SMM_GEMM(&transa, &transb, &m, &n, &k, + &alpha, A, &lda, B, &ldb, &beta, + CSMM, &ldc); + }; + + auto mkl = [&, m, n, k, alpha, beta]() { + int lda = k; + int ldb = n; + int ldc = n; + paddle::operators::math::CBlas::GEMM(CblasRowMajor, CblasNoTrans, + CblasNoTrans, m, n, k, alpha, A, + lda, B, ldb, beta, CMKL, ldc); + }; + smm(); + mkl(); + ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel()); + for (int i = 0; i < mat_c_mkl.numel(); ++i) { + EXPECT_FLOAT_EQ(CSMM[i], CMKL[i]); + } +} +TEST(math_function, gemm_mkl_vs_smm) { + MklSmmCompare(1, 2, 3); + MklSmmCompare(1, 2, 3); + MklSmmCompare(3, 8, 5); + MklSmmCompare(3, 8, 5); +} +#endif -TEST(math_function, gemm_trans_clbas) { +TEST(math_function, gemm_trans_cblas) { paddle::framework::Tensor input1; paddle::framework::Tensor input2; paddle::framework::Tensor input3; -- GitLab