diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 65241b270a2175a869dc85a5075e854e3a399b12..8dab16c284a6f483dce2d8b1dc729391e2e76b8d 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -213,7 +213,7 @@ void BenchSeqPoolKernel() { template void BenchMatMulKernel() { for (int m : {1, 2, 3, 4}) { - for (int n : {1, 2, 3, 4}) { + for (int n : TestSizes()) { for (int k : TestSizes()) { std::vector a(m * k), b(k * n), c(m * n); RandomVec(m * k, a.data(), -2.f, 2.f); diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index f5fd1b3d240951b73a7c87a39f2343a4473fcf31..7c6a75d35f654b6c2c46b4d498e401db92237341 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -3,7 +3,7 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) # use mkl kernels by name and type -# USE_JITKERNEL_MORE(kMatMul, mkl) +USE_JITKERNEL_MORE(kMatMul, mkl) USE_JITKERNEL_MORE(kVMul, mkl) USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVScal, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 5a499ac2c02aa70d2824f0d3be618e083ba10334..5b20ae4da96f262657e487cd231b5f319bf92641 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -24,6 +24,20 @@ namespace jit { namespace more { namespace mkl { +template <> +void MatMul(const float* a, const float* b, float* c, int m, int n, + int k) { + platform::dynload::cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, + n, k, 1.f, a, k, b, n, 0.f, c, n); +} + +template <> +void MatMul(const double* a, const double* b, double* c, int m, int n, + int k) { + platform::dynload::cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, + n, k, 1.0, a, k, b, n, 0.0, c, n); +} + template <> void VMul(const float* x, const float* y, float* z, int n) { platform::dynload::vsMul(n, x, y, z); @@ -93,6 +107,11 @@ void VAXPY(double a, const double* x, double* y, int n) { } // TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512 +template <> +bool MatMulKernel::UseMe(const int& d) const { + return platform::MayIUse(platform::avx); +} + template <> bool VMulKernel::UseMe(const int& d) const { return platform::MayIUse(platform::avx512f) && d > 512; @@ -139,6 +158,7 @@ bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { return true; \ } +AWALYS_USE_ME_WITH_DOUBLE(MatMul); AWALYS_USE_ME_WITH_DOUBLE(VMul); AWALYS_USE_ME_WITH_DOUBLE(VAdd); AWALYS_USE_ME_WITH_DOUBLE(VScal); @@ -159,6 +179,7 @@ namespace mkl = paddle::operators::jit::more::mkl; REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel, \ mkl::func##Kernel) +REGISTER_MKL_KERNEL(kMatMul, MatMul); REGISTER_MKL_KERNEL(kVMul, VMul); REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(kVScal, VScal); diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 0a3816db24ccd0820cb259b40044e1f5b66665f7..314ef73d8a581eaaf16ff3b8a58189d209b1735f 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -24,6 +24,9 @@ namespace jit { namespace more { namespace mkl { +template +void MatMul(const T* a, const T* b, T* c, int m, int n, int k); + template void VMul(const T* x, const T* y, T* z, int n); @@ -93,6 +96,9 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { const char* ImplType() const override { return "MKL"; } \ } +// ABCMNK +DECLARE_MKL_KERNEL(MatMul, MatMulTuples); + // XYZN DECLARE_MKL_KERNEL(VMul, XYZNTuples); DECLARE_MKL_KERNEL(VAdd, XYZNTuples); diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index cbf799cbd63157f5ec7bdca452d832ea32ff95ed..225319c059e581dd83049c5bf9b93c4ed516de0a 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -356,7 +356,20 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { // A(M,K) * B(K,N) = C(M,N) template -void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {} +void MatMul(const T* A, const T* B, T* C, int M, int N, int K) { + for (int m = 0; m < M; ++m) { + const T* pa = A + m * K; + T* pc = C + m * N; + for (int n = 0; n < N; ++n) { + const T* pb = B + n; + T sum = static_cast(0); + for (int k = 0; k < K; ++k) { + sum += (pa[k] * pb[k * N]); + } + *(pc + n) = sum; + } + } +} #define DECLARE_REFER_KERNEL(name, tuples) \ template \ diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index e6a9690a47421be4dbc2832bb7267c980e85d0f2..1246ee7c24e0de1363fb35351d15fd4bb912c9e8 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -230,7 +230,8 @@ struct TestFuncWithRefer, std::vector, }; template -struct TestFuncWithRefer, std::vector, std::vector> { +struct TestFuncWithRefer, std::vector, std::vector, + std::vector, int, int, int> { void operator()(const typename jit::MatMulTuples::func_type tgt, const std::vector& a, const std::vector& b, const std::vector& cref, int m, int n, int k) { @@ -486,8 +487,8 @@ void TestMatMulKernel() { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector a(m * k), b(k * n), c(m * n); - RandomVec(m * k, a.data(), -2.f, 2.f); - RandomVec(k * n, b.data(), -2.f, 2.f); + RandomVec(m * k, a.data(), -0.2f, 0.2f); + RandomVec(k * n, b.data(), -0.2f, 0.2f); const T* a_data = a.data(); const T* b_data = b.data(); T* c_data = c.data();