diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 720273a6de02c3f7955c919c5227d3ce46ff6b70..3c53c87c6ff4795c28be9eedc2f3e870e0a20916 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -72,18 +72,14 @@ static std::map {ngraph::element::boolean, framework::proto::VarType::BOOL}}; std::vector NgraphEngine::feed_vars = {}; -std::vector NgraphEngine::fetch_vars = {}; -framework::Variable* NgraphEngine::pre_var_ptr = nullptr; -const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr; -bool NgraphEngine::is_training = false; -std::shared_ptr NgraphEngine::backend_ = - ngraph::runtime::Backend::create("CPU"); +std::weak_ptr NgraphEngine::wp_backend_; + +std::mutex NgraphEngine::ng_mutex_; static std::vector> NgraphOpIntervals( std::vector>* ops) { NgraphEngine::feed_vars.clear(); - NgraphEngine::fetch_vars.clear(); std::vector> intervals; int size = ops->size(); @@ -118,11 +114,6 @@ static std::vector> NgraphOpIntervals( int index = right; while (index < size && ops->at(index)->Type() == framework::kFetchOpType) { - for (auto& var_name_item : ops->at(index)->Inputs()) { - for (auto& var_name : var_name_item.second) { - NgraphEngine::fetch_vars.emplace_back(var_name); - } - } ++index; } @@ -167,16 +158,22 @@ static void SubstituteNgraphOp( framework::OpRegistry::CreateOp(ng_op_desc)); } -std::string SerializedBlock(const std::vector& op_descs) { +std::string SerializedBlock(const framework::BlockDesc& bdesc) { framework::proto::BlockDesc block_proto; framework::BlockDesc block_desc(nullptr, &block_proto); block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_idx(0); - for (auto* op_desc : op_descs) { + for (auto& op_desc : bdesc.AllOps()) { auto* op = block_desc.AppendOp(); *op->Proto() = *op_desc->Proto(); } + + auto* vars = block_desc.Proto()->mutable_vars(); + for (auto& var_desc : bdesc.AllVars()) { + *vars->Add() = *var_desc->Proto(); + } + return block_desc.Proto()->SerializeAsString(); } @@ -213,12 +210,12 @@ std::string GenerateEngineKey(const std::vector& engine_inputs, void NgraphEngine::FuseNgraphOps( const framework::BlockDesc& block_desc, std::vector>* ops) { - NgraphEngine::p_bdesc = &block_desc; auto intervals = NgraphOpIntervals(ops); + std::string serialized_block = SerializedBlock(block_desc); std::string engine_key = - GenerateEngineKey(feed_vars, fetch_vars, ops->size()); + std::to_string(std::hash()(serialized_block)); for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) { - SubstituteNgraphOp(ops, engine_key, "", *it); + SubstituteNgraphOp(ops, engine_key, serialized_block, *it); } } @@ -232,6 +229,20 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope, var_node_map_ = std::make_shared< std::unordered_map>>(); + std::lock_guard lock(ng_mutex_); + + if (!wp_backend_.lock()) { + try { + VLOG(3) << "ngraph creating CPU backend."; + backend_ = ngraph::runtime::Backend::create("CPU"); + } catch (...) { + PADDLE_THROW("Unsupported nGraph backend"); + } + wp_backend_ = backend_; + } else { + backend_ = wp_backend_.lock(); + } + GetNgFunction(ctx); } @@ -239,24 +250,11 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { auto interval = ctx.Attr>("interval"); std::string serialized_graph = ctx.Attr("graph"); - auto input_vars = ctx.Inputs("Xs"); - if (!input_vars.empty()) { - feed_vars = input_vars; - var_in_ = input_vars; - } - auto output_vars = ctx.Outputs("Ys"); - if (!output_vars.empty()) { - var_out_ = output_vars; - } - framework::proto::BlockDesc block_proto; if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph); framework::BlockDesc block_desc(nullptr, &block_proto); - if (!serialized_graph.empty()) { - NgraphEngine::p_bdesc = &block_desc; - } - for (auto& var : p_bdesc->AllVars()) { + for (auto& var : block_desc.AllVars()) { if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || var->GetType() == framework::proto::VarType::LOD_TENSOR || var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) { @@ -284,10 +282,9 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { } std::vector ops_desc; - for (auto op_desc : p_bdesc->AllOps()) { + for (auto op_desc : block_desc.AllOps()) { ops_desc.emplace_back(op_desc); if (op_desc->Type().find("_grad") != std::string::npos) { - is_training = true; this->is_test_ = false; } } @@ -298,8 +295,7 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { framework::OpRegistry::CreateOp(*(ops_desc[idx]))); ++idx; } - while (idx < static_cast(ops_desc.size()) && - ops_desc.at(idx)->Type() != framework::kFetchOpType) { + while (idx < static_cast(ops_desc.size())) { auto op_desc = ops_desc.at(idx); for (auto& var_name_item : op_desc->Inputs()) { for (auto& var_name : var_name_item.second) { @@ -309,9 +305,21 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { ++idx; } + auto input_vars = ctx.Inputs("Xs"); + if (!input_vars.empty()) { + feed_vars = input_vars; + var_in_ = input_vars; + } + + auto output_vars = ctx.Outputs("Ys"); + if (!output_vars.empty()) { + var_out_ = output_vars; + } + if (var_in_.empty() && var_out_.empty()) { BuildNgIO(ops_desc, interval); } + for (size_t i = 0; i < var_in_.size(); ++i) { auto var_name = var_in_[i]; if (persistables_.find(var_name) == persistables_.end()) { @@ -324,6 +332,7 @@ void NgraphEngine::BuildNgIO(const std::vector& ops_desc, const std::vector& interval) { std::unordered_set inputs; std::unordered_set outputs; + for (int i = interval[0]; i < interval[1]; ++i) { auto op = ops_desc[i]; for (auto& var_name_item : op->Inputs()) { @@ -359,15 +368,11 @@ void NgraphEngine::BuildNgIO(const std::vector& ops_desc, op->Type()); for (auto& var_name : var_name_item.second) { if (this->is_test_) { - if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || - find(fetch_vars.begin(), fetch_vars.end(), var_name) != - fetch_vars.end()) { + if (post_op_inputs_.find(var_name) != post_op_inputs_.end()) { this->var_out_.emplace_back(var_name); } } else { - if (find(fetch_vars.begin(), fetch_vars.end(), var_name) != - fetch_vars.end() || - post_op_inputs_.find(var_name) != post_op_inputs_.end() || + if (post_op_inputs_.find(var_name) != post_op_inputs_.end() || persistables_.find(var_name) != persistables_.end()) { this->var_out_.emplace_back(var_name); } @@ -434,10 +439,14 @@ std::shared_ptr NgraphEngine::BuildNgFunction( ngraph::ParameterVector func_inputs; for (auto& vo : var_out_) { + PADDLE_ENFORCE_GT(var_node_map_->count(vo), 0, + "Cannot find vo %s in var_node_map_", vo); func_outputs.emplace_back(var_node_map_->at(vo)); } for (auto& vi : var_in_) { + PADDLE_ENFORCE_GT(var_node_map_->count(vi), 0, + "Cannot find vi %s in var_node_map_", vi); std::shared_ptr prm = std::dynamic_pointer_cast( var_in_node_map_->at(vi)); @@ -454,7 +463,8 @@ void NgraphEngine::ClearNgCache() { auto it = engine_cache.begin(); while (it != engine_cache.end()) { auto ng_engine = it->second; - backend_->remove_compiled_function(ng_engine.ngraph_handle); + ng_engine.ngraph_backend->remove_compiled_function(ng_engine.ngraph_handle); + ng_engine.ngraph_backend.reset(); ++it; } engine_cache.clear(); @@ -497,13 +507,6 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { if (engine_cache.find(func_cache_key_) != engine_cache.end()) { if (engine_cache[func_cache_key_].persistables.size() == 0) { ClearNgCache(); - } else { - auto var_name = engine_cache[func_cache_key_].persistables.begin(); - framework::Variable* var = scope_.FindVar(*var_name); - if (var != pre_var_ptr) { - ClearNgCache(); - } - pre_var_ptr = var; } } @@ -515,6 +518,7 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { for (auto& r : func->get_results()) { r->set_needs_default_layout(true); } + engine_cache[func_cache_key_].ngraph_backend = backend_; engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func); engine_cache[func_cache_key_].persistables = this->persistables_; engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_; @@ -526,31 +530,32 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { void NgraphEngine::Run(const framework::Scope& scope, const platform::Place& place) const { + VLOG(3) << "NgraphEngine Run ..."; std::shared_ptr ng_handle; + std::shared_ptr ng_backend; const std::set* p_persistables; const std::vector* p_var_in_updates; const std::vector* p_var_in; const std::vector* p_var_out; - bool is_test; auto& engine_cache = main_engine_cache::fetch(); auto& t_in_cache_ = main_t_in_cache::fetch(); - PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(), - "Cannot find cached data to run ngraph function"); + PADDLE_ENFORCE_GT(engine_cache.count(func_cache_key_), 0, + "Cannot find cached data to run ngraph function"); ng_handle = engine_cache[func_cache_key_].ngraph_handle; + ng_backend = engine_cache[func_cache_key_].ngraph_backend; p_persistables = &(engine_cache[func_cache_key_].persistables); p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates); p_var_in = &(engine_cache[func_cache_key_].var_in); p_var_out = &(engine_cache[func_cache_key_].var_out); - is_test = engine_cache[func_cache_key_].is_test; std::vector>* p_t_in; std::vector> t_in = {}; auto m_parameters = ng_handle->get_parameters(); auto m_results = ng_handle->get_results(); - if (is_test && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) { + if (is_inference_ && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) { p_t_in = &(t_in_cache_[func_cache_key_]); for (size_t i = 0; i < p_var_in_updates->size(); ++i) { int index = p_var_in_updates->at(i); @@ -562,14 +567,14 @@ void NgraphEngine::Run(const framework::Scope& scope, if (var && var->IsType()) { auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]); - ti = backend_->create_tensor(ng_type, sp, pd_arr); + ti = ng_backend->create_tensor(ng_type, sp, pd_arr); (*p_t_in)[index] = ti; } else { PADDLE_THROW("Cannot find var or tensor with var name %s", vi); } } } else { - if (is_test) { + if (is_inference_) { p_t_in = &(t_in_cache_[func_cache_key_]); } else { p_t_in = &t_in; @@ -584,15 +589,13 @@ void NgraphEngine::Run(const framework::Scope& scope, if (var && var->IsType()) { auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var); void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]); - PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()), - "Ensure ngraph tensor layout align with paddle tensor"); - ti = backend_->create_tensor(ng_type, sp, pd_arr); + ti = ng_backend->create_tensor(ng_type, sp, pd_arr); } else { PADDLE_THROW("Cannot find var or tensor with var name %s", vi); } bool is_persistable = (p_persistables->find(vi) != p_persistables->end()) ? true : false; - if (!is_training && is_test && is_persistable) { + if (is_inference_ && is_persistable) { ti->set_stale(false); } (*p_t_in).emplace_back(ti); @@ -615,7 +618,7 @@ void NgraphEngine::Run(const framework::Scope& scope, auto ng_type = m_results[i]->get_element_type(); void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]); std::shared_ptr to = - backend_->create_tensor(ng_type, sp, pd_arr); + ng_backend->create_tensor(ng_type, sp, pd_arr); t_out.emplace_back(to); } else { PADDLE_THROW("Cannot find var or tensor with var name %s", vo); diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index c60a5ad4eee5f0d886f8f919f97f453032a9a9b3..0fb2d167496b3eabd8e840fe18adb8900d5fb527 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include //NOLINT #include #include #include @@ -34,7 +35,8 @@ namespace operators { // cache engine repetitives struct EngineCache { - std::shared_ptr ngraph_handle; + std::shared_ptr ngraph_handle = nullptr; + std::shared_ptr ngraph_backend = nullptr; std::set persistables; std::vector var_in; std::vector var_out; @@ -127,9 +129,7 @@ class NgraphEngine { void Run(const framework::Scope& scope, const platform::Place& place) const; - static bool is_training; - static const framework::BlockDesc* p_bdesc; - static std::vector feed_vars, fetch_vars; + static std::vector feed_vars; static void FuseNgraphOps( const framework::BlockDesc& prog, @@ -149,19 +149,24 @@ class NgraphEngine { using main_t_in_cache = ThCache>>; - static framework::Variable* pre_var_ptr; - const framework::Scope& scope_; const platform::Place& place_; std::vector> fused_ops_; std::unordered_map var_type_map_; std::set persistables_; std::unordered_set post_op_inputs_; + // it is test for a single run, it can be a validation during training bool is_test_{true}; + // inference only. eg. CAPI inference + bool is_inference_{false}; std::string func_cache_key_; - + // use a weak pointer to keep backend_ alive + // to avoid it to be destropyed too earlier + static std::weak_ptr wp_backend_; + // use mutex to keep it thread safe + static std::mutex ng_mutex_; // ngraph backend eg. CPU - static std::shared_ptr backend_; + std::shared_ptr backend_; // var_name of inputs std::vector var_in_; // var_name of outputs from fetch in order