diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index f868c847bd80e874da2d2babde58129122e0bc70..992228adb50c8aeea2f4368a5a696787ff384025 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -22,6 +22,12 @@ namespace paddle { namespace operators { namespace jit { +std::unordered_map>& GetFuncCacheMap() { + static thread_local std::unordered_map> + g_func_cache_map; + return g_func_cache_map; +} + #define ONE_CASE(key) \ case key: \ return #key diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 1ac5318d461c2e8bc4f43569602a88f95a76befb..9a2447a98fb7bd6337967219437c9d381a8e672b 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include // for std::move @@ -175,13 +176,25 @@ typename KernelTuple::func_type GetDefaultBestFunc( return funcs[0]; } +extern std::unordered_map>& +GetFuncCacheMap(); + template class KernelFuncs { public: KernelFuncs() = default; static KernelFuncs& Cache() { - static thread_local KernelFuncs g_func_cache; - return g_func_cache; + auto& func_cache_map = GetFuncCacheMap(); + std::string key = typeid(KernelFuncs).name(); + auto iter = func_cache_map.find(key); + if (iter != func_cache_map.end()) { + return *(KernelFuncs*)(iter->second.get()); + } else { + std::shared_ptr cache = + std::make_shared>(); + func_cache_map.emplace(key, cache); + return *(KernelFuncs*)(cache.get()); + } } // the exposed interface to use diff --git a/paddle/fluid/operators/jit/kernel_pool.cc b/paddle/fluid/operators/jit/kernel_pool.cc index bc98c644fbee2cd54faf4dc9fe151b8be131bd7b..8eef8e474748bfd58dad8b1add91804e48d394ad 100644 --- a/paddle/fluid/operators/jit/kernel_pool.cc +++ b/paddle/fluid/operators/jit/kernel_pool.cc @@ -21,6 +21,12 @@ namespace paddle { namespace operators { namespace jit { +std::unordered_map>& GetJITCodesMap() { + static thread_local std::unordered_map> + g_jit_codes_map; + return g_jit_codes_map; +} + JitCodeCreatorPool& JitCodeCreatorPool::Instance() { static JitCodeCreatorPool g_creator_pool; return g_creator_pool; diff --git a/paddle/fluid/operators/jit/kernel_pool.h b/paddle/fluid/operators/jit/kernel_pool.h index 04710a54ac9ddf2ecb8f6a1f2ca33ef158d2d73f..548c87041268a2a24e6c291b6ba8f5518620770a 100644 --- a/paddle/fluid/operators/jit/kernel_pool.h +++ b/paddle/fluid/operators/jit/kernel_pool.h @@ -28,6 +28,8 @@ namespace paddle { namespace operators { namespace jit { +extern std::unordered_map>& GetJITCodesMap(); + template class JitCodePool { typedef std::unique_ptr GenBasePtr; @@ -36,8 +38,16 @@ class JitCodePool { public: JitCodePool() = default; static JitCodePool& Instance() { - static thread_local JitCodePool g_jit_codes; - return g_jit_codes; + auto& jit_codes_map = GetJITCodesMap(); + std::string key = typeid(JitCodePool).name(); + auto iter = jit_codes_map.find(key); + if (iter != jit_codes_map.end()) { + return *(JitCodePool*)(iter->second.get()); + } else { + std::shared_ptr cache = std::make_shared>(); + jit_codes_map.emplace(key, cache); + return *(JitCodePool*)(cache.get()); + } } const JitCodeMap& AllKernels() { return codes_; }