提交 8923612b 编写于 作者: B baojun 提交者: tensor-tang

NGraph enable parse serialized graph test=develop (#17453)

上级 cf5d271c
......@@ -158,6 +158,8 @@ static void SubstituteNgraphOp(
ng_op_desc.SetAttr("interval", interval);
ng_op_desc.SetAttr("engine_key", engine_key);
ng_op_desc.SetAttr("graph", block_str);
ng_op_desc.SetInput("Xs", std::vector<std::string>(0));
ng_op_desc.SetOutput("Ys", std::vector<std::string>(0));
ops->erase(ops->begin() + interval[0], ops->begin() + interval[1]);
ops->insert(ops->begin() + interval[0],
......@@ -223,20 +225,36 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
const platform::Place& place,
const framework::ExecutionContext& ctx)
: scope_(scope), place_(place) {
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key");
var_in_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
var_node_map_ = std::make_shared<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
GetNgFunction(engine_key, interval);
GetNgFunction(ctx);
}
void NgraphEngine::Prepare(const std::vector<int>& interval) {
void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string serialized_graph = ctx.Attr<std::string>("graph");
auto input_vars = ctx.Inputs("Xs");
if (!input_vars.empty()) {
feed_vars = input_vars;
var_in_ = input_vars;
}
auto output_vars = ctx.Outputs("Ys");
if (!output_vars.empty()) {
var_out_ = output_vars;
}
framework::proto::BlockDesc block_proto;
if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph);
framework::BlockDesc block_desc(nullptr, &block_proto);
if (!serialized_graph.empty()) {
NgraphEngine::p_bdesc = &block_desc;
}
bool has_fetch = false, is_full = false;
for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
......@@ -316,7 +334,15 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
op_state_ = OpState::UNKNOWN;
}
BuildNgIO(ops_desc, interval);
if (var_in_.empty() && var_out_.empty()) {
BuildNgIO(ops_desc, interval);
}
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
var_in_updates_.emplace_back(i);
}
}
}
void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
......@@ -392,13 +418,6 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
}
}
}
for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) {
var_in_updates_.emplace_back(i);
}
}
}
void NgraphEngine::GetNgInputShape() {
......@@ -434,7 +453,6 @@ void NgraphEngine::BuildNgNodes() {
}
}
}
NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
......@@ -448,8 +466,8 @@ void NgraphEngine::RunInferShape() {
}
}
void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
Prepare(interval);
void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
Prepare(ctx);
RunInferShape();
GetNgInputShape();
BuildNgNodes();
......@@ -472,12 +490,13 @@ void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
}
void NgraphEngine::GetNgFunction(std::string engine_key,
const std::vector<int>& interval) {
void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
auto interval = ctx.Attr<std::vector<int>>("interval");
std::string engine_key = ctx.Attr<std::string>("engine_key");
bool use_cache = true;
if (use_cache) {
this->func_cache_key_ = "";
for (int i = 0; i < std::min(static_cast<int>(feed_vars.size()), 10); ++i) {
for (int i = 0; i < static_cast<int>(feed_vars.size()); ++i) {
auto* var = scope_.FindVar(feed_vars[i]);
if (var && var->IsType<framework::LoDTensor>()) {
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
......@@ -507,7 +526,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
}
if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
BuildNgFunction(interval);
BuildNgFunction(ctx);
engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_;
engine_cache[func_cache_key_].persistables = this->persistables_;
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
......@@ -516,7 +535,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
engine_cache[func_cache_key_].is_test = this->is_test_;
}
} else {
BuildNgFunction(interval);
BuildNgFunction(ctx);
}
}
......
......@@ -101,7 +101,7 @@ class NgraphEngine {
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map_;
// prepare info for ngraph engine need
void Prepare(const std::vector<int>& interval);
void Prepare(const framework::ExecutionContext& ctx);
// get ngraph engine input and output list
void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
const std::vector<int>& interval);
......@@ -112,9 +112,9 @@ class NgraphEngine {
// run paddle RuntimeInferShape to get the tensor shape
void RunInferShape();
// build ngraph function call
void BuildNgFunction(const std::vector<int>& interval);
void BuildNgFunction(const framework::ExecutionContext& ctx);
// Check cache for ngraph function or otherwise build the function
void GetNgFunction(std::string engine_key, const std::vector<int>& interval);
void GetNgFunction(const framework::ExecutionContext& ctx);
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册