diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 7d78c61739a9d1ba4577079ea48ea4d3467f3fd8..720273a6de02c3f7955c919c5227d3ce46ff6b70 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -77,11 +77,6 @@ framework::Variable* NgraphEngine::pre_var_ptr = nullptr; const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr; bool NgraphEngine::is_training = false; -std::unordered_map NgraphEngine::engine_cache = {}; -std::unordered_map>> - NgraphEngine::t_in_cache_ = {}; - std::shared_ptr NgraphEngine::backend_ = ngraph::runtime::Backend::create("CPU"); @@ -453,6 +448,9 @@ std::shared_ptr NgraphEngine::BuildNgFunction( } void NgraphEngine::ClearNgCache() { + auto& engine_cache = main_engine_cache::fetch(); + auto& t_in_cache_ = main_t_in_cache::fetch(); + auto it = engine_cache.begin(); while (it != engine_cache.end()) { auto ng_engine = it->second; @@ -494,6 +492,8 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { std::to_string(interval[1]) + engine_key; func_cache_key_ = std::to_string(std::hash()(func_cache_key_)); + auto& engine_cache = main_engine_cache::fetch(); + if (engine_cache.find(func_cache_key_) != engine_cache.end()) { if (engine_cache[func_cache_key_].persistables.size() == 0) { ClearNgCache(); @@ -533,6 +533,9 @@ void NgraphEngine::Run(const framework::Scope& scope, const std::vector* p_var_out; bool is_test; + auto& engine_cache = main_engine_cache::fetch(); + auto& t_in_cache_ = main_t_in_cache::fetch(); + PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(), "Cannot find cached data to run ngraph function"); ng_handle = engine_cache[func_cache_key_].ngraph_handle; diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index 7fa443a5d49b17d116895bdd3227561fb3f8515a..c60a5ad4eee5f0d886f8f919f97f453032a9a9b3 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -14,11 +14,13 @@ limitations under the License. */ #pragma once +#include #include #include #include #include #include +#include #include #include "paddle/fluid/framework/operator.h" @@ -40,6 +42,82 @@ struct EngineCache { bool is_test = true; }; +template +class NgraphThreadCache { + public: + typedef decltype(Engine::getMutex()) mutex_type; + typedef std::lock_guard guard_type; + typedef T& ref_type; + enum class type_of_thread { unknown, forward, backward }; + + template + struct MetaInfo { + std::thread::id owner_tid; // owner of the cache, future use; + type_of_thread worker_type; // future use + S real_content; + MetaInfo() + : owner_tid{std::this_thread::get_id()}, + worker_type{type_of_thread::unknown} {} + }; + + typedef std::unique_ptr> content_type; + typedef std::list storage_type; + + protected: + static storage_type l; + static mutex_type getMutex() { return Engine::getMutex(); } + static void remove_from_list(const T* raw_ptr) { + guard_type guard(getMutex()); + l.remove_if([raw_ptr](const content_type& sh) { + return &(sh->real_content) == raw_ptr; + }); + } + + template + struct TLSDescriptor { + TRaw* raw_ptr; + TLSDescriptor() : raw_ptr{nullptr} {} + ~TLSDescriptor() { + // if thread die + NgraphThreadCache::remove_from_list(raw_ptr); + + /* TODO : Parallel executor swap */ + // FastMultiThreadCache::keep_alive_for_backward_thread(raw_ptr); + } + }; + + public: + NgraphThreadCache() = delete; + NgraphThreadCache(const NgraphThreadCache& copy) = delete; + + static T& fetch() { + thread_local TLSDescriptor tls; + if (!tls.raw_ptr) { + using elem_type = typename content_type::element_type; + content_type _p(new elem_type()); + if (!_p) PADDLE_THROW("Cannot alloc memory for thread-cache "); + guard_type guard(getMutex()); + l.push_back(std::move(_p)); + tls.raw_ptr = &l.back()->real_content; + } + return *(tls.raw_ptr); + } + auto getSize() -> decltype(l.size()) { + guard_type guard(getMutex()); + return l.size(); + } + + template + void for_each_cache(F f) { + guard_type guard(getMutex()); + std::for_each(l.begin(), l.end(), f); + } +}; + +template +typename NgraphThreadCache::storage_type + NgraphThreadCache::l; + // perform graph build through bridge and execute computation class NgraphEngine { public: @@ -57,11 +135,20 @@ class NgraphEngine { const framework::BlockDesc& prog, std::vector>* ops); + static std::recursive_mutex& getMutex() { + static std::recursive_mutex mx; + return mx; + } + private: - static std::unordered_map engine_cache; - static std::unordered_map< - std::string, std::vector>> - t_in_cache_; + template + using ThCache = + NgraphThreadCache, NgraphEngine>; + + using main_engine_cache = ThCache; + using main_t_in_cache = + ThCache>>; + static framework::Variable* pre_var_ptr; const framework::Scope& scope_;