提交 8923612b 编写于 作者: B baojun 提交者: tensor-tang

NGraph enable parse serialized graph test=develop (#17453)

上级 cf5d271c
...@@ -158,6 +158,8 @@ static void SubstituteNgraphOp( ...@@ -158,6 +158,8 @@ static void SubstituteNgraphOp(
ng_op_desc.SetAttr("interval", interval); ng_op_desc.SetAttr("interval", interval);
ng_op_desc.SetAttr("engine_key", engine_key); ng_op_desc.SetAttr("engine_key", engine_key);
ng_op_desc.SetAttr("graph", block_str); ng_op_desc.SetAttr("graph", block_str);
ng_op_desc.SetInput("Xs", std::vector<std::string>(0));
ng_op_desc.SetOutput("Ys", std::vector<std::string>(0));
ops->erase(ops->begin() + interval[0], ops->begin() + interval[1]); ops->erase(ops->begin() + interval[0], ops->begin() + interval[1]);
ops->insert(ops->begin() + interval[0], ops->insert(ops->begin() + interval[0],
...@@ -223,20 +225,36 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope, ...@@ -223,20 +225,36 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
const platform::Place& place, const platform::Place& place,
const framework::ExecutionContext& ctx) const framework::ExecutionContext& ctx)
: scope_(scope), place_(place) { : scope_(scope), place_(place) {
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key");
var_in_node_map_ = std::make_shared< var_in_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>(); std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
var_node_map_ = std::make_shared< var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>(); std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
GetNgFunction(engine_key, interval); GetNgFunction(ctx);
} }
void NgraphEngine::Prepare(const std::vector<int>& interval) { void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}
auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}
framework::proto::BlockDesc block_proto;
if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph);
framework::BlockDesc block_desc(nullptr, &block_proto);
if (!serialized_graph.empty()) {
NgraphEngine::p_bdesc = &block_desc;
}
bool has_fetch = false, is_full = false; bool has_fetch = false, is_full = false;
for (auto& var : p_bdesc->AllVars()) { for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
...@@ -316,7 +334,15 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) { ...@@ -316,7 +334,15 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
op_state_ = OpState::UNKNOWN; op_state_ = OpState::UNKNOWN;
} }
BuildNgIO(ops_desc, interval); if (var_in_.empty() && var_out_.empty()) {
BuildNgIO(ops_desc, interval);
}
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
var_in_updates_.emplace_back(i);
}
}
} }
void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
...@@ -392,13 +418,6 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc, ...@@ -392,13 +418,6 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
} }
} }
} }
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
var_in_updates_.emplace_back(i);
}
}
} }
void NgraphEngine::GetNgInputShape() { void NgraphEngine::GetNgInputShape() {
...@@ -434,7 +453,6 @@ void NgraphEngine::BuildNgNodes() { ...@@ -434,7 +453,6 @@ void NgraphEngine::BuildNgNodes() {
} }
} }
} }
NgraphBridge ngb(var_node_map_); NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) { for (auto& op : fused_ops_) {
ngb.BuildNgNode(op); ngb.BuildNgNode(op);
...@@ -448,8 +466,8 @@ void NgraphEngine::RunInferShape() { ...@@ -448,8 +466,8 @@ void NgraphEngine::RunInferShape() {
} }
} }
void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) { void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
Prepare(interval); Prepare(ctx);
RunInferShape(); RunInferShape();
GetNgInputShape(); GetNgInputShape();
BuildNgNodes(); BuildNgNodes();
...@@ -472,12 +490,13 @@ void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) { ...@@ -472,12 +490,13 @@ void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
std::make_shared<ngraph::Function>(func_outputs, func_inputs); std::make_shared<ngraph::Function>(func_outputs, func_inputs);
} }
void NgraphEngine::GetNgFunction(std::string engine_key, void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
const std::vector<int>& interval) { auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key");
bool use_cache = true; bool use_cache = true;
if (use_cache) { if (use_cache) {
this->func_cache_key_ = ""; this->func_cache_key_ = "";
for (int i = 0; i < std::min(static_cast<int>(feed_vars.size()), 10); ++i) { for (int i = 0; i < static_cast<int>(feed_vars.size()); ++i) {
auto* var = scope_.FindVar(feed_vars[i]); auto* var = scope_.FindVar(feed_vars[i]);
if (var && var->IsType<framework::LoDTensor>()) { if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
...@@ -507,7 +526,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key, ...@@ -507,7 +526,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
} }
if (engine_cache.find(func_cache_key_) == engine_cache.end()) { if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
BuildNgFunction(interval); BuildNgFunction(ctx);
engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_; engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_;
engine_cache[func_cache_key_].persistables = this->persistables_; engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_; engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
...@@ -516,7 +535,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key, ...@@ -516,7 +535,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
engine_cache[func_cache_key_].is_test = this->is_test_; engine_cache[func_cache_key_].is_test = this->is_test_;
} }
} else { } else {
BuildNgFunction(interval); BuildNgFunction(ctx);
} }
} }
......
...@@ -101,7 +101,7 @@ class NgraphEngine { ...@@ -101,7 +101,7 @@ class NgraphEngine {
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_; var_node_map_;
// prepare info for ngraph engine need // prepare info for ngraph engine need
void Prepare(const std::vector<int>& interval); void Prepare(const framework::ExecutionContext& ctx);
// get ngraph engine input and output list // get ngraph engine input and output list
void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs, void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
const std::vector<int>& interval); const std::vector<int>& interval);
...@@ -112,9 +112,9 @@ class NgraphEngine { ...@@ -112,9 +112,9 @@ class NgraphEngine {
// run paddle RuntimeInferShape to get the tensor shape // run paddle RuntimeInferShape to get the tensor shape
void RunInferShape(); void RunInferShape();
// build ngraph function call // build ngraph function call
void BuildNgFunction(const std::vector<int>& interval); void BuildNgFunction(const framework::ExecutionContext& ctx);
// Check cache for ngraph function or otherwise build the function // Check cache for ngraph function or otherwise build the function
void GetNgFunction(std::string engine_key, const std::vector<int>& interval); void GetNgFunction(const framework::ExecutionContext& ctx);
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册