diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 5a276172c399a78ae95ce61b30219e3cec809526..ef7ccc64adf2a93c1ac8e2a00ffb3299c10d3647 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -94,8 +94,7 @@ int main(int argc, char* argv[]) { RandomVec(d, x.data()); RandomVec(d, y.data()); // refer - auto refer = jit::GetRefer::func_type, - jit::VMulTuples::attr_type>(); + auto refer = jit::GetRefer>(); if (refer) { auto res = BenchTartgetFunc::func_type>(refer, x, y, z); @@ -103,8 +102,7 @@ int main(int argc, char* argv[]) { } // test jitcode - auto jitcode = jit::GetJitCode::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto jitcode = jit::GetJitCode, PlaceType>(d); if (jitcode) { auto res = BenchTartgetFunc::func_type>(jitcode, x, y, z); @@ -118,10 +116,8 @@ int main(int argc, char* argv[]) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = - dynamic_cast::func_type, - jit::VMulTuples::attr_type>*>( - impl.get()); + auto i = dynamic_cast>*>( + impl.get()); if (i && i->UseMe(d)) { auto more = i->GetFunc(); auto res = @@ -132,8 +128,7 @@ int main(int argc, char* argv[]) { } // Test result from Get function - auto tgt = jit::Get::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto tgt = jit::Get, PlaceType>(d); if (!tgt) { LOG(ERROR) << "Target can not be empty!"; } diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 16cd18e2ccc1ff38a7b5c3ba2808d634587bd29d..11bbd6a56cf4d1503f9739ee55971bf70b7ee883 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -32,9 +32,11 @@ namespace jit { #define SIGMOID_THRESHOLD_MAX 13.0 #define EXP_MAX_INPUT 40.0 -template -inline Func GetJitCode(Attr attr) { +template +inline typename KernelTuples::func_type GetJitCode( + typename KernelTuples::attr_type attr) { + using Func = typename KernelTuples::func_type; + using Attr = typename KernelTuples::attr_type; size_t key = JitCodeKey(attr); auto& codes = JitCodePool().Instance(); if (codes.Has(key)) { @@ -65,8 +67,8 @@ inline Func GetJitCode(Attr attr) { // Refer code do not related with attr, which is just for cast // Refer is always on CPUPlace -template -inline Func GetRefer() { +template +inline typename KernelTuples::func_type GetRefer() { auto& ref_pool = ReferKernelPool().Instance().AllKernels(); KernelKey kkey(KT, platform::CPUPlace()); auto ref_iter = ref_pool.find(kkey); @@ -74,7 +76,7 @@ inline Func GetRefer() { "Every Kernel should have reference function."); auto& ref_impls = ref_iter->second; for (auto& impl : ref_impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i) { return i->GetFunc(); } @@ -82,10 +84,10 @@ inline Func GetRefer() { return nullptr; } -template -Func Get(Attr attr) { - auto jitfunc = GetJitCode(attr); +typename KernelTuples::func_type Get(typename KernelTuples::attr_type attr) { + auto jitfunc = GetJitCode(attr); if (jitfunc) { return jitfunc; } @@ -97,7 +99,7 @@ Func Get(Attr attr) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = dynamic_cast*>(impl.get()); + auto i = dynamic_cast*>(impl.get()); if (i && i->UseMe(attr)) { return i->GetFunc(); } @@ -105,7 +107,7 @@ Func Get(Attr attr) { } // The last implementation should be reference function on CPUPlace. - return GetRefer(); + return GetRefer(); } } // namespace jit diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index df7be6ab8ec38ff8c48d2b1014a4ca3fc459cde5..84f030889859c0d54b544b8a6be44c0469e806b8 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -36,10 +36,13 @@ class Kernel { DISABLE_COPY_AND_ASSIGN(Kernel); }; -template +template class KernelImpl : public Kernel { + using T = typename KernelTuples::data_type; + using Func = typename KernelTuples::func_type; + using Attr = typename KernelTuples::attr_type; + public: - using ELEMENT_TYPE = T; virtual Func GetFunc() const { return func; } virtual bool UseMe(Attr attr) const = 0; @@ -47,11 +50,13 @@ class KernelImpl : public Kernel { Func func{nullptr}; }; -template -class ReferKernel : public KernelImpl { +template +class ReferKernel : public KernelImpl { public: // Refer code can always be used - bool UseMe(Attr attr) const override { return true; } + bool UseMe(typename KernelTuples::attr_type attr) const override { + return true; + } }; } // namespace jit diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index c0f738ccebefd0aa6a7289c20311bffdb6d40f42..56469b054de4cbc864e4f15dd29615b78d4dfdf3 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -28,8 +28,7 @@ template void VMul(const T* x, const T* y, T* z, int n); template -class VMulKernel : public KernelImpl::func_type, - typename VMulTuples::attr_type> { +class VMulKernel : public KernelImpl> { public: VMulKernel() { this->func = VMul; } bool UseMe(int d) const override { diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 97aa5de8fcf8370515563d91f9b89f974cc28516..99d1cbd43ec04e574c914802a0b327a54ab7b21f 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -29,8 +29,7 @@ void VMul(const T* x, const T* y, T* z, int n) { } template -class VMulKernel : public ReferKernel::func_type, - typename VMulTuples::attr_type> { +class VMulKernel : public ReferKernel> { public: VMulKernel() { this->func = VMul; } }; diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 1ee6ce6b13bd001750ef999fc7a2e3d271f945fa..4d7970414ff71154fbf8cdf094f680d40e2518f7 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -89,8 +89,7 @@ TEST(JitKernel, vmul) { namespace jit = paddle::operators::jit; const auto KT = jit::vmul; for (int d : TestSizes()) { - auto ref = jit::GetRefer::func_type, - jit::VMulTuples::attr_type>(); + auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector x(d), y(d), zref(d); @@ -115,8 +114,7 @@ TEST(JitKernel, vmul) { ExpectEQ(yinp_data, zref_data, d); // test jitcode - auto jitcode = jit::GetJitCode::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto jitcode = jit::GetJitCode, PlaceType>(d); if (jitcode) { VLOG(10) << "Test jitcode, size: " << d; TestTartgetFunc::func_type>(jitcode, x, y, zref); @@ -129,10 +127,8 @@ TEST(JitKernel, vmul) { if (iter != pool.end()) { auto& impls = iter->second; for (auto& impl : impls) { - auto i = - dynamic_cast::func_type, - jit::VMulTuples::attr_type>*>( - impl.get()); + auto i = dynamic_cast>*>( + impl.get()); if (i && i->UseMe(d)) { auto more = i->GetFunc(); VLOG(10) << "Test More Kernel, size: " << d; @@ -142,8 +138,7 @@ TEST(JitKernel, vmul) { } // Test result from Get function VLOG(10) << "Test Get function, size: " << d; - auto tgt = jit::Get::func_type, - jit::VMulTuples::attr_type, PlaceType>(d); + auto tgt = jit::Get, PlaceType>(d); TestTartgetFunc::func_type>(tgt, x, y, zref); } }