提交 eec9c9cb 编写于 作者: Y Yihua Xu 提交者: whs

Fix jit tls issue (#21151)

上级 a9d4eed3
...@@ -22,6 +22,12 @@ namespace paddle { ...@@ -22,6 +22,12 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
std::unordered_map<std::string, std::shared_ptr<void>>& GetFuncCacheMap() {
static thread_local std::unordered_map<std::string, std::shared_ptr<void>>
g_func_cache_map;
return g_func_cache_map;
}
#define ONE_CASE(key) \ #define ONE_CASE(key) \
case key: \ case key: \
return #key return #key
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> // for std::move #include <utility> // for std::move
...@@ -175,13 +176,25 @@ typename KernelTuple::func_type GetDefaultBestFunc( ...@@ -175,13 +176,25 @@ typename KernelTuple::func_type GetDefaultBestFunc(
return funcs[0]; return funcs[0];
} }
extern std::unordered_map<std::string, std::shared_ptr<void>>&
GetFuncCacheMap();
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
class KernelFuncs { class KernelFuncs {
public: public:
KernelFuncs() = default; KernelFuncs() = default;
static KernelFuncs& Cache() { static KernelFuncs& Cache() {
static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache; auto& func_cache_map = GetFuncCacheMap();
return g_func_cache; std::string key = typeid(KernelFuncs<KernelTuple, PlaceType>).name();
auto iter = func_cache_map.find(key);
if (iter != func_cache_map.end()) {
return *(KernelFuncs<KernelTuple, PlaceType>*)(iter->second.get());
} else {
std::shared_ptr<void> cache =
std::make_shared<KernelFuncs<KernelTuple, PlaceType>>();
func_cache_map.emplace(key, cache);
return *(KernelFuncs<KernelTuple, PlaceType>*)(cache.get());
}
} }
// the exposed interface to use // the exposed interface to use
......
...@@ -21,6 +21,12 @@ namespace paddle { ...@@ -21,6 +21,12 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
std::unordered_map<std::string, std::shared_ptr<void>>& GetJITCodesMap() {
static thread_local std::unordered_map<std::string, std::shared_ptr<void>>
g_jit_codes_map;
return g_jit_codes_map;
}
JitCodeCreatorPool& JitCodeCreatorPool::Instance() { JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool; static JitCodeCreatorPool g_creator_pool;
return g_creator_pool; return g_creator_pool;
......
...@@ -28,6 +28,8 @@ namespace paddle { ...@@ -28,6 +28,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace jit { namespace jit {
extern std::unordered_map<std::string, std::shared_ptr<void>>& GetJITCodesMap();
template <KernelType KT> template <KernelType KT>
class JitCodePool { class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr; typedef std::unique_ptr<GenBase> GenBasePtr;
...@@ -36,8 +38,16 @@ class JitCodePool { ...@@ -36,8 +38,16 @@ class JitCodePool {
public: public:
JitCodePool() = default; JitCodePool() = default;
static JitCodePool& Instance() { static JitCodePool& Instance() {
static thread_local JitCodePool<KT> g_jit_codes; auto& jit_codes_map = GetJITCodesMap();
return g_jit_codes; std::string key = typeid(JitCodePool<KT>).name();
auto iter = jit_codes_map.find(key);
if (iter != jit_codes_map.end()) {
return *(JitCodePool<KT>*)(iter->second.get());
} else {
std::shared_ptr<void> cache = std::make_shared<JitCodePool<KT>>();
jit_codes_map.emplace(key, cache);
return *(JitCodePool<KT>*)(cache.get());
}
} }
const JitCodeMap& AllKernels() { return codes_; } const JitCodeMap& AllKernels() { return codes_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册