提交 21516e5c 编写于 作者: T tensor-tang

add unit test of smm

上级 c3941745
...@@ -54,8 +54,63 @@ TEST(math_function, gemm_notrans_cblas) { ...@@ -54,8 +54,63 @@ TEST(math_function, gemm_notrans_cblas) {
EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99); EXPECT_EQ(input3_ptr[7], 99);
} }
#ifdef PADDLE_WITH_LIBXSMM
template <typename T>
void MklSmmCompare(int m, int n, int k) {
paddle::framework::Tensor mat_a;
paddle::framework::Tensor mat_b;
paddle::framework::Tensor mat_c_smm;
paddle::framework::Tensor mat_c_mkl;
auto* cpu_place = new paddle::platform::CPUPlace();
T* A = mat_a.mutable_data<T>({m, k}, *cpu_place);
T* B = mat_b.mutable_data<T>({k, n}, *cpu_place);
T* CSMM = mat_c_smm.mutable_data<T>({m, n}, *cpu_place);
T* CMKL = mat_c_mkl.mutable_data<T>({m, n}, *cpu_place);
T alpha = static_cast<T>(1);
T beta = static_cast<T>(0);
for (int i = 0; i < mat_a.numel(); ++i) {
A[i] = static_cast<T>(i);
}
for (int i = 0; i < mat_b.numel(); ++i) {
B[i] = static_cast<T>(i);
}
auto smm = [&, m, n, k, alpha, beta]() {
const char transa = 'N';
const char transb = 'N';
const int lda = m;
const int ldb = k;
const int ldc = m;
paddle::operators::math::CBlas<T>::SMM_GEMM(&transa, &transb, &m, &n, &k,
&alpha, A, &lda, B, &ldb, &beta,
CSMM, &ldc);
};
auto mkl = [&, m, n, k, alpha, beta]() {
int lda = k;
int ldb = n;
int ldc = n;
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
CblasNoTrans, m, n, k, alpha, A,
lda, B, ldb, beta, CMKL, ldc);
};
smm();
mkl();
ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel());
for (int i = 0; i < mat_c_mkl.numel(); ++i) {
EXPECT_FLOAT_EQ(CSMM[i], CMKL[i]);
}
}
TEST(math_function, gemm_mkl_vs_smm) {
MklSmmCompare<float>(1, 2, 3);
MklSmmCompare<double>(1, 2, 3);
MklSmmCompare<float>(3, 8, 5);
MklSmmCompare<double>(3, 8, 5);
}
#endif
TEST(math_function, gemm_trans_clbas) { TEST(math_function, gemm_trans_cblas) {
paddle::framework::Tensor input1; paddle::framework::Tensor input1;
paddle::framework::Tensor input2; paddle::framework::Tensor input2;
paddle::framework::Tensor input3; paddle::framework::Tensor input3;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册