diff --git a/paddle/fluid/operators/math/math_function_test.cc b/paddle/fluid/operators/math/math_function_test.cc index 078dd448c385dbb8a00025ee2ba08d0c41a4730a..2343e0ee965303c9fdb2ad3faf9ddf6e5bb7782f 100644 --- a/paddle/fluid/operators/math/math_function_test.cc +++ b/paddle/fluid/operators/math/math_function_test.cc @@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) { } delete ctx; } + +template +void GemmWarpTest(int m, int n, int k, T alpha, T beta) { + paddle::framework::Tensor mat_a; + paddle::framework::Tensor mat_b; + paddle::framework::Tensor mat_c_ref; + paddle::framework::Tensor mat_c_mkl; + auto* cpu_place = new paddle::platform::CPUPlace(); + + T* A = mat_a.mutable_data({m, k}, *cpu_place); + T* B = mat_b.mutable_data({k, n}, *cpu_place); + T* CREF = mat_c_ref.mutable_data({m, n}, *cpu_place); + T* CMKL = mat_c_mkl.mutable_data({m, n}, *cpu_place); + + ASSERT_EQ(mat_c_mkl.numel(), mat_c_ref.numel()); + for (int i = 0; i < mat_a.numel(); ++i) { + A[i] = static_cast(i); + } + for (int i = 0; i < mat_b.numel(); ++i) { + B[i] = static_cast(i + 1); + } + for (int i = 0; i < mat_c_ref.numel(); ++i) { + CREF[i] = static_cast(i + 2); + CMKL[i] = CREF[i]; + } + + // this would call gemm_warp + paddle::platform::CPUDeviceContext context(*cpu_place); + GetBlas(context).GEMM(CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, B, + beta, CREF); + + // lda,ldb,ldc follow RowMajor + int lda = k; + int ldb = n; + int ldc = n; + paddle::operators::math::CBlas::GEMM(CblasRowMajor, CblasNoTrans, + CblasNoTrans, m, n, k, alpha, A, lda, + B, ldb, beta, CMKL, ldc); + + for (int i = 0; i < mat_c_mkl.numel(); ++i) { + EXPECT_FLOAT_EQ(CREF[i], CMKL[i]); + } +} + +TEST(math_function, gemm_warp) { + GemmWarpTest(3, 2, 5, 1.f, 0.f); + GemmWarpTest(3, 2, 5, 2.f, 1.f); + GemmWarpTest(8, 5, 6, 1.f, 0.f); + GemmWarpTest(8, 5, 6, 2.f, 1.f); + GemmWarpTest(3, 2, 5, 1.0, 0.0); + GemmWarpTest(3, 2, 5, 2.0, 1.0); + GemmWarpTest(8, 5, 6, 1.0, 0.0); + GemmWarpTest(8, 5, 6, 2.0, 1.0); +}