From bc0df6a9487f0aa877647b2a15789fe3eb206616 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 12 Dec 2018 07:54:10 +0000 Subject: [PATCH] make typename tuples --- paddle/fluid/operators/jit/benchmark.cc | 15 +++++--------- paddle/fluid/operators/jit/helper.h | 24 ++++++++++++----------- paddle/fluid/operators/jit/kernel_base.h | 15 +++++++++----- paddle/fluid/operators/jit/more/mkl/mkl.h | 3 +-- paddle/fluid/operators/jit/refer/refer.h | 3 +-- paddle/fluid/operators/jit/test.cc | 15 +++++--------- 6 files changed, 35 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 5a276172c3..ef7ccc64ad 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -94,8 +94,7 @@ int main(int argc, char* argv[]) { RandomVec(d, x.data()); RandomVec(d, y.data()); // refer - auto refer = jit::GetRefer::func_type, - jit::VMulTuples::attr_type>(); + auto refer = jit::GetRefer>(); if (refer) { auto res = BenchTartgetFunc::func_type>(refer, x, y, z); @@ -103,8 +102,7 @@ int main(int argc, char* argv[]) { } // test jitcode - auto jitcode = jit::GetJitCode::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto jitcode = jit::GetJitCode, PlaceType>(d); if (jitcode) { auto res = BenchTartgetFunc::func_type>(jitcode, x, y, z); @@ -118,10 +116,8 @@ int main(int argc, char* argv[]) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = - dynamic_cast::func_type, - jit::VMulTuples::attr_type>*>( - impl.get()); + auto i = dynamic_cast>*>( + impl.get()); if (i && i->UseMe(d)) { auto more = i->GetFunc(); auto res = @@ -132,8 +128,7 @@ int main(int argc, char* argv[]) { } // Test result from Get function - auto tgt = jit::Get::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto tgt = jit::Get, PlaceType>(d); if (!tgt) { LOG(ERROR) << "Target can not be empty!"; } diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 16cd18e2cc..11bbd6a56c 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -32,9 +32,11 @@ namespace jit { #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 -template -inline Func GetJitCode(Attr attr) { +template +inline typename KernelTuples::func_type GetJitCode( + typename KernelTuples::attr_type attr) { + using Func = typename KernelTuples::func_type; + using Attr = typename KernelTuples::attr_type; size_t key = JitCodeKey(attr); auto& codes = JitCodePool().Instance(); if (codes.Has(key)) { @@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) { // Refer code do not related with attr, which is just for cast // Refer is always on CPUPlace -template -inline Func GetRefer() { +template +inline typename KernelTuples::func_type GetRefer() { auto& ref_pool = ReferKernelPool().Instance().AllKernels(); KernelKey kkey(KT, platform::CPUPlace()); auto ref_iter = ref_pool.find(kkey); @@ -74,7 +76,7 @@ inline Func GetRefer() { "Every Kernel should have reference function."); auto& ref_impls = ref_iter->second; for (auto& impl : ref_impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i) { return i->GetFunc(); } @@ -82,10 +84,10 @@ inline Func GetRefer() { return nullptr; } -template -Func Get(Attr attr) { - auto jitfunc = GetJitCode(attr); +typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) { + auto jitfunc = GetJitCode(attr); if (jitfunc) { return jitfunc; } @@ -97,7 +99,7 @@ Func Get(Attr attr) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { return i->GetFunc(); } @@ -105,7 +107,7 @@ Func Get(Attr attr) { } // The last implementation should be reference function on CPUPlace. - return GetRefer(); + return GetRefer(); } } // namespace jit diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index df7be6ab8e..84f0308898 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -36,10 +36,13 @@ class Kernel { DISABLE_COPY_AND_ASSIGN(Kernel); }; -template +template class KernelImpl : public Kernel { + using T = typename KernelTuples::data_type; + using Func = typename KernelTuples::func_type; + using Attr = typename KernelTuples::attr_type; + public: - using ELEMENT_TYPE = T; virtual Func GetFunc() const { return func; } virtual bool UseMe(Attr attr) const = 0; @@ -47,11 +50,13 @@ class KernelImpl : public Kernel { Func func{nullptr}; }; -template -class ReferKernel : public KernelImpl { +template +class ReferKernel : public KernelImpl { public: // Refer code can always be used - bool UseMe(Attr attr) const override { return true; } + bool UseMe(typename KernelTuples::attr_type attr) const override { + return true; + } }; } // namespace jit diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index c0f738cceb..56469b054d 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -28,8 +28,7 @@ template void VMul(const T* x, const T* y, T* z, int n); template -class VMulKernel : public KernelImpl::func_type, - typename VMulTuples::attr_type> { +class VMulKernel : public KernelImpl> { public: VMulKernel() { this->func = VMul; } bool UseMe(int d) const override { diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 97aa5de8fc..99d1cbd43e 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -29,8 +29,7 @@ void VMul(const T* x, const T* y, T* z, int n) { } template -class VMulKernel : public ReferKernel::func_type, - typename VMulTuples::attr_type> { +class VMulKernel : public ReferKernel> { public: VMulKernel() { this->func = VMul; } }; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 1ee6ce6b13..4d7970414f 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -89,8 +89,7 @@ TEST(JitKernel, vmul) { namespace jit = paddle::operators::jit; const auto KT = jit::vmul; for (int d : TestSizes()) { - auto ref = jit::GetRefer::func_type, - jit::VMulTuples::attr_type>(); + auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector x(d), y(d), zref(d); @@ -115,8 +114,7 @@ TEST(JitKernel, vmul) { ExpectEQ(yinp_data, zref_data, d); // test jitcode - auto jitcode = jit::GetJitCode::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto jitcode = jit::GetJitCode, PlaceType>(d); if (jitcode) { VLOG(10) << "Test jitcode, size: " << d; TestTartgetFunc::func_type>(jitcode, x, y, zref); @@ -129,10 +127,8 @@ TEST(JitKernel, vmul) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = - dynamic_cast::func_type, - jit::VMulTuples::attr_type>*>( - impl.get()); + auto i = dynamic_cast>*>( + impl.get()); if (i && i->UseMe(d)) { auto more = i->GetFunc(); VLOG(10) << "Test More Kernel, size: " << d; @@ -142,8 +138,7 @@ TEST(JitKernel, vmul) { } // Test result from Get function VLOG(10) << "Test Get function, size: " << d; - auto tgt = jit::Get::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto tgt = jit::Get, PlaceType>(d); TestTartgetFunc::func_type>(tgt, x, y, zref); } } -- GitLab