提交 e2c1b7c3 编写于 作者: B baojun 提交者: tensor-tang

[NGraph] cache compiled function instead test=develop (#17845)

上级 d008260f
...@@ -471,11 +471,11 @@ void NgraphEngine::BuildNgNodes() { ...@@ -471,11 +471,11 @@ void NgraphEngine::BuildNgNodes() {
} }
} }
void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) { std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
const framework::ExecutionContext& ctx) {
Prepare(ctx); Prepare(ctx);
GetNgInputShape(); GetNgInputShape();
BuildNgNodes(); BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs; ngraph::NodeVector func_outputs;
ngraph::ParameterVector func_inputs; ngraph::ParameterVector func_inputs;
...@@ -490,15 +490,36 @@ void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) { ...@@ -490,15 +490,36 @@ void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
func_inputs.emplace_back(prm); func_inputs.emplace_back(prm);
} }
ngraph_function_ = return std::make_shared<ngraph::Function>(func_outputs, func_inputs);
std::make_shared<ngraph::Function>(func_outputs, func_inputs); }
void NgraphEngine::ClearNgCache() {
auto it = engine_cache.begin();
while (it != engine_cache.end()) {
auto ng_engine = it->second;
backend_->remove_compiled_function(ng_engine.ngraph_handle);
++it;
}
engine_cache.clear();
auto it_tensor = t_in_cache_.begin();
while (it_tensor != t_in_cache_.end()) {
auto t_vec = it_tensor->second;
for (auto t_in : t_vec) {
t_in.reset();
}
++it_tensor;
}
t_in_cache_.clear();
} }
void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval"); auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key"); std::string engine_key = ctx.Attr<std::string>("engine_key");
// set to flase, to debug cache or recompile everytime.
bool use_cache = true; bool use_cache = true;
if (use_cache) { if (!use_cache) ClearNgCache();
this->func_cache_key_ = ""; this->func_cache_key_ = "";
for (int i = 0; i < static_cast<int>(feed_vars.size()); ++i) { for (int i = 0; i < static_cast<int>(feed_vars.size()); ++i) {
auto* var = scope_.FindVar(feed_vars[i]); auto* var = scope_.FindVar(feed_vars[i]);
...@@ -516,73 +537,58 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { ...@@ -516,73 +537,58 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
if (engine_cache.find(func_cache_key_) != engine_cache.end()) { if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) { if (engine_cache[func_cache_key_].persistables.size() == 0) {
engine_cache.clear(); ClearNgCache();
t_in_cache_.clear();
} else { } else {
auto var_name = engine_cache[func_cache_key_].persistables.begin(); auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name); framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) { if (var != pre_var_ptr) {
engine_cache.clear(); ClearNgCache();
t_in_cache_.clear();
} }
pre_var_ptr = var; pre_var_ptr = var;
} }
} }
if (engine_cache.find(func_cache_key_) == engine_cache.end()) { if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
BuildNgFunction(ctx); if (engine_cache.size() > 5) ClearNgCache();
engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_; auto func = BuildNgFunction(ctx);
// Due to optimization backend may produce results in other layouts,
// make sure we get default layout for results.
for (auto& r : func->get_results()) {
r->set_needs_default_layout(true);
}
engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func);
engine_cache[func_cache_key_].persistables = this->persistables_; engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_; engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
engine_cache[func_cache_key_].var_in = this->var_in_; engine_cache[func_cache_key_].var_in = this->var_in_;
engine_cache[func_cache_key_].var_out = this->var_out_; engine_cache[func_cache_key_].var_out = this->var_out_;
engine_cache[func_cache_key_].is_test = this->is_test_; engine_cache[func_cache_key_].is_test = this->is_test_;
} }
} else {
BuildNgFunction(ctx);
}
} }
void NgraphEngine::Run(const framework::Scope& scope, void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
std::shared_ptr<ngraph::Function> ng_func; std::shared_ptr<ngraph::runtime::Executable> ng_handle;
const std::set<std::string>* p_persistables; const std::set<std::string>* p_persistables;
const std::vector<size_t>* p_var_in_updates; const std::vector<size_t>* p_var_in_updates;
const std::vector<std::string>* p_var_in; const std::vector<std::string>* p_var_in;
const std::vector<std::string>* p_var_out; const std::vector<std::string>* p_var_out;
bool is_test; bool is_test;
bool use_cache = true;
if (use_cache) {
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(), PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function"); "Cannot find cached data to run ngraph function");
ng_func = engine_cache[func_cache_key_].ngraph_function; ng_handle = engine_cache[func_cache_key_].ngraph_handle;
p_persistables = &(engine_cache[func_cache_key_].persistables); p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates); p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in); p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out); p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test; is_test = engine_cache[func_cache_key_].is_test;
} else {
ng_func = ngraph_function_;
p_persistables = &this->persistables_;
p_var_in_updates = &this->var_in_updates_;
p_var_in = &this->var_in_;
p_var_out = &this->var_out_;
is_test = this->is_test_;
}
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in; std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {}; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {};
auto m_parameters = ng_func->get_parameters(); auto m_parameters = ng_handle->get_parameters();
auto m_results = ng_func->get_results(); auto m_results = ng_handle->get_results();
// Due to optimization backend may produce results in other layouts, if (is_test && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
// make sure we get default layout for results.
for (auto& r : m_results) {
r->set_needs_default_layout(true);
}
if (is_test && use_cache &&
t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
p_t_in = &(t_in_cache_[func_cache_key_]); p_t_in = &(t_in_cache_[func_cache_key_]);
for (size_t i = 0; i < p_var_in_updates->size(); ++i) { for (size_t i = 0; i < p_var_in_updates->size(); ++i) {
int index = p_var_in_updates->at(i); int index = p_var_in_updates->at(i);
...@@ -601,7 +607,7 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -601,7 +607,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
} }
} }
} else { } else {
if (is_test && use_cache) { if (is_test) {
p_t_in = &(t_in_cache_[func_cache_key_]); p_t_in = &(t_in_cache_[func_cache_key_]);
} else { } else {
p_t_in = &t_in; p_t_in = &t_in;
...@@ -664,8 +670,7 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -664,8 +670,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
} }
} }
auto handle = backend_->compile(ng_func); ng_handle->call(t_out, *p_t_in);
handle->call_with_validate(t_out, *p_t_in);
} // NgraphEngine::Run } // NgraphEngine::Run
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -40,7 +40,7 @@ enum class OpState { /* nGraph support state on ops */ ...@@ -40,7 +40,7 @@ enum class OpState { /* nGraph support state on ops */
// cache engine repetitives // cache engine repetitives
struct EngineCache { struct EngineCache {
std::shared_ptr<ngraph::Function> ngraph_function; std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
std::set<std::string> persistables; std::set<std::string> persistables;
std::vector<std::string> var_in; std::vector<std::string> var_in;
std::vector<std::string> var_out; std::vector<std::string> var_out;
...@@ -84,8 +84,6 @@ class NgraphEngine { ...@@ -84,8 +84,6 @@ class NgraphEngine {
// ngraph backend eg. CPU // ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_; static std::shared_ptr<ngraph::runtime::Backend> backend_;
// ngraph function to call and execute
std::shared_ptr<ngraph::Function> ngraph_function_;
// var_name of inputs // var_name of inputs
std::vector<std::string> var_in_; std::vector<std::string> var_in_;
// var_name of outputs from fetch in order // var_name of outputs from fetch in order
...@@ -110,7 +108,10 @@ class NgraphEngine { ...@@ -110,7 +108,10 @@ class NgraphEngine {
// Call ngraph bridge to map ops // Call ngraph bridge to map ops
void BuildNgNodes(); void BuildNgNodes();
// build ngraph function call // build ngraph function call
void BuildNgFunction(const framework::ExecutionContext& ctx); std::shared_ptr<ngraph::Function> BuildNgFunction(
const framework::ExecutionContext& ctx);
// clear ngraph engine cache and t_in cache
void ClearNgCache();
// Check cache for ngraph function or otherwise build the function // Check cache for ngraph function or otherwise build the function
void GetNgFunction(const framework::ExecutionContext& ctx); void GetNgFunction(const framework::ExecutionContext& ctx);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册