提交 6421c61a 编写于 作者: B baojun 提交者: Tao Luo

Update ngraph engine for multiple threading (#19155)

* update for multiple threading
test=develop

* remove PADDLE_ENFORCE test=develop
上级 e26411ce
......@@ -72,18 +72,14 @@ static std::map<ngraph::element::Type, framework::proto::VarType::Type>
{ngraph::element::boolean, framework::proto::VarType::BOOL}};
std::vector<std::string> NgraphEngine::feed_vars = {};
std::vector<std::string> NgraphEngine::fetch_vars = {};
framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
bool NgraphEngine::is_training = false;
std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
std::weak_ptr<ngraph::runtime::Backend> NgraphEngine::wp_backend_;
std::mutex NgraphEngine::ng_mutex_;
static std::vector<std::vector<int>> NgraphOpIntervals(
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::feed_vars.clear();
NgraphEngine::fetch_vars.clear();
std::vector<std::vector<int>> intervals;
int size = ops->size();
......@@ -118,11 +114,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
for (auto& var_name_item : ops->at(index)->Inputs()) {
for (auto& var_name : var_name_item.second) {
NgraphEngine::fetch_vars.emplace_back(var_name);
}
}
++index;
}
......@@ -167,16 +158,22 @@ static void SubstituteNgraphOp(
framework::OpRegistry::CreateOp(ng_op_desc));
}
std::string SerializedBlock(const std::vector<framework::OpDesc*>& op_descs) {
std::string SerializedBlock(const framework::BlockDesc& bdesc) {
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto* op_desc : op_descs) {
for (auto& op_desc : bdesc.AllOps()) {
auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto();
}
auto* vars = block_desc.Proto()->mutable_vars();
for (auto& var_desc : bdesc.AllVars()) {
*vars->Add() = *var_desc->Proto();
}
return block_desc.Proto()->SerializeAsString();
}
......@@ -213,12 +210,12 @@ std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs,
void NgraphEngine::FuseNgraphOps(
const framework::BlockDesc& block_desc,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
NgraphEngine::p_bdesc = &block_desc;
auto intervals = NgraphOpIntervals(ops);
std::string serialized_block = SerializedBlock(block_desc);
std::string engine_key =
GenerateEngineKey(feed_vars, fetch_vars, ops->size());
std::to_string(std::hash<std::string>()(serialized_block));
for (auto it = intervals.rbegin(); it != intervals.rend(); ++it) {
SubstituteNgraphOp(ops, engine_key, "", *it);
SubstituteNgraphOp(ops, engine_key, serialized_block, *it);
}
}
......@@ -232,6 +229,20 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
std::lock_guard<std::mutex> lock(ng_mutex_);
if (!wp_backend_.lock()) {
try {
VLOG(3) << "ngraph creating CPU backend.";
backend_ = ngraph::runtime::Backend::create("CPU");
} catch (...) {
PADDLE_THROW("Unsupported nGraph backend");
}
wp_backend_ = backend_;
} else {
backend_ = wp_backend_.lock();
}
GetNgFunction(ctx);
}
......@@ -239,24 +250,11 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}
auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}
framework::proto::BlockDesc block_proto;
if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph);
framework::BlockDesc block_desc(nullptr, &block_proto);
if (!serialized_graph.empty()) {
NgraphEngine::p_bdesc = &block_desc;
}
for (auto& var : p_bdesc->AllVars()) {
for (auto& var : block_desc.AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY)) {
......@@ -284,10 +282,9 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
}
std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) {
for (auto op_desc : block_desc.AllOps()) {
ops_desc.emplace_back(op_desc);
if (op_desc->Type().find("_grad") != std::string::npos) {
is_training = true;
this->is_test_ = false;
}
}
......@@ -298,8 +295,7 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx;
}
while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() != framework::kFetchOpType) {
while (idx < static_cast<int>(ops_desc.size())) {
auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) {
......@@ -309,9 +305,21 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
++idx;
}
auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}
auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}
if (var_in_.empty() && var_out_.empty()) {
BuildNgIO(ops_desc, interval);
}
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
......@@ -324,6 +332,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
const std::vector<int>& interval) {
std::unordered_set<std::string> inputs;
std::unordered_set<std::string> outputs;
for (int i = interval[0]; i < interval[1]; ++i) {
auto op = ops_desc[i];
for (auto& var_name_item : op->Inputs()) {
......@@ -359,15 +368,11 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
op->Type());
for (auto& var_name : var_name_item.second) {
if (this->is_test_) {
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end()) {
if (post_op_inputs_.find(var_name) != post_op_inputs_.end()) {
this->var_out_.emplace_back(var_name);
}
} else {
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
fetch_vars.end() ||
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
persistables_.find(var_name) != persistables_.end()) {
this->var_out_.emplace_back(var_name);
}
......@@ -434,10 +439,14 @@ std::shared_ptr<ngraph::Function> NgraphEngine::BuildNgFunction(
ngraph::ParameterVector func_inputs;
for (auto& vo : var_out_) {
PADDLE_ENFORCE_GT(var_node_map_->count(vo), 0,
"Cannot find vo %s in var_node_map_", vo);
func_outputs.emplace_back(var_node_map_->at(vo));
}
for (auto& vi : var_in_) {
PADDLE_ENFORCE_GT(var_node_map_->count(vi), 0,
"Cannot find vi %s in var_node_map_", vi);
std::shared_ptr<ngraph::op::Parameter> prm =
std::dynamic_pointer_cast<ngraph::op::Parameter>(
var_in_node_map_->at(vi));
......@@ -454,7 +463,8 @@ void NgraphEngine::ClearNgCache() {
auto it = engine_cache.begin();
while (it != engine_cache.end()) {
auto ng_engine = it->second;
backend_->remove_compiled_function(ng_engine.ngraph_handle);
ng_engine.ngraph_backend->remove_compiled_function(ng_engine.ngraph_handle);
ng_engine.ngraph_backend.reset();
++it;
}
engine_cache.clear();
......@@ -497,13 +507,6 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
if (engine_cache.find(func_cache_key_) != engine_cache.end()) {
if (engine_cache[func_cache_key_].persistables.size() == 0) {
ClearNgCache();
} else {
auto var_name = engine_cache[func_cache_key_].persistables.begin();
framework::Variable* var = scope_.FindVar(*var_name);
if (var != pre_var_ptr) {
ClearNgCache();
}
pre_var_ptr = var;
}
}
......@@ -515,6 +518,7 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
for (auto& r : func->get_results()) {
r->set_needs_default_layout(true);
}
engine_cache[func_cache_key_].ngraph_backend = backend_;
engine_cache[func_cache_key_].ngraph_handle = backend_->compile(func);
engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
......@@ -526,31 +530,32 @@ void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
void NgraphEngine::Run(const framework::Scope& scope,
const platform::Place& place) const {
VLOG(3) << "NgraphEngine Run ...";
std::shared_ptr<ngraph::runtime::Executable> ng_handle;
std::shared_ptr<ngraph::runtime::Backend> ng_backend;
const std::set<std::string>* p_persistables;
const std::vector<size_t>* p_var_in_updates;
const std::vector<std::string>* p_var_in;
const std::vector<std::string>* p_var_out;
bool is_test;
auto& engine_cache = main_engine_cache::fetch();
auto& t_in_cache_ = main_t_in_cache::fetch();
PADDLE_ENFORCE(engine_cache.find(func_cache_key_) != engine_cache.end(),
PADDLE_ENFORCE_GT(engine_cache.count(func_cache_key_), 0,
"Cannot find cached data to run ngraph function");
ng_handle = engine_cache[func_cache_key_].ngraph_handle;
ng_backend = engine_cache[func_cache_key_].ngraph_backend;
p_persistables = &(engine_cache[func_cache_key_].persistables);
p_var_in_updates = &(engine_cache[func_cache_key_].var_in_updates);
p_var_in = &(engine_cache[func_cache_key_].var_in);
p_var_out = &(engine_cache[func_cache_key_].var_out);
is_test = engine_cache[func_cache_key_].is_test;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>* p_t_in;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_in = {};
auto m_parameters = ng_handle->get_parameters();
auto m_results = ng_handle->get_results();
if (is_test && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
if (is_inference_ && t_in_cache_.find(func_cache_key_) != t_in_cache_.end()) {
p_t_in = &(t_in_cache_[func_cache_key_]);
for (size_t i = 0; i < p_var_in_updates->size(); ++i) {
int index = p_var_in_updates->at(i);
......@@ -562,14 +567,14 @@ void NgraphEngine::Run(const framework::Scope& scope,
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
ti = backend_->create_tensor(ng_type, sp, pd_arr);
ti = ng_backend->create_tensor(ng_type, sp, pd_arr);
(*p_t_in)[index] = ti;
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
}
} else {
if (is_test) {
if (is_inference_) {
p_t_in = &(t_in_cache_[func_cache_key_]);
} else {
p_t_in = &t_in;
......@@ -584,15 +589,13 @@ void NgraphEngine::Run(const framework::Scope& scope,
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetMutableLoDTensorOrSelectedRowsValueFromVar(var);
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
PADDLE_ENFORCE(sp == Ddim2Shape(tensor_pd->dims()),
"Ensure ngraph tensor layout align with paddle tensor");
ti = backend_->create_tensor(ng_type, sp, pd_arr);
ti = ng_backend->create_tensor(ng_type, sp, pd_arr);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vi);
}
bool is_persistable =
(p_persistables->find(vi) != p_persistables->end()) ? true : false;
if (!is_training && is_test && is_persistable) {
if (is_inference_ && is_persistable) {
ti->set_stale(false);
}
(*p_t_in).emplace_back(ti);
......@@ -615,7 +618,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
auto ng_type = m_results[i]->get_element_type();
void* pd_arr = tensor_pd->mutable_data(place, ng2pd_type_map[ng_type]);
std::shared_ptr<ngraph::runtime::Tensor> to =
backend_->create_tensor(ng_type, sp, pd_arr);
ng_backend->create_tensor(ng_type, sp, pd_arr);
t_out.emplace_back(to);
} else {
PADDLE_THROW("Cannot find var or tensor with var name %s", vo);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <list>
#include <memory>
#include <mutex> //NOLINT
#include <set>
#include <string>
#include <unordered_map>
......@@ -34,7 +35,8 @@ namespace operators {
// cache engine repetitives
struct EngineCache {
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
std::shared_ptr<ngraph::runtime::Executable> ngraph_handle = nullptr;
std::shared_ptr<ngraph::runtime::Backend> ngraph_backend = nullptr;
std::set<std::string> persistables;
std::vector<std::string> var_in;
std::vector<std::string> var_out;
......@@ -127,9 +129,7 @@ class NgraphEngine {
void Run(const framework::Scope& scope, const platform::Place& place) const;
static bool is_training;
static const framework::BlockDesc* p_bdesc;
static std::vector<std::string> feed_vars, fetch_vars;
static std::vector<std::string> feed_vars;
static void FuseNgraphOps(
const framework::BlockDesc& prog,
......@@ -149,19 +149,24 @@ class NgraphEngine {
using main_t_in_cache =
ThCache<std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>;
static framework::Variable* pre_var_ptr;
const framework::Scope& scope_;
const platform::Place& place_;
std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
std::set<std::string> persistables_;
std::unordered_set<std::string> post_op_inputs_;
// it is test for a single run, it can be a validation during training
bool is_test_{true};
// inference only. eg. CAPI inference
bool is_inference_{false};
std::string func_cache_key_;
// use a weak pointer to keep backend_ alive
// to avoid it to be destropyed too earlier
static std::weak_ptr<ngraph::runtime::Backend> wp_backend_;
// use mutex to keep it thread safe
static std::mutex ng_mutex_;
// ngraph backend eg. CPU
static std::shared_ptr<ngraph::runtime::Backend> backend_;
std::shared_ptr<ngraph::runtime::Backend> backend_;
// var_name of inputs
std::vector<std::string> var_in_;
// var_name of outputs from fetch in order
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册