From c5623c87a32b19f308a380cba022aae73bba0cb2 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Sat, 12 Jan 2019 12:56:11 +0000 Subject: [PATCH] init jit matmul kernel --- paddle/fluid/operators/jit/benchmark.cc | 21 +++++++++ paddle/fluid/operators/jit/helper.cc | 1 + paddle/fluid/operators/jit/kernel_base.h | 8 ++++ .../operators/jit/more/mkl/CMakeLists.txt | 1 + .../fluid/operators/jit/refer/CMakeLists.txt | 1 + paddle/fluid/operators/jit/refer/refer.cc | 2 + paddle/fluid/operators/jit/refer/refer.h | 6 +++ paddle/fluid/operators/jit/test.cc | 47 +++++++++++++++++++ 8 files changed, 87 insertions(+) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 4b4ce07fa78..65241b270a2 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -210,6 +210,24 @@ void BenchSeqPoolKernel() { } } +template +void BenchMatMulKernel() { + for (int m : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + for (int k : TestSizes()) { + std::vector a(m * k), b(k * n), c(m * n); + RandomVec(m * k, a.data(), -2.f, 2.f); + RandomVec(k * n, b.data(), -2.f, 2.f); + const T* a_data = a.data(); + const T* b_data = b.data(); + T* c_data = c.data(); + BenchAllImpls, PlaceType>(k, a_data, b_data, + c_data, m, n, k); + } + } + } +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: @@ -251,4 +269,7 @@ int main(int argc, char* argv[]) { // seq pool function BenchSeqPoolKernel(); + + // matmul + BenchMatMulKernel(); } diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index 7d02590f2e5..2465199f430 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -47,6 +47,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kLayerNorm); ONE_CASE(kNCHW16CMulNC); ONE_CASE(kSeqPool); + ONE_CASE(kMatMul); default: PADDLE_THROW("Not support type: %d, or forget to add it.", kt); return "NOT JITKernel"; diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 2a7697a6f25..69112c0ee9e 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -42,6 +42,7 @@ typedef enum { kLayerNorm, kNCHW16CMulNC, kSeqPool, + kMatMul, } KernelType; typedef enum { @@ -135,6 +136,13 @@ struct SeqPoolTuples { typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*); }; +template +struct MatMulTuples { + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(const T*, const T*, T*, int, int, int); +}; + template struct CRFDecodingTuples { typedef T data_type; diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index f5ed2f05721..f5fd1b3d240 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -3,6 +3,7 @@ cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml) set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE) # use mkl kernels by name and type +# USE_JITKERNEL_MORE(kMatMul, mkl) USE_JITKERNEL_MORE(kVMul, mkl) USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVScal, mkl) diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 0f626bb3bfd..9a7e80740fa 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -27,3 +27,4 @@ USE_JITKERNEL_REFER(kCRFDecoding) USE_JITKERNEL_REFER(kLayerNorm) USE_JITKERNEL_REFER(kNCHW16CMulNC) USE_JITKERNEL_REFER(kSeqPool) +USE_JITKERNEL_REFER(kMatMul) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 85381daa474..1b8dd0e3151 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -49,4 +49,6 @@ REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC); REGISTER_REFER_KERNEL(kSeqPool, SeqPool); +REGISTER_REFER_KERNEL(kMatMul, MatMul); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index b4e9c8dd107..cbf799cbd63 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -354,6 +354,10 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { } } +// A(M,K) * B(K,N) = C(M,N) +template +void MatMul(const T* A, const T* B, T* C, int M, int N, int K) {} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -394,6 +398,8 @@ DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples); DECLARE_REFER_KERNEL(SeqPool, SeqPoolTuples); +DECLARE_REFER_KERNEL(MatMul, MatMulTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 30291bfef3b..e6a9690a474 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -229,6 +229,25 @@ struct TestFuncWithRefer, std::vector, } }; +template +struct TestFuncWithRefer, std::vector, std::vector> { + void operator()(const typename jit::MatMulTuples::func_type tgt, + const std::vector& a, const std::vector& b, + const std::vector& cref, int m, int n, int k) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(a.size(), static_cast(m * k)); + EXPECT_EQ(b.size(), static_cast(k * n)); + EXPECT_EQ(cref.size(), static_cast(m * n)); + std::vector c(cref.size()); + const T* a_data = a.data(); + const T* b_data = b.data(); + const T* cref_data = cref.data(); + T* c_data = c.data(); + tgt(a_data, b_data, c_data, m, n, k); + ExpectEQ(c_data, cref_data, m * n); + } +}; + template void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { @@ -458,6 +477,28 @@ void TestSeqPoolKernel() { } } +template +void TestMatMulKernel() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int m : {1, 2, 3, 4}) { + for (int n : {1, 2, 3, 4}) { + for (int k : TestSizes()) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector a(m * k), b(k * n), c(m * n); + RandomVec(m * k, a.data(), -2.f, 2.f); + RandomVec(k * n, b.data(), -2.f, 2.f); + const T* a_data = a.data(); + const T* b_data = b.data(); + T* c_data = c.data(); + ref(a_data, b_data, c_data, m, n, k); + TestAllImpls, PlaceType, std::vector, + std::vector, std::vector>(k, a, b, c, m, n, k); + } + } + } +} + template void TestNCHW16CMulNCKernel() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); @@ -618,6 +659,12 @@ TEST(JITKernel, kSeqPool) { TestSeqPoolKernel(); } +TEST(JITKernel, kMatMul) { + namespace jit = paddle::operators::jit; + TestMatMulKernel(); + TestMatMulKernel(); +} + TEST(JITKernel, kNCHW16CMulNC) { namespace jit = paddle::operators::jit; TestNCHW16CMulNCKernel