提交 33f99d61 编写于 作者: Z Zhen Wang

add IrNode&IrVarNode&IrOpNode. test=develop

上级 d8128930
......@@ -17,7 +17,9 @@ import numpy as np
import six
from ..... import compat as cpt
from .... import core
from .... import Executor
from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Program
from ....initializer import Constant
from .... import unique_name
......@@ -31,7 +33,7 @@ __all__ = [
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
program_exe=None,
place=None,
weight_bits=8,
activation_bits=8,
activation_quantize_type='abs_max',
......@@ -45,7 +47,7 @@ class QuantizationTransformPass(object):
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
program_exe(fluid.Executor): program_exe is used to initialize new
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above.
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
......@@ -71,13 +73,13 @@ class QuantizationTransformPass(object):
from paddle.fluid import core
graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace())
place = fluid.CPUPlace()
transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe)
place)
transform_pass.apply(graph)
"""
self._scope = scope
self._program_exe = program_exe
self._place = place
self._weight_bits = weight_bits
self._activation_bits = activation_bits
......@@ -118,7 +120,7 @@ class QuantizationTransformPass(object):
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _transform_forward(graph, op):
for var_node in op.inputs:
......@@ -149,7 +151,7 @@ class QuantizationTransformPass(object):
if not self._is_test:
self._create_global_step(graph)
ops = graph.all_ops()
ops = graph.all_op_nodes()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
......@@ -163,8 +165,8 @@ class QuantizationTransformPass(object):
if len(self._need_initialized) > 0:
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var(
......@@ -175,7 +177,8 @@ class QuantizationTransformPass(object):
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
initializer(var, init_program.global_block())
self._program_exe.run(program=init_program, scope=self._scope)
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
return graph
......@@ -183,11 +186,11 @@ class QuantizationTransformPass(object):
if self._weight_quantize_type == 'range_abs_max' or \
self._activation_quantize_type == 'range_abs_max':
counter_name = cpt.to_text('@STEP_COUNTER@')
for node in graph.all_vars():
for node in graph.all_var_nodes():
if node.name() == counter_name:
self._global_step = node
if self._global_step is None:
global_step_in = graph.create_param_node(
global_step_in = graph.create_persistable_node(
name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
......@@ -262,7 +265,7 @@ class QuantizationTransformPass(object):
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
scale_in_node = graph.create_param_node(
scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
......@@ -275,7 +278,7 @@ class QuantizationTransformPass(object):
if not self._is_test:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node = graph.create_param_node(
scales_node = graph.create_persistable_node(
name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size],
......@@ -400,8 +403,8 @@ class QuantizationFreezePass(object):
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_quant_op_names:
......@@ -425,13 +428,13 @@ class QuantizationFreezePass(object):
self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node)
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
......@@ -462,7 +465,7 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node):
max_range = None
scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
......@@ -480,7 +483,7 @@ class QuantizationFreezePass(object):
original_var_name)
max_range = param_range * act_range / scale_v
else:
assert isinstance(scale_v, core.Node)
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1:
......@@ -517,14 +520,19 @@ class QuantizationFreezePass(object):
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
def _original_var_name(self, var_name):
......@@ -583,8 +591,8 @@ class ConvertToInt8Pass(object):
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
input_map = {}
for op_node in ops:
op_name = op_node.name()
......@@ -605,7 +613,7 @@ class ConvertToInt8Pass(object):
def _convert_to_int8(self, graph, var_node):
int8_var_node_name = var_node.name() + ".int8"
int8_var_node = graph.create_param_node(
int8_var_node = graph.create_persistable_node(
name=cpt.to_text(int8_var_node_name),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
......@@ -624,14 +632,19 @@ class ConvertToInt8Pass(object):
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
......@@ -655,7 +668,7 @@ class TransformForMobilePass(object):
Args:
graph(IrGraph): the graph will be transformed.
"""
ops = graph.all_ops()
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._fake_quant_op_names:
......
......@@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase):
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
marked_nodes = set()
for op in graph.all_ops():
for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1:
marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
self.assertEqual(len(nodes), len(graph.all_ops()))
self.assertEqual(len(nodes), len(graph.all_op_nodes()))
nodes_map = graph.build_adjacency_list()
self.assertEqual(len(nodes_map), len(graph.all_ops()))
self.assertEqual(len(nodes_map), len(graph.all_op_nodes()))
nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
......
......@@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
place=place,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
......@@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
......@@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
place=place,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
for op in graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
......@@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.check_program(transform_pass, program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
for op in val_graph.all_op_nodes():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
......@@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase):
with fluid.scope_guard(scope):
exe.run(startup)
transform_pass = QuantizationTransformPass(
scope=scope, program_exe=exe, activation_quantize_type=quant_type)
scope=scope, place=place, activation_quantize_type=quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
marked_nodes = set()
for op in main_graph.all_ops():
for op in main_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
......@@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
iters = 5
batch_size = 8
#train_exe = fluid.ParallelExecutor(
# main_program=quantized_main_program,
# use_cuda=bool(use_cuda),
# loss_name=loss.name,
# scope=scope)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
......@@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data),
fetch_list=[loss])
#loss_v = train_exe.run(feed=feeder.feed(data),
# fetch_list=[loss.name])
#print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader())
with fluid.program_guard(quantized_test_program):
......@@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
......@@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data),
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
#print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
#print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
#print('{}: {}'.format('w_quant' + dev_name + quant_type,
# np.sum(w_quant)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type,
np.sum(w_quant)))
# Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
......@@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
#print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze)))
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
for op in test_graph.all_op_nodes():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
......
......@@ -1538,10 +1538,297 @@ class Block(object):
return ret_var
class IrNode(object):
"""
Python IrNode. Beneath it is a core.Node, which is used for Ir Pass.
"""
def __init__(self, node):
"""
Construct an IrNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert isinstance(node,
core.Node), 'node must be the instance of core.Node.'
self.node = node
def name(self):
"""
Return the node name.
Returns:
str: node name.
"""
return self.node.name()
def node_type(self):
"""
Return the node type.
Returns:
core.Node.Type: node type(core.Node.Type.Operation or core.Node.Type.Variable).
"""
return self.node.node_type()
def var(self):
"""
Return the node variable description.
Returns:
core.VarDesc: node variable description.
"""
return self.node.var()
def op(self):
"""
Return the node operator description.
Returns:
core.OpDesc: node operator description.
"""
return self.node.op()
def id(self):
"""
Return the node id.
Returns:
int: node id.
"""
return self.node.id()
def is_op(self):
"""
If the node is an operator, then return true.
Returns:
bool: indicate whether the node is an operator.
"""
return self.node.is_op()
def is_var(self):
"""
If the node is a variable, then return true.
Returns:
bool: indicate whether the node is a variable.
"""
return self.node.is_var()
def is_ctrl_var(self):
"""
If the node is a control dependence variable, then return true.
Returns:
bool: indicate whether the node is a control dependence variable.
"""
return self.node.is_ctrl_var()
def clear_inputs(self):
"""
Clear the node inputs. After executing the `clear_inputs` function,
the node inputs will be empty.
"""
self.node.clear_inputs()
def inputs_remove_by_id(self, node_id):
"""
Remove a node from inputs by the given node id.
Args:
node_id(int): the given node id.
"""
self.node.inputs_remove(node_id)
def inputs_remove(self, ir_node):
"""
Remove a node from inputs.
Args:
ir_node(IrNode): the node being removed.
"""
self.node.inputs_remove(ir_node.node)
def inputs_append(self, ir_node):
"""
Append a node in inputs.
Args:
ir_node(IrNode): the node being appended.
"""
self.node.inputs_append(ir_node.node)
def clear_outputs(self):
"""
Clear the node outputs. After executing the `clear_outputs` function,
the node outputs will be empty.
"""
self.node.clear_outputs()
def outputs_remove_by_id(self, node_id):
"""
Remove a node from outputs by the given node id.
Args:
node_id(int): the given node id.
"""
self.node.outputs_remove(node_id)
def outputs_remove(self, ir_node):
"""
Remove a node from outputs.
Args:
ir_node(IrNode): the node being removed.
"""
self.node.outputs_remove(ir_node.node)
def outputs_append(self, ir_node):
"""
Append a node in outputs.
Args:
ir_node(IrNode): the node being appended.
"""
self.node.outputs_append(ir_node.node)
@property
def inputs(self):
"""
Return the node inputs.
Returns:
list(IrNode): node inputs wrapped by IrNode.
"""
return [IrNode(n) for n in self.node.inputs]
@property
def outputs(self):
"""
Return the node outputs.
Returns:
list(IrNode): node outputs wrapped by IrNode.
"""
return [IrNode(n) for n in self.node.outputs]
class IrVarNode(IrNode):
"""
Python IrVarNode. Beneath it is a core.Node, it inherits from IrNode.
"""
def __init__(self, node):
"""
Construct an IrVarNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert isinstance(node, core.Node) and node.is_var(), \
'node must be the instance of core.Node and it must be a variable node.'
super(IrVarNode, self).__init__(node)
self.node = node
def set_shape(self, shape):
"""
Set the node variable shape.
Args:
shape(list): shape to be set.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
self.node.var().set_shape(shape)
def persistable(self):
"""
If the variable node is a persistable variable, then return true.
Returns:
bool: indicate whether the variable is persistable.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().persistable()
@property
def inputs(self):
"""
Return the node inputs.
Returns:
list(IrOpNode): node inputs wrapped by IrOpNode.
"""
return [IrOpNode(n) for n in self.node.inputs]
@property
def outputs(self):
"""
Return the node outputs.
Returns:
list(IrOpNode): node outputs wrapped by IrOpNode.
"""
return [IrOpNode(n) for n in self.node.outputs]
class IrOpNode(IrNode):
"""
Python IrOpNode. Beneath it is a core.Node, it inherits from IrNode.
"""
def __init__(self, node):
"""
Construct an IrOpNode using core.Node.
Args:
node(core.Node): C++ Node.
"""
assert isinstance(node, core.Node) and node.is_op(), \
'node must be the instance of core.Node and it must be a operator node.'
super(IrOpNode, self).__init__(node)
self.node = node
def rename_input(self, old_input_name, new_input_name):
"""
Rename the input of this node.
Args:
old_input_name(str): the old input name.
new_input_name(str): the new input name.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name)
@property
def inputs(self):
"""
Return the node inputs.
Returns:
list(IrVarNode): node inputs wrapped by IrVarNode.
"""
return [IrVarNode(n) for n in self.node.inputs]
@property
def outputs(self):
"""
Return the node outputs.
Returns:
list(IrVarNode): node outputs wrapped by IrVarNode.
"""
return [IrVarNode(n) for n in self.node.outputs]
class IrGraph(object):
"""
Python IrGraph. Beneath it is a core.Graph, which is used for
create a c++ Ir Pass Graph. An IrGraph is just a graph view of
creating a c++ Ir Pass Graph. An IrGraph is just a graph view of
a Program. In an IrGraph, both Variables and Operators are graph
nodes.
"""
......@@ -1569,15 +1856,15 @@ class IrGraph(object):
"""
Return all nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes()}
return {IrNode(node) for node in self.graph.nodes()}
def all_vars(self):
def all_var_nodes(self):
"""
Return all variable nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_var()}
return {IrVarNode(node) for node in self.graph.nodes() if node.is_var()}
def all_persistable_vars(self):
def all_persistable_nodes(self):
"""
Return all persistable variable nodes included in the graph as a set.
"""
......@@ -1586,13 +1873,13 @@ class IrGraph(object):
if node.is_var() and node.var() is not None and node.var(
).persistable():
persistable_nodes.add(node)
return persistable_nodes
return {IrVarNode(p) for p in persistable_nodes}
def all_ops(self):
def all_op_nodes(self):
"""
Return all operator nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_op()}
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def var_node(self, name):
"""
......@@ -1606,14 +1893,14 @@ class IrGraph(object):
doesn't have a variable with the giving name.
Returns:
core.Node: the variable node with the giving name.
IrVarNode: the variable node with the giving name.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None
var_nodes = self.all_vars()
var_nodes = self.all_var_nodes()
for var_node in var_nodes:
if var_node.name() == name:
target_var_node = var_node
......@@ -1621,7 +1908,7 @@ class IrGraph(object):
raise ValueError("var_node %s not in this graph" % name)
return target_var_node
def create_param_node(self, name, var_type, shape, var_dtype):
def create_persistable_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
it can not distinguish between persistable variables and parameters.
......@@ -1633,14 +1920,14 @@ class IrGraph(object):
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
Returns:
core.Node: the created persistable variable node.
IrVarNode: the created persistable variable node.
"""
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc)
return IrVarNode(self.graph.create_var_node(var_desc))
def create_var_node(self, name, var_type, shape, var_dtype):
"""
......@@ -1654,14 +1941,14 @@ class IrGraph(object):
var_dtype(core.VarDesc.VarType): the data type of the variable node.
Returns:
core.Node: the created variable node.
IrVarNode: the created variable node.
"""
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc)
return IrVarNode(self.graph.create_var_node(var_desc))
def create_var_node_from_desc(self, var_desc):
"""
......@@ -1672,9 +1959,9 @@ class IrGraph(object):
var_desc(core.VarDesc): the giving variable description.
Returns:
core.Node: the created variable node.
IrVarNode: the created variable node.
"""
return self.graph.create_var_node(var_desc)
return IrVarNode(self.graph.create_var_node(var_desc))
def create_op_node(self, op_type, attrs, inputs, outputs):
"""
......@@ -1687,7 +1974,7 @@ class IrGraph(object):
outputs(dict): the outpus of the operator node.
Returns:
core.Node: the created operator node.
IrOpNode: the created operator node.
"""
op_desc = core.OpDesc()
op_desc.set_type(op_type)
......@@ -1703,7 +1990,7 @@ class IrGraph(object):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
return IrOpNode(self.graph.create_op_node(op_desc))
def create_op_node_from_desc(self, op_desc):
"""
......@@ -1713,37 +2000,37 @@ class IrGraph(object):
op_desc(core.VarDesc): the giving operator description.
Returns:
core.Node: the created operator node.
IrOpNode: the created operator node.
"""
return self.graph.create_op_node(op_desc)
return IrOpNode(self.graph.create_op_node(op_desc))
def update_input_link(self, old_input_node, new_input_node, op_node):
"""
Update the input's link of a operator node.
Args:
old_input_node(core.Node): the old input node of the giving op_node.
new_input_node(core.Node): the new input node of the giving op_node.
op_node(core.Node): the operator node that is needed to update input's link.
old_input_node(IrNode): the old input node of the giving op_node.
new_input_node(IrNode): the new input node of the giving op_node.
op_node(IrOpNode): the operator node that is needed to update input's link.
"""
assert old_input_node in self.graph.nodes() and new_input_node in \
self.graph.nodes() and op_node in self.graph.nodes(), \
assert old_input_node.node in self.graph.nodes() and new_input_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
op_node.rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out):
"""
Connect two nodes.
Args:
node_in(core.Node): the input node.
node_out(core.Node): the output node.
node_in(IrNode): the input node.
node_out(IrNode): the output node.
"""
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \
'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
......@@ -1761,7 +2048,8 @@ class IrGraph(object):
remove_nodes = set(remove_nodes)
else:
remove_nodes = {remove_nodes}
core.graph_safe_remove_nodes(self.graph, remove_nodes)
original_nodes = {n.node for n in remove_nodes}
core.graph_safe_remove_nodes(self.graph, original_nodes)
def has_circle(self):
"""
......@@ -1788,18 +2076,23 @@ class IrGraph(object):
Notes: the `graph` cannot contain a circle.
Returns:
set(core.Node): nodes in topology order.
set(IrNode): nodes in topology order.
"""
return core.topology_sort(self.graph)
ordered_nodes = core.topology_sort(self.graph)
return {IrNode(n) for n in ordered_nodes}
def build_adjacency_list(self):
"""
Build an adjacency list of operations for the `graph`.
Returns:
dict{core.Node: set(core.Node)}: the adjacency list.
dict{IrNode: set(IrNode)}: the adjacency list.
"""
return core.build_adjacency_list(self.graph)
adj_list = core.build_adjacency_list(self.graph)
wrapped_adj_list = dict()
for k, v in six.iteritems(adj_list):
wrapped_adj_list[IrNode(k)] = {IrNode(n) for n in v}
return wrapped_adj_list
def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True):
"""
......@@ -1809,7 +2102,7 @@ class IrGraph(object):
Args:
save_path(str): the save path of drawn graph.
name(str): the name of drawn graph.
marked_nodes(set(core.Node)): nodes that are needed to be marked.
marked_nodes(set(IrNode)): nodes that are needed to be marked.
Default value is None.
remove_ctr_var(bool): If it is set True, all control variable nodes
in the graph will be removed. Default value is True.
......@@ -1824,20 +2117,22 @@ class IrGraph(object):
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
if remove_ctr_var:
remove_ctr_vars = set()
for node in self.graph.nodes():
if remove_ctr_var:
for node in self.all_var_nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
self.safe_remove_nodes(remove_ctr_vars)
ops_num = 0
for node in self.graph.nodes():
if node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
print('Total ops num = {}.'.format(len(self.all_op_nodes())))
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
if isinstance(marked_nodes, Iterable):
marked_nodes = set(marked_nodes)
else:
marked_nodes = {marked_nodes}
marked_nodes = {n.node for n in marked_nodes}
remove_ctr_vars = {n.node for n in remove_ctr_vars}
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册