提交 6bc1aaaa 编写于 作者: T tensor-tang

refine the ColMajor replacement

上级 c3862a75
...@@ -168,6 +168,9 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -168,6 +168,9 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB, int M, CBLAS_TRANSPOSE transB, int M,
int N, int K, T alpha, const T *A, int N, int K, T alpha, const T *A,
const T *B, T beta, T *C) const { const T *B, T beta, T *C) const {
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans && if (M * N * K < 128 * 128 * 128 && transA == CblasNoTrans &&
transB == CblasNoTrans) { transB == CblasNoTrans) {
...@@ -175,16 +178,10 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -175,16 +178,10 @@ void Blas<platform::CPUDeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
// Note: SMM use ColMajor // Note: SMM use ColMajor
const char transa = 'N'; const char transa = 'N';
const char transb = 'N'; const char transb = 'N';
const int lda = M;
const int ldb = K;
const int ldc = M;
CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda, CBlas<T>::SMM_GEMM(&transa, &transb, &N, &M, &K, &alpha, B, &ldb, A, &lda,
&beta, C, &ldc); &beta, C, &ldc);
} else { } else {
#endif #endif
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B,
ldb, beta, C, ldc); ldb, beta, C, ldc);
#ifdef PADDLE_WITH_LIBXSMM #ifdef PADDLE_WITH_LIBXSMM
......
...@@ -75,26 +75,25 @@ void MklSmmCompare(int m, int n, int k) { ...@@ -75,26 +75,25 @@ void MklSmmCompare(int m, int n, int k) {
for (int i = 0; i < mat_b.numel(); ++i) { for (int i = 0; i < mat_b.numel(); ++i) {
B[i] = static_cast<T>(i); B[i] = static_cast<T>(i);
} }
// lda,ldb,ldc follow RowMajor
int lda = k;
int ldb = n;
int ldc = n;
auto smm = [&, m, n, k, alpha, beta]() { auto smm = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
const char transa = 'N'; const char transa = 'N';
const char transb = 'N'; const char transb = 'N';
const int lda = m;
const int ldb = k;
const int ldc = m;
paddle::operators::math::CBlas<T>::SMM_GEMM(&transa, &transb, &n, &m, &k, paddle::operators::math::CBlas<T>::SMM_GEMM(&transa, &transb, &n, &m, &k,
&alpha, B, &ldb, A, &lda, &beta, &alpha, B, &ldb, A, &lda, &beta,
CSMM, &ldc); CSMM, &ldc);
}; };
auto mkl = [&, m, n, k, alpha, beta]() { auto mkl = [&, m, n, k, lda, ldb, ldc, alpha, beta]() {
int lda = k;
int ldb = n;
int ldc = n;
paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans, paddle::operators::math::CBlas<T>::GEMM(CblasRowMajor, CblasNoTrans,
CblasNoTrans, m, n, k, alpha, A, CblasNoTrans, m, n, k, alpha, A,
lda, B, ldb, beta, CMKL, ldc); lda, B, ldb, beta, CMKL, ldc);
}; };
smm(); smm();
mkl(); mkl();
ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel()); ASSERT_EQ(mat_c_mkl.numel(), mat_c_smm.numel());
...@@ -105,6 +104,8 @@ void MklSmmCompare(int m, int n, int k) { ...@@ -105,6 +104,8 @@ void MklSmmCompare(int m, int n, int k) {
TEST(math_function, gemm_mkl_vs_smm) { TEST(math_function, gemm_mkl_vs_smm) {
MklSmmCompare<float>(1, 2, 3); MklSmmCompare<float>(1, 2, 3);
MklSmmCompare<double>(1, 2, 3); MklSmmCompare<double>(1, 2, 3);
MklSmmCompare<float>(3, 2, 1);
MklSmmCompare<double>(3, 2, 1);
MklSmmCompare<float>(3, 8, 5); MklSmmCompare<float>(3, 8, 5);
MklSmmCompare<double>(3, 8, 5); MklSmmCompare<double>(3, 8, 5);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册