提交 e53f517a 编写于 作者: P pawelpiotrowicz 提交者: Tao Luo

fix for multithreading test_analyzer_image_classification --num_threads=X (#18265)

test=develop
上级 65d98752
...@@ -77,11 +77,6 @@ framework::Variable* NgraphEngine::pre_var_ptr = nullptr; ...@@ -77,11 +77,6 @@ framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr; const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
bool NgraphEngine::is_training = false; bool NgraphEngine::is_training = false;
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
std::unordered_map<std::string,
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
NgraphEngine::t_in_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ = std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU"); ngraph::runtime::Backend::create("CPU");
...@@ -453,6 +448,9 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction( ...@@ -453,6 +448,9 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
} }
void NgraphEngine::ClearNgCache() { void NgraphEngine::ClearNgCache() {
auto& engine_cache = main_engine_cache::fetch();
auto& t_in_cache_ = main_t_in_cache::fetch();
auto it = engine_cache.begin(); auto it = engine_cache.begin();
while (it != engine_cache.end()) { while (it != engine_cache.end()) {
auto ng_engine = it->second; auto ng_engine = it->second;
...@@ -494,6 +492,8 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { ...@@ -494,6 +492,8 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
std::to_string(interval[1]) + engine_key; std::to_string(interval[1]) + engine_key;
func_cache_key_ = std::to_string(std::hash<std::string>()(func_cache_key_)); func_cache_key_ = std::to_string(std::hash<std::string>()(func_cache_key_));
auto& engine_cache = main_engine_cache::fetch();
if (engine_cache.find(func_cache_key_) != engine_cache.end()) { if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) { if (engine_cache[func_cache_key_].persistables.size() == 0) {
ClearNgCache(); ClearNgCache();
...@@ -533,6 +533,9 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -533,6 +533,9 @@ void NgraphEngine::Run(const framework::Scope& scope,
const std::vector<std::string>* p_var_out; const std::vector<std::string>* p_var_out;
bool is_test; bool is_test;
auto& engine_cache = main_engine_cache::fetch();
auto& t_in_cache_ = main_t_in_cache::fetch();
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(), PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function"); "Cannot find cached data to run ngraph function");
ng_handle = engine_cache[func_cache_key_].ngraph_handle; ng_handle = engine_cache[func_cache_key_].ngraph_handle;
......
...@@ -14,11 +14,13 @@ limitations under the License. */ ...@@ -14,11 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include <list>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -40,6 +42,82 @@ struct EngineCache { ...@@ -40,6 +42,82 @@ struct EngineCache {
bool is_test = true; bool is_test = true;
}; };
template <class T, class Engine, int separator = 0>
class NgraphThreadCache {
public:
typedef decltype(Engine::getMutex()) mutex_type;
typedef std::lock_guard<mutex_type> guard_type;
typedef T& ref_type;
enum class type_of_thread { unknown, forward, backward };
template <class S>
struct MetaInfo {
std::thread::id owner_tid; // owner of the cache, future use;
type_of_thread worker_type; // future use
S real_content;
MetaInfo()
: owner_tid{std::this_thread::get_id()},
worker_type{type_of_thread::unknown} {}
};
typedef std::unique_ptr<MetaInfo<T>> content_type;
typedef std::list<content_type> storage_type;
protected:
static storage_type l;
static mutex_type getMutex() { return Engine::getMutex(); }
static void remove_from_list(const T* raw_ptr) {
guard_type guard(getMutex());
l.remove_if([raw_ptr](const content_type& sh) {
return &(sh->real_content) == raw_ptr;
});
}
template <class TRaw>
struct TLSDescriptor {
TRaw* raw_ptr;
TLSDescriptor() : raw_ptr{nullptr} {}
~TLSDescriptor() {
// if thread die
NgraphThreadCache::remove_from_list(raw_ptr);
/* TODO : Parallel executor swap */
// FastMultiThreadCache::keep_alive_for_backward_thread(raw_ptr);
}
};
public:
NgraphThreadCache() = delete;
NgraphThreadCache(const NgraphThreadCache& copy) = delete;
static T& fetch() {
thread_local TLSDescriptor<T> tls;
if (!tls.raw_ptr) {
using elem_type = typename content_type::element_type;
content_type _p(new elem_type());
if (!_p) PADDLE_THROW("Cannot alloc memory for thread-cache ");
guard_type guard(getMutex());
l.push_back(std::move(_p));
tls.raw_ptr = &l.back()->real_content;
}
return *(tls.raw_ptr);
}
auto getSize() -> decltype(l.size()) {
guard_type guard(getMutex());
return l.size();
}
template <class F>
void for_each_cache(F f) {
guard_type guard(getMutex());
std::for_each(l.begin(), l.end(), f);
}
};
template <class T, class Engine, int separator>
typename NgraphThreadCache<T, Engine, separator>::storage_type
NgraphThreadCache<T, Engine, separator>::l;
// perform graph build through bridge and execute computation // perform graph build through bridge and execute computation
class NgraphEngine { class NgraphEngine {
public: public:
...@@ -57,11 +135,20 @@ class NgraphEngine { ...@@ -57,11 +135,20 @@ class NgraphEngine {
const framework::BlockDesc& prog, const framework::BlockDesc& prog,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops); std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
static std::recursive_mutex& getMutex() {
static std::recursive_mutex mx;
return mx;
}
private: private:
static std::unordered_map<std::string, EngineCache> engine_cache; template <class T>
static std::unordered_map< using ThCache =
std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>> NgraphThreadCache<std::unordered_map<std::string, T>, NgraphEngine>;
t_in_cache_;
using main_engine_cache = ThCache<EngineCache>;
using main_t_in_cache =
ThCache<std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>;
static framework::Variable* pre_var_ptr; static framework::Variable* pre_var_ptr;
const framework::Scope& scope_; const framework::Scope& scope_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册