diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 4b4ce07fa78b97e636173566fa104cb8a18c914e..65241b270a2175a869dc85a5075e854e3a399b12 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -210,6 +210,24 @@ void BenchSeqPoolKernel() { } } +template +void BenchMatMulKernel() { + for (int m : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + for (int k : TestSizes()) { + std::vector a(m * k), b(k * n), c(m * n); + RandomVec(m * k, a.data(), -2.f, 2.f); + RandomVec(k * n, b.data(), -2.f, 2.f); + const T* a_data = a.data(); + const T* b_data = b.data(); + T* c_data = c.data(); + BenchAllImpls, PlaceType>(k, a_data, b_data, + c_data, m, n, k); + } + } + } +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: @@ -251,4 +269,7 @@ int main(int argc, char* argv[]) { // seq pool function BenchSeqPoolKernel(); + + // matmul + BenchMatMulKernel(); } diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index 7d02590f2e5d82b5105132d7af716f14c661d067..2465199f430b5c13a0615b6868951be2ac5996d2 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -47,6 +47,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kLayerNorm); ONE_CASE(kNCHW16CMulNC); ONE_CASE(kSeqPool); + ONE_CASE(kMatMul); default: PADDLE_THROW("Not support type: %d, or forget to add it.", kt); return "NOT JITKernel"; diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 2a7697a6f253dcc2b8143d9f14a80a1cfd45996d..69112c0ee9e2232f604e4821f3852bfe646bdac4 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -42,6 +42,7 @@ typedef enum { kLayerNorm, kNCHW16CMulNC, kSeqPool, + kMatMul, } KernelType; typedef enum { @@ -135,6 +136,13 @@ struct SeqPoolTuples { typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); }; +template +struct MatMulTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, const T*, T*, int, int, int); +}; + template struct CRFDecodingTuples { typedef T data_type; diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index f5ed2f0572176e42b774259c2b8fe9713d989417..f5fd1b3d240951b73a7c87a39f2343a4473fcf31 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -3,6 +3,7 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) # use mkl kernels by name and type +# USE_JITKERNEL_MORE(kMatMul, mkl) USE_JITKERNEL_MORE(kVMul, mkl) USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVScal, mkl) diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 0f626bb3bfd2851e3fb6ad8265169f9bb9860851..9a7e80740faa0245332bab2e046476c9014ef992 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -27,3 +27,4 @@ USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kNCHW16CMulNC) USE_JITKERNEL_REFER(kSeqPool) +USE_JITKERNEL_REFER(kMatMul) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 85381daa47484a4053326f04e12d583543a423e0..1b8dd0e31517f9e37ef3665af48b6707a56a9425 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -49,4 +49,6 @@ REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(kSeqPool, SeqPool); +REGISTER_REFER_KERNEL(kMatMul, MatMul); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index b4e9c8dd107ee844544165b1719d38754ae976bc..cbf799cbd63157f5ec7bdca452d832ea32ff95ed 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -354,6 +354,10 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { } } +// A(M,K) * B(K,N) = C(M,N) +template +void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -394,6 +398,8 @@ DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); +DECLARE_REFER_KERNEL(MatMul, MatMulTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 30291bfef3bc96fe2e687e5be6d782eee89496aa..e6a9690a47421be4dbc2832bb7267c980e85d0f2 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -229,6 +229,25 @@ struct TestFuncWithRefer, std::vector, } }; +template +struct TestFuncWithRefer, std::vector, std::vector> { + void operator()(const typename jit::MatMulTuples::func_type tgt, + const std::vector& a, const std::vector& b, + const std::vector& cref, int m, int n, int k) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(a.size(), static_cast(m * k)); + EXPECT_EQ(b.size(), static_cast(k * n)); + EXPECT_EQ(cref.size(), static_cast(m * n)); + std::vector c(cref.size()); + const T* a_data = a.data(); + const T* b_data = b.data(); + const T* cref_data = cref.data(); + T* c_data = c.data(); + tgt(a_data, b_data, c_data, m, n, k); + ExpectEQ(c_data, cref_data, m * n); + } +}; + template void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { @@ -458,6 +477,28 @@ void TestSeqPoolKernel() { } } +template +void TestMatMulKernel() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int m : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + for (int k : TestSizes()) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector a(m * k), b(k * n), c(m * n); + RandomVec(m * k, a.data(), -2.f, 2.f); + RandomVec(k * n, b.data(), -2.f, 2.f); + const T* a_data = a.data(); + const T* b_data = b.data(); + T* c_data = c.data(); + ref(a_data, b_data, c_data, m, n, k); + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector>(k, a, b, c, m, n, k); + } + } + } +} + template void TestNCHW16CMulNCKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); @@ -618,6 +659,12 @@ TEST(JITKernel, kSeqPool) { TestSeqPoolKernel(); } +TEST(JITKernel, kMatMul) { + namespace jit = paddle::operators::jit; + TestMatMulKernel(); + TestMatMulKernel(); +} + TEST(JITKernel, kNCHW16CMulNC) { namespace jit = paddle::operators::jit; TestNCHW16CMulNCKernel