From 8923612b102177bcc58b1933c23afdce0cf3eb3f Mon Sep 17 00:00:00 2001 From: baojun <32073718+baojun-nervana@users.noreply.github.com> Date: Tue, 21 May 2019 21:52:30 -0700 Subject: [PATCH] NGraph enable parse serialized graph test=develop (#17453) --- .../fluid/operators/ngraph/ngraph_engine.cc | 63 ++++++++++++------- paddle/fluid/operators/ngraph/ngraph_engine.h | 6 +- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index c2dce51fe54..2486ae6bb5c 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -158,6 +158,8 @@ static void SubstituteNgraphOp( ng_op_desc.SetAttr("interval", interval); ng_op_desc.SetAttr("engine_key", engine_key); ng_op_desc.SetAttr("graph", block_str); + ng_op_desc.SetInput("Xs", std::vector(0)); + ng_op_desc.SetOutput("Ys", std::vector(0)); ops->erase(ops->begin() + interval[0], ops->begin() + interval[1]); ops->insert(ops->begin() + interval[0], @@ -223,20 +225,36 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope, const platform::Place& place, const framework::ExecutionContext& ctx) : scope_(scope), place_(place) { - std::string serialized_graph = ctx.Attr("graph"); - auto interval = ctx.Attr>("interval"); - std::string engine_key = ctx.Attr("engine_key"); - var_in_node_map_ = std::make_shared< std::unordered_map>>(); var_node_map_ = std::make_shared< std::unordered_map>>(); - GetNgFunction(engine_key, interval); + GetNgFunction(ctx); } -void NgraphEngine::Prepare(const std::vector& interval) { +void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) { + auto interval = ctx.Attr>("interval"); + std::string serialized_graph = ctx.Attr("graph"); + + auto input_vars = ctx.Inputs("Xs"); + if (!input_vars.empty()) { + feed_vars = input_vars; + var_in_ = input_vars; + } + auto output_vars = ctx.Outputs("Ys"); + if (!output_vars.empty()) { + var_out_ = output_vars; + } + + framework::proto::BlockDesc block_proto; + if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph); + framework::BlockDesc block_desc(nullptr, &block_proto); + if (!serialized_graph.empty()) { + NgraphEngine::p_bdesc = &block_desc; + } + bool has_fetch = false, is_full = false; for (auto& var : p_bdesc->AllVars()) { if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || @@ -316,7 +334,15 @@ void NgraphEngine::Prepare(const std::vector& interval) { op_state_ = OpState::UNKNOWN; } - BuildNgIO(ops_desc, interval); + if (var_in_.empty() && var_out_.empty()) { + BuildNgIO(ops_desc, interval); + } + for (size_t i = 0; i < var_in_.size(); ++i) { + auto var_name = var_in_[i]; + if (persistables_.find(var_name) == persistables_.end()) { + var_in_updates_.emplace_back(i); + } + } } void NgraphEngine::BuildNgIO(const std::vector& ops_desc, @@ -392,13 +418,6 @@ void NgraphEngine::BuildNgIO(const std::vector& ops_desc, } } } - - for (size_t i = 0; i < var_in_.size(); ++i) { - auto var_name = var_in_[i]; - if (persistables_.find(var_name) == persistables_.end()) { - var_in_updates_.emplace_back(i); - } - } } void NgraphEngine::GetNgInputShape() { @@ -434,7 +453,6 @@ void NgraphEngine::BuildNgNodes() { } } } - NgraphBridge ngb(var_node_map_); for (auto& op : fused_ops_) { ngb.BuildNgNode(op); @@ -448,8 +466,8 @@ void NgraphEngine::RunInferShape() { } } -void NgraphEngine::BuildNgFunction(const std::vector& interval) { - Prepare(interval); +void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) { + Prepare(ctx); RunInferShape(); GetNgInputShape(); BuildNgNodes(); @@ -472,12 +490,13 @@ void NgraphEngine::BuildNgFunction(const std::vector& interval) { std::make_shared(func_outputs, func_inputs); } -void NgraphEngine::GetNgFunction(std::string engine_key, - const std::vector& interval) { +void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) { + auto interval = ctx.Attr>("interval"); + std::string engine_key = ctx.Attr("engine_key"); bool use_cache = true; if (use_cache) { this->func_cache_key_ = ""; - for (int i = 0; i < std::min(static_cast(feed_vars.size()), 10); ++i) { + for (int i = 0; i < static_cast(feed_vars.size()); ++i) { auto* var = scope_.FindVar(feed_vars[i]); if (var && var->IsType()) { auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); @@ -507,7 +526,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key, } if (engine_cache.find(func_cache_key_) == engine_cache.end()) { - BuildNgFunction(interval); + BuildNgFunction(ctx); engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_; engine_cache[func_cache_key_].persistables = this->persistables_; engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_; @@ -516,7 +535,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key, engine_cache[func_cache_key_].is_test = this->is_test_; } } else { - BuildNgFunction(interval); + BuildNgFunction(ctx); } } diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.h b/paddle/fluid/operators/ngraph/ngraph_engine.h index 19400ac5b0e..0e36204a447 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.h +++ b/paddle/fluid/operators/ngraph/ngraph_engine.h @@ -101,7 +101,7 @@ class NgraphEngine { std::unordered_map>> var_node_map_; // prepare info for ngraph engine need - void Prepare(const std::vector& interval); + void Prepare(const framework::ExecutionContext& ctx); // get ngraph engine input and output list void BuildNgIO(const std::vector& op_descs, const std::vector& interval); @@ -112,9 +112,9 @@ class NgraphEngine { // run paddle RuntimeInferShape to get the tensor shape void RunInferShape(); // build ngraph function call - void BuildNgFunction(const std::vector& interval); + void BuildNgFunction(const framework::ExecutionContext& ctx); // Check cache for ngraph function or otherwise build the function - void GetNgFunction(std::string engine_key, const std::vector& interval); + void GetNgFunction(const framework::ExecutionContext& ctx); }; } // namespace operators -- GitLab