提交 eec9c9cb 编写于 作者: Y Yihua Xu 提交者: whs

Fix jit tls issue (#21151)

上级 a9d4eed3
......@@ -22,6 +22,12 @@ namespace paddle {
namespace operators {
namespace jit {
std::unordered_map<std::string, std::shared_ptr<void>>& GetFuncCacheMap() {
static thread_local std::unordered_map<std::string, std::shared_ptr<void>>
g_func_cache_map;
return g_func_cache_map;
}
#define ONE_CASE(key) \
case key: \
return #key
......
......@@ -15,6 +15,7 @@
#pragma once
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility> // for std::move
......@@ -175,13 +176,25 @@ typename KernelTuple::func_type GetDefaultBestFunc(
return funcs[0];
}
extern std::unordered_map<std::string, std::shared_ptr<void>>&
GetFuncCacheMap();
template <typename KernelTuple, typename PlaceType>
class KernelFuncs {
public:
KernelFuncs() = default;
static KernelFuncs& Cache() {
static thread_local KernelFuncs<KernelTuple, PlaceType> g_func_cache;
return g_func_cache;
auto& func_cache_map = GetFuncCacheMap();
std::string key = typeid(KernelFuncs<KernelTuple, PlaceType>).name();
auto iter = func_cache_map.find(key);
if (iter != func_cache_map.end()) {
return *(KernelFuncs<KernelTuple, PlaceType>*)(iter->second.get());
} else {
std::shared_ptr<void> cache =
std::make_shared<KernelFuncs<KernelTuple, PlaceType>>();
func_cache_map.emplace(key, cache);
return *(KernelFuncs<KernelTuple, PlaceType>*)(cache.get());
}
}
// the exposed interface to use
......
......@@ -21,6 +21,12 @@ namespace paddle {
namespace operators {
namespace jit {
std::unordered_map<std::string, std::shared_ptr<void>>& GetJITCodesMap() {
static thread_local std::unordered_map<std::string, std::shared_ptr<void>>
g_jit_codes_map;
return g_jit_codes_map;
}
JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool;
return g_creator_pool;
......
......@@ -28,6 +28,8 @@ namespace paddle {
namespace operators {
namespace jit {
extern std::unordered_map<std::string, std::shared_ptr<void>>& GetJITCodesMap();
template <KernelType KT>
class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr;
......@@ -36,8 +38,16 @@ class JitCodePool {
public:
JitCodePool() = default;
static JitCodePool& Instance() {
static thread_local JitCodePool<KT> g_jit_codes;
return g_jit_codes;
auto& jit_codes_map = GetJITCodesMap();
std::string key = typeid(JitCodePool<KT>).name();
auto iter = jit_codes_map.find(key);
if (iter != jit_codes_map.end()) {
return *(JitCodePool<KT>*)(iter->second.get());
} else {
std::shared_ptr<void> cache = std::make_shared<JitCodePool<KT>>();
jit_codes_map.emplace(key, cache);
return *(JitCodePool<KT>*)(cache.get());
}
}
const JitCodeMap& AllKernels() { return codes_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册