提交 54893145 编写于 作者: Z Zhen Wang

update some functions' names according to the suggestion. test=develop

上级 9261cf39
......@@ -14,6 +14,7 @@
#include "paddle/fluid/pybind/ir.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
......@@ -115,7 +116,7 @@ void BindNode(py::module *m) {
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("inputs_remove",
.def("remove_input",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.inputs.begin(), self.inputs.end(),
......@@ -124,7 +125,7 @@ void BindNode(py::module *m) {
self.inputs.erase(pos);
}
})
.def("inputs_remove",
.def("remove_input",
[](Node &self, Node &node) {
auto pos =
std::find(self.inputs.begin(), self.inputs.end(), &node);
......@@ -132,10 +133,10 @@ void BindNode(py::module *m) {
self.inputs.erase(pos);
}
})
.def("inputs_append",
.def("append_input",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("outputs_remove",
.def("remove_output",
[](Node &self, int node_id) {
auto pos = std::find_if(
self.outputs.begin(), self.outputs.end(),
......@@ -144,7 +145,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos);
}
})
.def("outputs_remove",
.def("remove_output",
[](Node &self, Node &node) {
auto pos =
std::find(self.outputs.begin(), self.outputs.end(), &node);
......@@ -152,7 +153,7 @@ void BindNode(py::module *m) {
self.outputs.erase(pos);
}
})
.def("outputs_append",
.def("append_output",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
......
......@@ -1640,25 +1640,25 @@ class IrNode(object):
Args:
node_id(int): the given node id.
"""
self.node.inputs_remove(node_id)
self.node.remove_input(node_id)
def inputs_remove(self, ir_node):
def remove_input(self, node):
"""
Remove a node from inputs.
Args:
ir_node(IrNode): the node being removed.
node(IrNode): the node being removed.
"""
self.node.inputs_remove(ir_node.node)
self.node.remove_input(node.node)
def inputs_append(self, ir_node):
def append_input(self, node):
"""
Append a node in inputs.
Args:
ir_node(IrNode): the node being appended.
node(IrNode): the node being appended.
"""
self.node.inputs_append(ir_node.node)
self.node.append_input(node.node)
def clear_outputs(self):
"""
......@@ -1667,32 +1667,32 @@ class IrNode(object):
"""
self.node.clear_outputs()
def outputs_remove_by_id(self, node_id):
def remove_output_by_id(self, node_id):
"""
Remove a node from outputs by the given node id.
Args:
node_id(int): the given node id.
"""
self.node.outputs_remove(node_id)
self.node.remove_output(node_id)
def outputs_remove(self, ir_node):
def remove_output(self, node):
"""
Remove a node from outputs.
Args:
ir_node(IrNode): the node being removed.
node(IrNode): the node being removed.
"""
self.node.outputs_remove(ir_node.node)
self.node.remove_output(node.node)
def outputs_append(self, ir_node):
def append_output(self, node):
"""
Append a node in outputs.
Args:
ir_node(IrNode): the node being appended.
node(IrNode): the node being appended.
"""
self.node.outputs_append(ir_node.node)
self.node.append_output(node.node)
@property
def inputs(self):
......@@ -2116,10 +2116,10 @@ class IrGraph(object):
assert old_input_node.node in self.graph.nodes() and new_input_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
old_input_node.remove_output(op_node)
op_node.remove_input(old_input_node)
new_input_node.append_output(op_node)
op_node.append_input(new_input_node)
op_node.rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out):
......@@ -2132,8 +2132,8 @@ class IrGraph(object):
"""
assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \
'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
node_in.append_output(node_out)
node_out.append_input(node_in)
def safe_remove_nodes(self, remove_nodes):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册