提交 e2c1b7c3 编写于 作者: B baojun 提交者: tensor-tang

[NGraph] cache compiled function instead test=develop (#17845)

上级 d008260f
......@@ -471,11 +471,11 @@ void NgraphEngine::BuildNgNodes() {
}
}
void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
const framework::ExecutionContext& ctx) {
Prepare(ctx);
GetNgInputShape();
BuildNgNodes();
ngraph_function_ = nullptr;
ngraph::NodeVector func_outputs;
ngraph::ParameterVector func_inputs;
......@@ -490,99 +490,105 @@ void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
func_inputs.emplace_back(prm);
}
ngraph_function_ =
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
return std::make_shared<ngraph::Function>(func_outputs, func_inputs);
}
void NgraphEngine::ClearNgCache() {
auto it = engine_cache.begin();
while (it != engine_cache.end()) {
auto ng_engine = it->second;
backend_->remove_compiled_function(ng_engine.ngraph_handle);
++it;
}
engine_cache.clear();
auto it_tensor = t_in_cache_.begin();
while (it_tensor != t_in_cache_.end()) {
auto t_vec = it_tensor->second;
for (auto t_in : t_vec) {
t_in.reset();
}
++it_tensor;
}
t_in_cache_.clear();
}
void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key");
// set to flase, to debug cache or recompile everytime.
bool use_cache = true;
if (use_cache) {
this->func_cache_key_ = "";
for (int i = 0; i < static_cast<int>(feed_vars.size()); ++i) {
auto* var = scope_.FindVar(feed_vars[i]);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto dims = tensor_pd->dims();
for (int j = 0; j < dims.size(); ++j) {
func_cache_key_ += std::to_string(dims[j]);
}
if (!use_cache) ClearNgCache();
this->func_cache_key_ = "";
for (int i = 0; i < static_cast<int>(feed_vars.size()); ++i) {
auto* var = scope_.FindVar(feed_vars[i]);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto dims = tensor_pd->dims();
for (int j = 0; j < dims.size(); ++j) {
func_cache_key_ += std::to_string(dims[j]);
}
}
func_cache_key_ += std::to_string(interval[0]) + "_" +
std::to_string(interval[1]) + engine_key;
func_cache_key_ = std::to_string(std::hash<std::string>()(func_cache_key_));
if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) {
engine_cache.clear();
t_in_cache_.clear();
} else {
auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) {
engine_cache.clear();
t_in_cache_.clear();
}
pre_var_ptr = var;
}
func_cache_key_ += std::to_string(interval[0]) + "_" +
std::to_string(interval[1]) + engine_key;
func_cache_key_ = std::to_string(std::hash<std::string>()(func_cache_key_));
if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) {
ClearNgCache();
} else {
auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) {
ClearNgCache();
}
pre_var_ptr = var;
}
}
if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
BuildNgFunction(ctx);
engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_;
engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
engine_cache[func_cache_key_].var_in = this->var_in_;
engine_cache[func_cache_key_].var_out = this->var_out_;
engine_cache[func_cache_key_].is_test = this->is_test_;
if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
if (engine_cache.size() > 5) ClearNgCache();
auto func = BuildNgFunction(ctx);
// Due to optimization backend may produce results in other layouts,
// make sure we get default layout for results.
for (auto& r : func->get_results()) {
r->set_needs_default_layout(true);
}
} else {
BuildNgFunction(ctx);
engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func);
engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
engine_cache[func_cache_key_].var_in = this->var_in_;
engine_cache[func_cache_key_].var_out = this->var_out_;
engine_cache[func_cache_key_].is_test = this->is_test_;
}
}
void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const {
std::shared_ptr<ngraph::Function> ng_func;
std::shared_ptr<ngraph::runtime::Executable> ng_handle;
const std::set<std::string>* p_persistables;
const std::vector<size_t>* p_var_in_updates;
const std::vector<std::string>* p_var_in;
const std::vector<std::string>* p_var_out;
bool is_test;
bool use_cache = true;
if (use_cache) {
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function");
ng_func = engine_cache[func_cache_key_].ngraph_function;
p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test;
} else {
ng_func = ngraph_function_;
p_persistables = &this->persistables_;
p_var_in_updates = &this->var_in_updates_;
p_var_in = &this->var_in_;
p_var_out = &this->var_out_;
is_test = this->is_test_;
}
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
"Cannot find cached data to run ngraph function");
ng_handle = engine_cache[func_cache_key_].ngraph_handle;
p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {};
auto m_parameters = ng_func->get_parameters();
auto m_results = ng_func->get_results();
// Due to optimization backend may produce results in other layouts,
// make sure we get default layout for results.
for (auto& r : m_results) {
r->set_needs_default_layout(true);
}
if (is_test && use_cache &&
t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
auto m_parameters = ng_handle->get_parameters();
auto m_results = ng_handle->get_results();
if (is_test && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
p_t_in = &(t_in_cache_[func_cache_key_]);
for (size_t i = 0; i < p_var_in_updates->size(); ++i) {
int index = p_var_in_updates->at(i);
......@@ -601,7 +607,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
}
}
} else {
if (is_test && use_cache) {
if (is_test) {
p_t_in = &(t_in_cache_[func_cache_key_]);
} else {
p_t_in = &t_in;
......@@ -664,8 +670,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
}
}
auto handle = backend_->compile(ng_func);
handle->call_with_validate(t_out, *p_t_in);
ng_handle->call(t_out, *p_t_in);
} // NgraphEngine::Run
} // namespace operators
} // namespace paddle
......@@ -40,7 +40,7 @@ enum class OpState { /* nGraph support state on ops */
// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::Function> ngraph_function;
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
std::set<std::string> persistables;
std::vector<std::string> var_in;
std::vector<std::string> var_out;
......@@ -84,8 +84,6 @@ class NgraphEngine {
// ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_;
// ngraph function to call and execute
std::shared_ptr<ngraph::Function> ngraph_function_;
// var_name of inputs
std::vector<std::string> var_in_;
// var_name of outputs from fetch in order
......@@ -110,7 +108,10 @@ class NgraphEngine {
// Call ngraph bridge to map ops
void BuildNgNodes();
// build ngraph function call
void BuildNgFunction(const framework::ExecutionContext& ctx);
std::shared_ptr<ngraph::Function> BuildNgFunction(
const framework::ExecutionContext& ctx);
// clear ngraph engine cache and t_in cache
void ClearNgCache();
// Check cache for ngraph function or otherwise build the function
void GetNgFunction(const framework::ExecutionContext& ctx);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册