From e2c1b7c354c9ca6720228f1b55bfc4c6a6f82292 Mon Sep 17 00:00:00 2001 From: baojun <32073718+baojun-nervana@users.noreply.github.com> Date: Thu, 6 Jun 2019 00:47:12 -0700 Subject: [PATCH] [NGraph] cache compiled function instead test=develop (#17845) --- .../fluid/operators/ngraph/ngraph_engine.cc | 147 +++++++++--------- paddle/fluid/operators/ngraph/ngraph_engine.h | 9 +- 2 files changed, 81 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 19d30a6f83..ae87687e34 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -471,11 +471,11 @@ void NgraphEngine::BuildNgNodes() { } } -void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) { +std::shared_ptr NgraphEngine::BuildNgFunction( + const framework::ExecutionContext& ctx) { Prepare(ctx); GetNgInputShape(); BuildNgNodes(); - ngraph_function_ = nullptr; ngraph::NodeVector func_outputs; ngraph::ParameterVector func_inputs; @@ -490,99 +490,105 @@ void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) { func_inputs.emplace_back(prm); } - ngraph_function_ = - std::make_shared(func_outputs, func_inputs); + return std::make_shared(func_outputs, func_inputs); +} + +void NgraphEngine::ClearNgCache() { + auto it = engine_cache.begin(); + while (it != engine_cache.end()) { + auto ng_engine = it->second; + backend_->remove_compiled_function(ng_engine.ngraph_handle); + ++it; + } + engine_cache.clear(); + auto it_tensor = t_in_cache_.begin(); + while (it_tensor != t_in_cache_.end()) { + auto t_vec = it_tensor->second; + for (auto t_in : t_vec) { + t_in.reset(); + } + ++it_tensor; + } + t_in_cache_.clear(); } void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { auto interval = ctx.Attr>("interval"); std::string engine_key = ctx.Attr("engine_key"); + + // set to flase, to debug cache or recompile everytime. bool use_cache = true; - if (use_cache) { - this->func_cache_key_ = ""; - for (int i = 0; i < static_cast(feed_vars.size()); ++i) { - auto* var = scope_.FindVar(feed_vars[i]); - if (var && var->IsType()) { - auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); - auto dims = tensor_pd->dims(); - for (int j = 0; j < dims.size(); ++j) { - func_cache_key_ += std::to_string(dims[j]); - } + if (!use_cache) ClearNgCache(); + + this->func_cache_key_ = ""; + for (int i = 0; i < static_cast(feed_vars.size()); ++i) { + auto* var = scope_.FindVar(feed_vars[i]); + if (var && var->IsType()) { + auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); + auto dims = tensor_pd->dims(); + for (int j = 0; j < dims.size(); ++j) { + func_cache_key_ += std::to_string(dims[j]); } } - func_cache_key_ += std::to_string(interval[0]) + "_" + - std::to_string(interval[1]) + engine_key; - func_cache_key_ = std::to_string(std::hash()(func_cache_key_)); - - if (engine_cache.find(func_cache_key_) != engine_cache.end()) { - if (engine_cache[func_cache_key_].persistables.size() == 0) { - engine_cache.clear(); - t_in_cache_.clear(); - } else { - auto var_name = engine_cache[func_cache_key_].persistables.begin(); - framework::Variable* var = scope_.FindVar(*var_name); - if (var != pre_var_ptr) { - engine_cache.clear(); - t_in_cache_.clear(); - } - pre_var_ptr = var; + } + func_cache_key_ += std::to_string(interval[0]) + "_" + + std::to_string(interval[1]) + engine_key; + func_cache_key_ = std::to_string(std::hash()(func_cache_key_)); + + if (engine_cache.find(func_cache_key_) != engine_cache.end()) { + if (engine_cache[func_cache_key_].persistables.size() == 0) { + ClearNgCache(); + } else { + auto var_name = engine_cache[func_cache_key_].persistables.begin(); + framework::Variable* var = scope_.FindVar(*var_name); + if (var != pre_var_ptr) { + ClearNgCache(); } + pre_var_ptr = var; } + } - if (engine_cache.find(func_cache_key_) == engine_cache.end()) { - BuildNgFunction(ctx); - engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_; - engine_cache[func_cache_key_].persistables = this->persistables_; - engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_; - engine_cache[func_cache_key_].var_in = this->var_in_; - engine_cache[func_cache_key_].var_out = this->var_out_; - engine_cache[func_cache_key_].is_test = this->is_test_; + if (engine_cache.find(func_cache_key_) == engine_cache.end()) { + if (engine_cache.size() > 5) ClearNgCache(); + auto func = BuildNgFunction(ctx); + // Due to optimization backend may produce results in other layouts, + // make sure we get default layout for results. + for (auto& r : func->get_results()) { + r->set_needs_default_layout(true); } - } else { - BuildNgFunction(ctx); + engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func); + engine_cache[func_cache_key_].persistables = this->persistables_; + engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_; + engine_cache[func_cache_key_].var_in = this->var_in_; + engine_cache[func_cache_key_].var_out = this->var_out_; + engine_cache[func_cache_key_].is_test = this->is_test_; } } void NgraphEngine::Run(const framework::Scope& scope, const platform::Place& place) const { - std::shared_ptr ng_func; + std::shared_ptr ng_handle; const std::set* p_persistables; const std::vector* p_var_in_updates; const std::vector* p_var_in; const std::vector* p_var_out; bool is_test; - bool use_cache = true; - if (use_cache) { - PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(), - "Cannot find cached data to run ngraph function"); - ng_func = engine_cache[func_cache_key_].ngraph_function; - p_persistables = &(engine_cache[func_cache_key_].persistables); - p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates); - p_var_in = &(engine_cache[func_cache_key_].var_in); - p_var_out = &(engine_cache[func_cache_key_].var_out); - is_test = engine_cache[func_cache_key_].is_test; - } else { - ng_func = ngraph_function_; - p_persistables = &this->persistables_; - p_var_in_updates = &this->var_in_updates_; - p_var_in = &this->var_in_; - p_var_out = &this->var_out_; - is_test = this->is_test_; - } + PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(), + "Cannot find cached data to run ngraph function"); + ng_handle = engine_cache[func_cache_key_].ngraph_handle; + p_persistables = &(engine_cache[func_cache_key_].persistables); + p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates); + p_var_in = &(engine_cache[func_cache_key_].var_in); + p_var_out = &(engine_cache[func_cache_key_].var_out); + is_test = engine_cache[func_cache_key_].is_test; std::vector>* p_t_in; std::vector> t_in = {}; - auto m_parameters = ng_func->get_parameters(); - auto m_results = ng_func->get_results(); - // Due to optimization backend may produce results in other layouts, - // make sure we get default layout for results. - for (auto& r : m_results) { - r->set_needs_default_layout(true); - } - if (is_test && use_cache && - t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) { + auto m_parameters = ng_handle->get_parameters(); + auto m_results = ng_handle->get_results(); + if (is_test && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) { p_t_in = &(t_in_cache_[func_cache_key_]); for (size_t i = 0; i < p_var_in_updates->size(); ++i) { int index = p_var_in_updates->at(i); @@ -601,7 +607,7 @@ void NgraphEngine::Run(const framework::Scope& scope, } } } else { - if (is_test && use_cache) { + if (is_test) { p_t_in = &(t_in_cache_[func_cache_key_]); } else { p_t_in = &t_in; @@ -664,8 +670,7 @@ void NgraphEngine::Run(const framework::Scope& scope, } } - auto handle = backend_->compile(ng_func); - handle->call_with_validate(t_out, *p_t_in); + ng_handle->call(t_out, *p_t_in); } // NgraphEngine::Run } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index 885b738b95..4cb1465371 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -40,7 +40,7 @@ enum class OpState { /* nGraph support state on ops */ // cache engine repetitives struct EngineCache { - std::shared_ptr ngraph_function; + std::shared_ptr ngraph_handle; std::set persistables; std::vector var_in; std::vector var_out; @@ -84,8 +84,6 @@ class NgraphEngine { // ngraph backend eg. CPU static std::shared_ptr backend_; - // ngraph function to call and execute - std::shared_ptr ngraph_function_; // var_name of inputs std::vector var_in_; // var_name of outputs from fetch in order @@ -110,7 +108,10 @@ class NgraphEngine { // Call ngraph bridge to map ops void BuildNgNodes(); // build ngraph function call - void BuildNgFunction(const framework::ExecutionContext& ctx); + std::shared_ptr BuildNgFunction( + const framework::ExecutionContext& ctx); + // clear ngraph engine cache and t_in cache + void ClearNgCache(); // Check cache for ngraph function or otherwise build the function void GetNgFunction(const framework::ExecutionContext& ctx); }; -- GitLab