diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/framework/ngraph_bridge.cc index 45ef0211ad2e84709884973be223703d8e929d41..e22c29037718a60ff7f24404d7749600e2edb80b 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/framework/ngraph_bridge.cc @@ -111,7 +111,7 @@ std::map}, {"tanh", BuildUnaryNode}}; -void NgraphBridge::BuildNgGraph(const std::shared_ptr& op) { +void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { auto& op_type = op->Type(); NG_NODE_MAP[op_type](op, ngb_node_map_); } diff --git a/paddle/fluid/framework/ngraph_bridge.h b/paddle/fluid/framework/ngraph_bridge.h index 3cf62b6daab96537435bdb61f8dd3b9c8fc80222..9ed6b9510942136a61faa5755fd8fa74286939a8 100644 --- a/paddle/fluid/framework/ngraph_bridge.h +++ b/paddle/fluid/framework/ngraph_bridge.h @@ -43,7 +43,7 @@ class NgraphBridge { var_node_map) : ngb_node_map_(var_node_map) {} - void BuildNgGraph(const std::shared_ptr& op); + void BuildNgNode(const std::shared_ptr& op); private: std::shared_ptr< diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 1c770a2370926c6b32236537cbb265e8a9eaa468..3fea753f0659019395c9b214e52a7912058c501c 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -122,7 +122,7 @@ class NgraphOperator { // get ngraph input and define ngraph input parameters void GetNgInputShape(std::shared_ptr op); // Call ngraph bridge to map ops - void BuildNgNode(); + void BuildNgNodes(); // get the ngraph input and output var list void BuildNgIO(); // build ngraph function call @@ -301,7 +301,7 @@ void NgraphOperator::GetNgInputShape(std::shared_ptr op) { } } -void NgraphOperator::BuildNgNode() { +void NgraphOperator::BuildNgNodes() { for (auto& var_name : var_out_) { if (var_node_map_->find(var_name) == var_node_map_->end()) { auto* var = scope_.FindVar(var_name); @@ -319,7 +319,7 @@ void NgraphOperator::BuildNgNode() { paddle::framework::NgraphBridge ngb(var_node_map_); for (auto& op : fused_ops_) { - ngb.BuildNgGraph(op); + ngb.BuildNgNode(op); } } @@ -396,7 +396,7 @@ void NgraphOperator::BuildNgIO() { } void NgraphOperator::BuildNgFunction() { - BuildNgNode(); + BuildNgNodes(); ngraph_function_ = nullptr; ngraph::NodeVector func_outputs; ngraph::op::ParameterVector func_inputs;