提交 54893145 编写于 作者: Z Zhen Wang

update some functions' names according to the suggestion. test=develop

上级 9261cf39
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -115,7 +116,7 @@ void BindNode(py::module *m) { ...@@ -115,7 +116,7 @@ void BindNode(py::module *m) {
.def("is_var", &Node::IsVar) .def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar) .def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); }) .def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("inputs_remove", .def("remove_input",
[](Node &self, int node_id) { [](Node &self, int node_id) {
auto pos = std::find_if( auto pos = std::find_if(
self.inputs.begin(), self.inputs.end(), self.inputs.begin(), self.inputs.end(),
...@@ -124,7 +125,7 @@ void BindNode(py::module *m) { ...@@ -124,7 +125,7 @@ void BindNode(py::module *m) {
self.inputs.erase(pos); self.inputs.erase(pos);
} }
}) })
.def("inputs_remove", .def("remove_input",
[](Node &self, Node &node) { [](Node &self, Node &node) {
auto pos = auto pos =
std::find(self.inputs.begin(), self.inputs.end(), &node); std::find(self.inputs.begin(), self.inputs.end(), &node);
...@@ -132,10 +133,10 @@ void BindNode(py::module *m) { ...@@ -132,10 +133,10 @@ void BindNode(py::module *m) {
self.inputs.erase(pos); self.inputs.erase(pos);
} }
}) })
.def("inputs_append", .def("append_input",
[](Node &self, Node &node) { self.inputs.push_back(&node); }) [](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); }) .def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("outputs_remove", .def("remove_output",
[](Node &self, int node_id) { [](Node &self, int node_id) {
auto pos = std::find_if( auto pos = std::find_if(
self.outputs.begin(), self.outputs.end(), self.outputs.begin(), self.outputs.end(),
...@@ -144,7 +145,7 @@ void BindNode(py::module *m) { ...@@ -144,7 +145,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos); self.outputs.erase(pos);
} }
}) })
.def("outputs_remove", .def("remove_output",
[](Node &self, Node &node) { [](Node &self, Node &node) {
auto pos = auto pos =
std::find(self.outputs.begin(), self.outputs.end(), &node); std::find(self.outputs.begin(), self.outputs.end(), &node);
...@@ -152,7 +153,7 @@ void BindNode(py::module *m) { ...@@ -152,7 +153,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos); self.outputs.erase(pos);
} }
}) })
.def("outputs_append", .def("append_output",
[](Node &self, Node &node) { self.outputs.push_back(&node); }) [](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs) .def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs); .def_readwrite("outputs", &Node::outputs);
......
...@@ -1640,25 +1640,25 @@ class IrNode(object): ...@@ -1640,25 +1640,25 @@ class IrNode(object):
Args: Args:
node_id(int): the given node id. node_id(int): the given node id.
""" """
self.node.inputs_remove(node_id) self.node.remove_input(node_id)
def inputs_remove(self, ir_node): def remove_input(self, node):
""" """
Remove a node from inputs. Remove a node from inputs.
Args: Args:
ir_node(IrNode): the node being removed. node(IrNode): the node being removed.
""" """
self.node.inputs_remove(ir_node.node) self.node.remove_input(node.node)
def inputs_append(self, ir_node): def append_input(self, node):
""" """
Append a node in inputs. Append a node in inputs.
Args: Args:
ir_node(IrNode): the node being appended. node(IrNode): the node being appended.
""" """
self.node.inputs_append(ir_node.node) self.node.append_input(node.node)
def clear_outputs(self): def clear_outputs(self):
""" """
...@@ -1667,32 +1667,32 @@ class IrNode(object): ...@@ -1667,32 +1667,32 @@ class IrNode(object):
""" """
self.node.clear_outputs() self.node.clear_outputs()
def outputs_remove_by_id(self, node_id): def remove_output_by_id(self, node_id):
""" """
Remove a node from outputs by the given node id. Remove a node from outputs by the given node id.
Args: Args:
node_id(int): the given node id. node_id(int): the given node id.
""" """
self.node.outputs_remove(node_id) self.node.remove_output(node_id)
def outputs_remove(self, ir_node): def remove_output(self, node):
""" """
Remove a node from outputs. Remove a node from outputs.
Args: Args:
ir_node(IrNode): the node being removed. node(IrNode): the node being removed.
""" """
self.node.outputs_remove(ir_node.node) self.node.remove_output(node.node)
def outputs_append(self, ir_node): def append_output(self, node):
""" """
Append a node in outputs. Append a node in outputs.
Args: Args:
ir_node(IrNode): the node being appended. node(IrNode): the node being appended.
""" """
self.node.outputs_append(ir_node.node) self.node.append_output(node.node)
@property @property
def inputs(self): def inputs(self):
...@@ -2116,10 +2116,10 @@ class IrGraph(object): ...@@ -2116,10 +2116,10 @@ class IrGraph(object):
assert old_input_node.node in self.graph.nodes() and new_input_node.node in \ assert old_input_node.node in self.graph.nodes() and new_input_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \ self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.' 'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.outputs_remove(op_node) old_input_node.remove_output(op_node)
op_node.inputs_remove(old_input_node) op_node.remove_input(old_input_node)
new_input_node.outputs_append(op_node) new_input_node.append_output(op_node)
op_node.inputs_append(new_input_node) op_node.append_input(new_input_node)
op_node.rename_input(old_input_node.name(), new_input_node.name()) op_node.rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out): def link_to(self, node_in, node_out):
...@@ -2132,8 +2132,8 @@ class IrGraph(object): ...@@ -2132,8 +2132,8 @@ class IrGraph(object):
""" """
assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \ assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \
'The two arguments(node_in&node_out) must be in the graph nodes.' 'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out) node_in.append_output(node_out)
node_out.inputs_append(node_in) node_out.append_input(node_in)
def safe_remove_nodes(self, remove_nodes): def safe_remove_nodes(self, remove_nodes):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册