diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 124587b1430359ebfea9dac4a740ca98559f93b4..053e5ed07983d81521fbefed4040c8876a5447c4 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -33,8 +33,11 @@ namespace jit { #define EXP_MAX_INPUT 40.0 template -inline typename KernelTuples::func_type GetJitCode( - typename KernelTuples::attr_type attr) { +inline typename std::enable_if< + std::is_same::value && + std::is_same::value, + typename KernelTuples::func_type>::type +GetJitCode(typename KernelTuples::attr_type attr) { using Func = typename KernelTuples::func_type; using Attr = typename KernelTuples::attr_type; size_t key = JitCodeKey(attr); @@ -45,21 +48,19 @@ inline typename KernelTuples::func_type GetJitCode( // creator is not related with attr, so can use KernelKey as key KernelKey kkey(KT, PlaceType()); - if (std::is_same::value) { - // pool: (KernelKey(type, place), vector) - auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); - auto iter = creator_map.find(kkey); - if (iter != creator_map.end()) { - auto& creators = iter->second; - for (auto& cur : creators) { - auto i = dynamic_cast*>(cur.get()); - if (i && i->UseMe(attr)) { - auto p = i->CreateJitCode(attr); - if (p) { - auto f = p->template getCode(); - codes.Insert(key, std::move(p)); - return f; - } + // pool: (KernelKey(type, place), vector) + auto& creator_map = JitCodeCreatorPool().Instance().AllCreators(); + auto iter = creator_map.find(kkey); + if (iter != creator_map.end()) { + auto& creators = iter->second; + for (auto& cur : creators) { + auto i = dynamic_cast*>(cur.get()); + if (i && i->UseMe(attr)) { + auto p = i->CreateJitCode(attr); + if (p) { + auto f = p->template getCode(); + codes.Insert(key, std::move(p)); + return f; } } } @@ -67,6 +68,15 @@ inline typename KernelTuples::func_type GetJitCode( return nullptr; } +template +inline typename std::enable_if< + !std::is_same::value || + !std::is_same::value, + typename KernelTuples::func_type>::type +GetJitCode(typename KernelTuples::attr_type attr) { + return nullptr; +} + // Refer code do not related with attr, which is just for cast // Refer is always on CPUPlace template diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index dba7e754eaece357fba0bd8f9f5795bace31cdce..9ceca24079f7ab8dc3e7f5afb7d9bee84c9e954d 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -48,13 +48,13 @@ void ExpectEQ(const T* target, const T* refer, int n) { std::vector TestSizes() { std::vector s; - for (int i = 1; i < 10; ++i) { + for (int i = 1; i < 32; ++i) { s.push_back(i); } - // // test some large size - // s.push_back(100); - // s.push_back(1000); - // s.push_back(2000); + // test some large size + s.push_back(100); + s.push_back(1000); + s.push_back(2000); return s; } @@ -148,8 +148,7 @@ void TestXYZNKernel() { TEST(JITKernel, vmul) { namespace jit = paddle::operators::jit; TestXYZNKernel(); - // TODO(TJ): fix double issue - // TestXYZNKernel(); + TestXYZNKernel(); } TEST(JITKernel, vadd) {