提交 e53f517a 编写于 作者: P pawelpiotrowicz 提交者: Tao Luo

fix for multithreading test_analyzer_image_classification --num_threads=X (#18265)

test=develop
上级 65d98752
......@@ -77,11 +77,6 @@ framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
bool NgraphEngine::is_training = false;
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
std::unordered_map<std::string,
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
NgraphEngine::t_in_cache_ = {};
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
......@@ -453,6 +448,9 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
}
void NgraphEngine::ClearNgCache() {
auto& engine_cache = main_engine_cache::fetch();
auto& t_in_cache_ = main_t_in_cache::fetch();
auto it = engine_cache.begin();
while (it != engine_cache.end()) {
auto ng_engine = it->second;
......@@ -494,6 +492,8 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
std::to_string(interval[1]) + engine_key;
func_cache_key_ = std::to_string(std::hash<std::string>()(func_cache_key_));
auto& engine_cache = main_engine_cache::fetch();
if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) {
ClearNgCache();
......@@ -533,6 +533,9 @@ void NgraphEngine::Run(const framework::Scope& scope,
const std::vector<std::string>* p_var_out;
bool is_test;
auto& engine_cache = main_engine_cache::fetch();
auto& t_in_cache_ = main_t_in_cache::fetch();
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function");
ng_handle = engine_cache[func_cache_key_].ngraph_handle;
......
......@@ -14,11 +14,13 @@ limitations under the License. */
#pragma once
#include <list>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/operator.h"
......@@ -40,6 +42,82 @@ struct EngineCache {
bool is_test = true;
};
template <class T, class Engine, int separator = 0>
class NgraphThreadCache {
public:
typedef decltype(Engine::getMutex()) mutex_type;
typedef std::lock_guard<mutex_type> guard_type;
typedef T& ref_type;
enum class type_of_thread { unknown, forward, backward };
template <class S>
struct MetaInfo {
std::thread::id owner_tid; // owner of the cache, future use;
type_of_thread worker_type; // future use
S real_content;
MetaInfo()
: owner_tid{std::this_thread::get_id()},
worker_type{type_of_thread::unknown} {}
};
typedef std::unique_ptr<MetaInfo<T>> content_type;
typedef std::list<content_type> storage_type;
protected:
static storage_type l;
static mutex_type getMutex() { return Engine::getMutex(); }
static void remove_from_list(const T* raw_ptr) {
guard_type guard(getMutex());
l.remove_if([raw_ptr](const content_type& sh) {
return &(sh->real_content) == raw_ptr;
});
}
template <class TRaw>
struct TLSDescriptor {
TRaw* raw_ptr;
TLSDescriptor() : raw_ptr{nullptr} {}
~TLSDescriptor() {
// if thread die
NgraphThreadCache::remove_from_list(raw_ptr);
/* TODO : Parallel executor swap */
// FastMultiThreadCache::keep_alive_for_backward_thread(raw_ptr);
}
};
public:
NgraphThreadCache() = delete;
NgraphThreadCache(const NgraphThreadCache& copy) = delete;
static T& fetch() {
thread_local TLSDescriptor<T> tls;
if (!tls.raw_ptr) {
using elem_type = typename content_type::element_type;
content_type _p(new elem_type());
if (!_p) PADDLE_THROW("Cannot alloc memory for thread-cache ");
guard_type guard(getMutex());
l.push_back(std::move(_p));
tls.raw_ptr = &l.back()->real_content;
}
return *(tls.raw_ptr);
}
auto getSize() -> decltype(l.size()) {
guard_type guard(getMutex());
return l.size();
}
template <class F>
void for_each_cache(F f) {
guard_type guard(getMutex());
std::for_each(l.begin(), l.end(), f);
}
};
template <class T, class Engine, int separator>
typename NgraphThreadCache<T, Engine, separator>::storage_type
NgraphThreadCache<T, Engine, separator>::l;
// perform graph build through bridge and execute computation
class NgraphEngine {
public:
......@@ -57,11 +135,20 @@ class NgraphEngine {
const framework::BlockDesc& prog,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
static std::recursive_mutex& getMutex() {
static std::recursive_mutex mx;
return mx;
}
private:
static std::unordered_map<std::string, EngineCache> engine_cache;
static std::unordered_map<
std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
t_in_cache_;
template <class T>
using ThCache =
NgraphThreadCache<std::unordered_map<std::string, T>, NgraphEngine>;
using main_engine_cache = ThCache<EngineCache>;
using main_t_in_cache =
ThCache<std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>;
static framework::Variable* pre_var_ptr;
const framework::Scope& scope_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册