提交 6421c61a 编写于 作者: B baojun 提交者: Tao Luo

Update ngraph engine for multiple threading (#19155)

* update for multiple threading
test=develop

* remove PADDLE_ENFORCE test=develop
上级 e26411ce
...@@ -72,18 +72,14 @@ static std::map<ngraph::element::Type, framework::proto::VarType::Type> ...@@ -72,18 +72,14 @@ static std::map<ngraph::element::Type, framework::proto::VarType::Type>
{ngraph::element::boolean, framework::proto::VarType::BOOL}}; {ngraph::element::boolean, framework::proto::VarType::BOOL}};
std::vector<std::string> NgraphEngine::feed_vars = {}; std::vector<std::string> NgraphEngine::feed_vars = {};
std::vector<std::string> NgraphEngine::fetch_vars = {};
framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
bool NgraphEngine::is_training = false;
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ = std::weak_ptr<ngraph::runtime::Backend> NgraphEngine::wp_backend_;
ngraph::runtime::Backend::create("CPU");
std::mutex NgraphEngine::ng_mutex_;
static std::vector<std::vector<int>> NgraphOpIntervals( static std::vector<std::vector<int>> NgraphOpIntervals(
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) { std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::feed_vars.clear(); NgraphEngine::feed_vars.clear();
NgraphEngine::fetch_vars.clear();
std::vector<std::vector<int>> intervals; std::vector<std::vector<int>> intervals;
int size = ops->size(); int size = ops->size();
...@@ -118,11 +114,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -118,11 +114,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int index = right; int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) { while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
for (auto& var_name_item : ops->at(index)->Inputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::fetch_vars.emplace_back(var_name);
}
}
++index; ++index;
} }
...@@ -167,16 +158,22 @@ static void SubstituteNgraphOp( ...@@ -167,16 +158,22 @@ static void SubstituteNgraphOp(
framework::OpRegistry::CreateOp(ng_op_desc)); framework::OpRegistry::CreateOp(ng_op_desc));
} }
std::string SerializedBlock(const std::vector<framework::OpDesc*>& op_descs) { std::string SerializedBlock(const framework::BlockDesc& bdesc) {
framework::proto::BlockDesc block_proto; framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto); framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0); block_desc.Proto()->set_idx(0);
for (auto* op_desc : op_descs) { for (auto& op_desc : bdesc.AllOps()) {
auto* op = block_desc.AppendOp(); auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto(); *op->Proto() = *op_desc->Proto();
} }
auto* vars = block_desc.Proto()->mutable_vars();
for (auto& var_desc : bdesc.AllVars()) {
*vars->Add() = *var_desc->Proto();
}
return block_desc.Proto()->SerializeAsString(); return block_desc.Proto()->SerializeAsString();
} }
...@@ -213,12 +210,12 @@ std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs, ...@@ -213,12 +210,12 @@ std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs,
void NgraphEngine::FuseNgraphOps( void NgraphEngine::FuseNgraphOps(
const framework::BlockDesc& block_desc, const framework::BlockDesc& block_desc,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) { std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::p_bdesc = &block_desc;
auto intervals = NgraphOpIntervals(ops); auto intervals = NgraphOpIntervals(ops);
std::string serialized_block = SerializedBlock(block_desc);
std::string engine_key = std::string engine_key =
GenerateEngineKey(feed_vars, fetch_vars, ops->size()); std::to_string(std::hash<std::string>()(serialized_block));
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) { for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
SubstituteNgraphOp(ops, engine_key, "", *it); SubstituteNgraphOp(ops, engine_key, serialized_block, *it);
} }
} }
...@@ -232,6 +229,20 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope, ...@@ -232,6 +229,20 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
var_node_map_ = std::make_shared< var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>(); std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
std::lock_guard<std::mutex> lock(ng_mutex_);
if (!wp_backend_.lock()) {
try {
VLOG(3) << "ngraph creating CPU backend.";
backend_ = ngraph::runtime::Backend::create("CPU");
} catch (...) {
PADDLE_THROW("Unsupported nGraph backend");
}
wp_backend_ = backend_;
} else {
backend_ = wp_backend_.lock();
}
GetNgFunction(ctx); GetNgFunction(ctx);
} }
...@@ -239,24 +250,11 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ...@@ -239,24 +250,11 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval"); auto interval = ctx.Attr<std::vector<int>>("interval");
std::string serialized_graph = ctx.Attr<std::string>("graph"); std::string serialized_graph = ctx.Attr<std::string>("graph");
auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}
auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}
framework::proto::BlockDesc block_proto; framework::proto::BlockDesc block_proto;
if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph); if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph);
framework::BlockDesc block_desc(nullptr, &block_proto); framework::BlockDesc block_desc(nullptr, &block_proto);
if (!serialized_graph.empty()) {
NgraphEngine::p_bdesc = &block_desc;
}
for (auto& var : p_bdesc->AllVars()) { for (auto& var : block_desc.AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR || var->GetType() == framework::proto::VarType::LOD_TENSOR ||
var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) { var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) {
...@@ -284,10 +282,9 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ...@@ -284,10 +282,9 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
} }
std::vector<paddle::framework::OpDesc*> ops_desc; std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) { for (auto op_desc : block_desc.AllOps()) {
ops_desc.emplace_back(op_desc); ops_desc.emplace_back(op_desc);
if (op_desc->Type().find("_grad") != std::string::npos) { if (op_desc->Type().find("_grad") != std::string::npos) {
is_training = true;
this->is_test_ = false; this->is_test_ = false;
} }
} }
...@@ -298,8 +295,7 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ...@@ -298,8 +295,7 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
framework::OpRegistry::CreateOp(*(ops_desc[idx]))); framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx; ++idx;
} }
while (idx < static_cast<int>(ops_desc.size()) && while (idx < static_cast<int>(ops_desc.size())) {
ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx); auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) { for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
...@@ -309,9 +305,21 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ...@@ -309,9 +305,21 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
++idx; ++idx;
} }
auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}
auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}
if (var_in_.empty() && var_out_.empty()) { if (var_in_.empty() && var_out_.empty()) {
BuildNgIO(ops_desc, interval); BuildNgIO(ops_desc, interval);
} }
for (size_t i = 0; i < var_in_.size(); ++i) { for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i]; auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) { if (persistables_.find(var_name) == persistables_.end()) {
...@@ -324,6 +332,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -324,6 +332,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
const std::vector<int>& interval) { const std::vector<int>& interval) {
std::unordered_set<std::string> inputs; std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs; std::unordered_set<std::string> outputs;
for (int i = interval[0]; i < interval[1]; ++i) { for (int i = interval[0]; i < interval[1]; ++i) {
auto op = ops_desc[i]; auto op = ops_desc[i];
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
...@@ -359,15 +368,11 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -359,15 +368,11 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
op->Type()); op->Type());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
if (this->is_test_) { if (this->is_test_) {
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || if (post_op_inputs_.find(var_name) != post_op_inputs_.end()) {
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
this->var_out_.emplace_back(var_name); this->var_out_.emplace_back(var_name);
} }
} else { } else {
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) != if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) { persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name); this->var_out_.emplace_back(var_name);
} }
...@@ -434,10 +439,14 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction( ...@@ -434,10 +439,14 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
ngraph::ParameterVector func_inputs; ngraph::ParameterVector func_inputs;
for (auto& vo : var_out_) { for (auto& vo : var_out_) {
PADDLE_ENFORCE_GT(var_node_map_->count(vo), 0,
"Cannot find vo %s in var_node_map_", vo);
func_outputs.emplace_back(var_node_map_->at(vo)); func_outputs.emplace_back(var_node_map_->at(vo));
} }
for (auto& vi : var_in_) { for (auto& vi : var_in_) {
PADDLE_ENFORCE_GT(var_node_map_->count(vi), 0,
"Cannot find vi %s in var_node_map_", vi);
std::shared_ptr<ngraph::op::Parameter> prm = std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>( std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi)); var_in_node_map_->at(vi));
...@@ -454,7 +463,8 @@ void NgraphEngine::ClearNgCache() { ...@@ -454,7 +463,8 @@ void NgraphEngine::ClearNgCache() {
auto it = engine_cache.begin(); auto it = engine_cache.begin();
while (it != engine_cache.end()) { while (it != engine_cache.end()) {
auto ng_engine = it->second; auto ng_engine = it->second;
backend_->remove_compiled_function(ng_engine.ngraph_handle); ng_engine.ngraph_backend->remove_compiled_function(ng_engine.ngraph_handle);
ng_engine.ngraph_backend.reset();
++it; ++it;
} }
engine_cache.clear(); engine_cache.clear();
...@@ -497,13 +507,6 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { ...@@ -497,13 +507,6 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
if (engine_cache.find(func_cache_key_) != engine_cache.end()) { if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) { if (engine_cache[func_cache_key_].persistables.size() == 0) {
ClearNgCache(); ClearNgCache();
} else {
auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) {
ClearNgCache();
}
pre_var_ptr = var;
} }
} }
...@@ -515,6 +518,7 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { ...@@ -515,6 +518,7 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
for (auto& r : func->get_results()) { for (auto& r : func->get_results()) {
r->set_needs_default_layout(true); r->set_needs_default_layout(true);
} }
engine_cache[func_cache_key_].ngraph_backend = backend_;
engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func); engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func);
engine_cache[func_cache_key_].persistables = this->persistables_; engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_; engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
...@@ -526,31 +530,32 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { ...@@ -526,31 +530,32 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
void NgraphEngine::Run(const framework::Scope& scope, void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
VLOG(3) << "NgraphEngine Run ...";
std::shared_ptr<ngraph::runtime::Executable> ng_handle; std::shared_ptr<ngraph::runtime::Executable> ng_handle;
std::shared_ptr<ngraph::runtime::Backend> ng_backend;
const std::set<std::string>* p_persistables; const std::set<std::string>* p_persistables;
const std::vector<size_t>* p_var_in_updates; const std::vector<size_t>* p_var_in_updates;
const std::vector<std::string>* p_var_in; const std::vector<std::string>* p_var_in;
const std::vector<std::string>* p_var_out; const std::vector<std::string>* p_var_out;
bool is_test;
auto& engine_cache = main_engine_cache::fetch(); auto& engine_cache = main_engine_cache::fetch();
auto& t_in_cache_ = main_t_in_cache::fetch(); auto& t_in_cache_ = main_t_in_cache::fetch();
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(), PADDLE_ENFORCE_GT(engine_cache.count(func_cache_key_), 0,
"Cannot find cached data to run ngraph function"); "Cannot find cached data to run ngraph function");
ng_handle = engine_cache[func_cache_key_].ngraph_handle; ng_handle = engine_cache[func_cache_key_].ngraph_handle;
ng_backend = engine_cache[func_cache_key_].ngraph_backend;
p_persistables = &(engine_cache[func_cache_key_].persistables); p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates); p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in); p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out); p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in; std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {}; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {};
auto m_parameters = ng_handle->get_parameters(); auto m_parameters = ng_handle->get_parameters();
auto m_results = ng_handle->get_results(); auto m_results = ng_handle->get_results();
if (is_test && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) { if (is_inference_ && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
p_t_in = &(t_in_cache_[func_cache_key_]); p_t_in = &(t_in_cache_[func_cache_key_]);
for (size_t i = 0; i < p_var_in_updates->size(); ++i) { for (size_t i = 0; i < p_var_in_updates->size(); ++i) {
int index = p_var_in_updates->at(i); int index = p_var_in_updates->at(i);
...@@ -562,14 +567,14 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -562,14 +567,14 @@ void NgraphEngine::Run(const framework::Scope& scope,
if (var && var->IsType<framework::LoDTensor>()) { if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]); void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
ti = backend_->create_tensor(ng_type, sp, pd_arr); ti = ng_backend->create_tensor(ng_type, sp, pd_arr);
(*p_t_in)[index] = ti; (*p_t_in)[index] = ti;
} else { } else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi); PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
} }
} }
} else { } else {
if (is_test) { if (is_inference_) {
p_t_in = &(t_in_cache_[func_cache_key_]); p_t_in = &(t_in_cache_[func_cache_key_]);
} else { } else {
p_t_in = &t_in; p_t_in = &t_in;
...@@ -584,15 +589,13 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -584,15 +589,13 @@ void NgraphEngine::Run(const framework::Scope& scope,
if (var && var->IsType<framework::LoDTensor>()) { if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]); void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()), ti = ng_backend->create_tensor(ng_type, sp, pd_arr);
"Ensure ngraph tensor layout align with paddle tensor");
ti = backend_->create_tensor(ng_type, sp, pd_arr);
} else { } else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi); PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
} }
bool is_persistable = bool is_persistable =
(p_persistables->find(vi) != p_persistables->end()) ? true : false; (p_persistables->find(vi) != p_persistables->end()) ? true : false;
if (!is_training && is_test && is_persistable) { if (is_inference_ && is_persistable) {
ti->set_stale(false); ti->set_stale(false);
} }
(*p_t_in).emplace_back(ti); (*p_t_in).emplace_back(ti);
...@@ -615,7 +618,7 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -615,7 +618,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
auto ng_type = m_results[i]->get_element_type(); auto ng_type = m_results[i]->get_element_type();
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]); void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
std::shared_ptr<ngraph::runtime::Tensor> to = std::shared_ptr<ngraph::runtime::Tensor> to =
backend_->create_tensor(ng_type, sp, pd_arr); ng_backend->create_tensor(ng_type, sp, pd_arr);
t_out.emplace_back(to); t_out.emplace_back(to);
} else { } else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vo); PADDLE_THROW("Cannot find var or tensor with var name %s", vo);
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <list> #include <list>
#include <memory> #include <memory>
#include <mutex> //NOLINT
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -34,7 +35,8 @@ namespace operators { ...@@ -34,7 +35,8 @@ namespace operators {
// cache engine repetitives // cache engine repetitives
struct EngineCache { struct EngineCache {
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle; std::shared_ptr<ngraph::runtime::Executable> ngraph_handle = nullptr;
std::shared_ptr<ngraph::runtime::Backend> ngraph_backend = nullptr;
std::set<std::string> persistables; std::set<std::string> persistables;
std::vector<std::string> var_in; std::vector<std::string> var_in;
std::vector<std::string> var_out; std::vector<std::string> var_out;
...@@ -127,9 +129,7 @@ class NgraphEngine { ...@@ -127,9 +129,7 @@ class NgraphEngine {
void Run(const framework::Scope& scope, const platform::Place& place) const; void Run(const framework::Scope& scope, const platform::Place& place) const;
static bool is_training; static std::vector<std::string> feed_vars;
static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
static void FuseNgraphOps( static void FuseNgraphOps(
const framework::BlockDesc& prog, const framework::BlockDesc& prog,
...@@ -149,19 +149,24 @@ class NgraphEngine { ...@@ -149,19 +149,24 @@ class NgraphEngine {
using main_t_in_cache = using main_t_in_cache =
ThCache<std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>; ThCache<std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>;
static framework::Variable* pre_var_ptr;
const framework::Scope& scope_; const framework::Scope& scope_;
const platform::Place& place_; const platform::Place& place_;
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_; std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_; std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::set<std::string> persistables_; std::set<std::string> persistables_;
std::unordered_set<std::string> post_op_inputs_; std::unordered_set<std::string> post_op_inputs_;
// it is test for a single run, it can be a validation during training
bool is_test_{true}; bool is_test_{true};
// inference only. eg. CAPI inference
bool is_inference_{false};
std::string func_cache_key_; std::string func_cache_key_;
// use a weak pointer to keep backend_ alive
// to avoid it to be destropyed too earlier
static std::weak_ptr<ngraph::runtime::Backend> wp_backend_;
// use mutex to keep it thread safe
static std::mutex ng_mutex_;
// ngraph backend eg. CPU // ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_; std::shared_ptr<ngraph::runtime::Backend> backend_;
// var_name of inputs // var_name of inputs
std::vector<std::string> var_in_; std::vector<std::string> var_in_;
// var_name of outputs from fetch in order // var_name of outputs from fetch in order
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册