From eec9c9cbe7e841a9b3fef301d8dc9518d6db2452 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Fri, 15 Nov 2019 13:54:38 +0800 Subject: [PATCH] Fix jit tls issue (#21151) --- paddle/fluid/operators/jit/helper.cc | 6 ++++++ paddle/fluid/operators/jit/helper.h | 17 +++++++++++++++-- paddle/fluid/operators/jit/kernel_pool.cc | 6 ++++++ paddle/fluid/operators/jit/kernel_pool.h | 14 ++++++++++++-- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index f868c847bd8..992228adb50 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -22,6 +22,12 @@ namespace paddle { namespace operators { namespace jit { +std::unordered_map>& GetFuncCacheMap() { + static thread_local std::unordered_map> + g_func_cache_map; + return g_func_cache_map; +} + #define ONE_CASE(key) \ case key: \ return #key diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 1ac5318d461..9a2447a98fb 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include // for std::move @@ -175,13 +176,25 @@ typename KernelTuple::func_type GetDefaultBestFunc( return funcs[0]; } +extern std::unordered_map>& +GetFuncCacheMap(); + template class KernelFuncs { public: KernelFuncs() = default; static KernelFuncs& Cache() { - static thread_local KernelFuncs g_func_cache; - return g_func_cache; + auto& func_cache_map = GetFuncCacheMap(); + std::string key = typeid(KernelFuncs).name(); + auto iter = func_cache_map.find(key); + if (iter != func_cache_map.end()) { + return *(KernelFuncs*)(iter->second.get()); + } else { + std::shared_ptr cache = + std::make_shared>(); + func_cache_map.emplace(key, cache); + return *(KernelFuncs*)(cache.get()); + } } // the exposed interface to use diff --git a/paddle/fluid/operators/jit/kernel_pool.cc b/paddle/fluid/operators/jit/kernel_pool.cc index bc98c644fbe..8eef8e47474 100644 --- a/paddle/fluid/operators/jit/kernel_pool.cc +++ b/paddle/fluid/operators/jit/kernel_pool.cc @@ -21,6 +21,12 @@ namespace paddle { namespace operators { namespace jit { +std::unordered_map>& GetJITCodesMap() { + static thread_local std::unordered_map> + g_jit_codes_map; + return g_jit_codes_map; +} + JitCodeCreatorPool& JitCodeCreatorPool::Instance() { static JitCodeCreatorPool g_creator_pool; return g_creator_pool; diff --git a/paddle/fluid/operators/jit/kernel_pool.h b/paddle/fluid/operators/jit/kernel_pool.h index 04710a54ac9..548c8704126 100644 --- a/paddle/fluid/operators/jit/kernel_pool.h +++ b/paddle/fluid/operators/jit/kernel_pool.h @@ -28,6 +28,8 @@ namespace paddle { namespace operators { namespace jit { +extern std::unordered_map>& GetJITCodesMap(); + template class JitCodePool { typedef std::unique_ptr GenBasePtr; @@ -36,8 +38,16 @@ class JitCodePool { public: JitCodePool() = default; static JitCodePool& Instance() { - static thread_local JitCodePool g_jit_codes; - return g_jit_codes; + auto& jit_codes_map = GetJITCodesMap(); + std::string key = typeid(JitCodePool).name(); + auto iter = jit_codes_map.find(key); + if (iter != jit_codes_map.end()) { + return *(JitCodePool*)(iter->second.get()); + } else { + std::shared_ptr cache = std::make_shared>(); + jit_codes_map.emplace(key, cache); + return *(JitCodePool*)(cache.get()); + } } const JitCodeMap& AllKernels() { return codes_; } -- GitLab