From 33f99d61976276c6f8f0fda99fc0fc9aa5995138 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 20 Feb 2019 22:43:25 +0800 Subject: [PATCH] add IrNode&IrVarNode&IrOpNode. test=develop --- .../slim/quantization/quantization_pass.py | 69 ++-- .../fluid/contrib/slim/tests/test_graph.py | 6 +- .../slim/tests/test_quantization_pass.py | 57 ++- python/paddle/fluid/framework.py | 383 ++++++++++++++++-- 4 files changed, 409 insertions(+), 106 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 18b58e6f388..5764d9d94f4 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -17,7 +17,9 @@ import numpy as np import six from ..... import compat as cpt from .... import core +from .... import Executor from ....framework import IrGraph +from ....framework import IrNode from ....framework import Program from ....initializer import Constant from .... import unique_name @@ -31,7 +33,7 @@ __all__ = [ class QuantizationTransformPass(object): def __init__(self, scope=None, - program_exe=None, + place=None, weight_bits=8, activation_bits=8, activation_quantize_type='abs_max', @@ -45,7 +47,7 @@ class QuantizationTransformPass(object): scope(fluid.Scope): When activation use 'range_abs_max' as the quantize type, this pass will create some new parameters. The scope is used to initialize these new parameters. - program_exe(fluid.Executor): program_exe is used to initialize new + place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new parameters described above. weight_bits (int): quantization bit number for weights, the bias is not quantized. @@ -71,13 +73,13 @@ class QuantizationTransformPass(object): from paddle.fluid import core graph = IrGraph(core.Graph(program.desc), for_test=False) - exe = fluid.Executor(fluid.CPUPlace()) + place = fluid.CPUPlace() transform_pass = QuantizationTransformPass(fluid.global_scope(), - exe) + place) transform_pass.apply(graph) """ self._scope = scope - self._program_exe = program_exe + self._place = place self._weight_bits = weight_bits self._activation_bits = activation_bits @@ -118,7 +120,7 @@ class QuantizationTransformPass(object): self._is_test = graph.is_test() # marked the variable which has been dequantized. dequantized_vars = collections.OrderedDict() - persistable_vars = [p.name() for p in graph.all_persistable_vars()] + persistable_vars = [p.name() for p in graph.all_persistable_nodes()] def _transform_forward(graph, op): for var_node in op.inputs: @@ -149,7 +151,7 @@ class QuantizationTransformPass(object): if not self._is_test: self._create_global_step(graph) - ops = graph.all_ops() + ops = graph.all_op_nodes() # The process of _transform_forward and _transform_backward is needed in two for loops. # The loop for transforming the forward graph: for op in ops: @@ -163,8 +165,8 @@ class QuantizationTransformPass(object): if len(self._need_initialized) > 0: assert self._scope is not None, \ 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' - assert self._program_exe is not None, \ - 'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.' + assert self._place is not None, \ + 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' init_program = Program() for var_desc, initializer in six.iteritems(self._need_initialized): var = init_program.global_block().create_var( @@ -175,7 +177,8 @@ class QuantizationTransformPass(object): lod_level=var_desc.lod_level(), persistable=var_desc.persistable()) initializer(var, init_program.global_block()) - self._program_exe.run(program=init_program, scope=self._scope) + exe = Executor(self._place) + exe.run(program=init_program, scope=self._scope) return graph @@ -183,11 +186,11 @@ class QuantizationTransformPass(object): if self._weight_quantize_type == 'range_abs_max' or \ self._activation_quantize_type == 'range_abs_max': counter_name = cpt.to_text('@STEP_COUNTER@') - for node in graph.all_vars(): + for node in graph.all_var_nodes(): if node.name() == counter_name: self._global_step = node if self._global_step is None: - global_step_in = graph.create_param_node( + global_step_in = graph.create_persistable_node( name=counter_name, var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], @@ -262,7 +265,7 @@ class QuantizationTransformPass(object): shape=var_node.var().shape(), var_dtype=var_node.var().dtype()) - scale_in_node = graph.create_param_node( + scale_in_node = graph.create_persistable_node( name=self._quantized_scale_name(var_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], @@ -275,7 +278,7 @@ class QuantizationTransformPass(object): if not self._is_test: # The name of scales_var_node maybe 'scales_0', 'scales_1', etc. - scales_node = graph.create_param_node( + scales_node = graph.create_persistable_node( name=unique_name.generate('scales'), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[self._window_size], @@ -400,8 +403,8 @@ class QuantizationFreezePass(object): Args: graph(IrGraph): the applied graph. """ - persistable_vars = [p.name() for p in graph.all_persistable_vars()] - ops = graph.all_ops() + persistable_vars = [p.name() for p in graph.all_persistable_nodes()] + ops = graph.all_op_nodes() for op_node in ops: op_name = op_node.name() if op_name in self._fake_quant_op_names: @@ -425,13 +428,13 @@ class QuantizationFreezePass(object): self._weight_bits) self._restore_var(input_arg_name, quantized_param_v) - ops = graph.all_ops() + ops = graph.all_op_nodes() for op_node in ops: op_name = op_node.name() if op_name in self._fake_dequant_op_names: self._remove_fake_quant_and_dequant_op(graph, op_node) - ops = graph.all_ops() + ops = graph.all_op_nodes() for op_node in ops: op_name = op_node.name() if op_name in self._quantizable_ops: @@ -462,7 +465,7 @@ class QuantizationFreezePass(object): def _insert_post_dequant_op(self, graph, op_node): max_range = None scale_var_node = None - persistable_vars = [p.name() for p in graph.all_persistable_vars()] + persistable_vars = [p.name() for p in graph.all_persistable_nodes()] for var_node in op_node.inputs: name = var_node.name() if name in self._op_input_rename_map: @@ -480,7 +483,7 @@ class QuantizationFreezePass(object): original_var_name) max_range = param_range * act_range / scale_v else: - assert isinstance(scale_v, core.Node) + assert isinstance(scale_v, IrNode) scale_var_node = self._var_scale_map[original_var_name] if len(op_node.outputs) != 1: @@ -517,14 +520,19 @@ class QuantizationFreezePass(object): def _remove_unused_var_nodes(self, graph): all_used_vars = set() - ops = graph.all_ops() + ops = graph.all_op_nodes() for op_node in ops: for input_node in op_node.inputs: all_used_vars.add(input_node) for output_node in op_node.outputs: all_used_vars.add(output_node) - all_unused_vars = graph.all_vars() - all_used_vars + all_used_vars = {n.node for n in all_used_vars} + all_unused_vars = { + n + for n in filter(lambda node: node.node not in all_used_vars, + graph.all_var_nodes()) + } graph.safe_remove_nodes(all_unused_vars) def _original_var_name(self, var_name): @@ -583,8 +591,8 @@ class ConvertToInt8Pass(object): Args: graph(IrGraph): the applied graph. """ - persistable_vars = [p.name() for p in graph.all_persistable_vars()] - ops = graph.all_ops() + persistable_vars = [p.name() for p in graph.all_persistable_nodes()] + ops = graph.all_op_nodes() input_map = {} for op_node in ops: op_name = op_node.name() @@ -605,7 +613,7 @@ class ConvertToInt8Pass(object): def _convert_to_int8(self, graph, var_node): int8_var_node_name = var_node.name() + ".int8" - int8_var_node = graph.create_param_node( + int8_var_node = graph.create_persistable_node( name=cpt.to_text(int8_var_node_name), var_type=var_node.var().type(), shape=var_node.var().shape(), @@ -624,14 +632,19 @@ class ConvertToInt8Pass(object): def _remove_unused_var_nodes(self, graph): all_used_vars = set() - ops = graph.all_ops() + ops = graph.all_op_nodes() for op_node in ops: for input_node in op_node.inputs: all_used_vars.add(input_node) for output_node in op_node.outputs: all_used_vars.add(output_node) - all_unused_vars = graph.all_vars() - all_used_vars + all_used_vars = {n.node for n in all_used_vars} + all_unused_vars = { + n + for n in filter(lambda node: node.node not in all_used_vars, + graph.all_var_nodes()) + } graph.safe_remove_nodes(all_unused_vars) @@ -655,7 +668,7 @@ class TransformForMobilePass(object): Args: graph(IrGraph): the graph will be transformed. """ - ops = graph.all_ops() + ops = graph.all_op_nodes() for op_node in ops: name = op_node.name() if name in self._fake_quant_op_names: diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph.py b/python/paddle/fluid/contrib/slim/tests/test_graph.py index 75e0c95b5c3..2d2f1384dec 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph.py @@ -61,16 +61,16 @@ class TestGraph(unittest.TestCase): opt.minimize(loss) graph = IrGraph(core.Graph(main.desc), for_test=False) marked_nodes = set() - for op in graph.all_ops(): + for op in graph.all_op_nodes(): if op.name().find('conv2d') > -1: marked_nodes.add(op) graph.draw('.', 'residual', marked_nodes) self.assertFalse(graph.has_circle()) self.assertEqual(graph.graph_num(), 1) nodes = graph.topology_sort() - self.assertEqual(len(nodes), len(graph.all_ops())) + self.assertEqual(len(nodes), len(graph.all_op_nodes())) nodes_map = graph.build_adjacency_list() - self.assertEqual(len(nodes_map), len(graph.all_ops())) + self.assertEqual(len(nodes_map), len(graph.all_op_nodes())) nodes_num = len(graph.all_nodes()) graph.safe_remove_nodes(marked_nodes) self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index 2f291132f30..254b73a1247 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -130,15 +130,16 @@ class TestQuantizationTransformPass(unittest.TestCase): loss = linear_fc(3) opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) - exe = fluid.Executor(fluid.CPUPlace()) + place = fluid.CPUPlace() + exe = fluid.Executor(place) graph = IrGraph(core.Graph(main.desc), for_test=False) transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), - program_exe=exe, + place=place, activation_quantize_type=quant_type) transform_pass.apply(graph) marked_nodes = set() - for op in graph.all_ops(): + for op in graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) @@ -146,7 +147,7 @@ class TestQuantizationTransformPass(unittest.TestCase): self.check_program(transform_pass, program) val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_marked_nodes = set() - for op in val_graph.all_ops(): + for op in val_graph.all_op_nodes(): if op.name().find('quantize') > -1: val_marked_nodes.add(op) val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes) @@ -166,15 +167,16 @@ class TestQuantizationTransformPass(unittest.TestCase): loss = residual_block(2) opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) - exe = fluid.Executor(fluid.CPUPlace()) + place = fluid.CPUPlace() + exe = fluid.Executor(place) graph = IrGraph(core.Graph(main.desc), for_test=False) transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), - program_exe=exe, + place=place, activation_quantize_type=quant_type) transform_pass.apply(graph) marked_nodes = set() - for op in graph.all_ops(): + for op in graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) @@ -182,7 +184,7 @@ class TestQuantizationTransformPass(unittest.TestCase): self.check_program(transform_pass, program) val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_marked_nodes = set() - for op in val_graph.all_ops(): + for op in val_graph.all_op_nodes(): if op.name().find('quantize') > -1: val_marked_nodes.add(op) val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes) @@ -231,17 +233,17 @@ class TestQuantizationFreezePass(unittest.TestCase): with fluid.scope_guard(scope): exe.run(startup) transform_pass = QuantizationTransformPass( - scope=scope, program_exe=exe, activation_quantize_type=quant_type) + scope=scope, place=place, activation_quantize_type=quant_type) transform_pass.apply(main_graph) transform_pass.apply(test_graph) dev_name = '_gpu_' if use_cuda else '_cpu_' marked_nodes = set() - for op in main_graph.all_ops(): + for op in main_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes) marked_nodes = set() - for op in test_graph.all_ops(): + for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes) @@ -251,11 +253,6 @@ class TestQuantizationFreezePass(unittest.TestCase): iters = 5 batch_size = 8 - #train_exe = fluid.ParallelExecutor( - # main_program=quantized_main_program, - # use_cuda=bool(use_cuda), - # loss_name=loss.name, - # scope=scope) train_reader = paddle.batch( paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=500), @@ -269,9 +266,7 @@ class TestQuantizationFreezePass(unittest.TestCase): loss_v = exe.run(program=quantized_main_program, feed=feeder.feed(data), fetch_list=[loss]) - #loss_v = train_exe.run(feed=feeder.feed(data), - # fetch_list=[loss.name]) - #print('{}: {}'.format('loss' + dev_name + quant_type, loss_v)) + print('{}: {}'.format('loss' + dev_name + quant_type, loss_v)) test_data = next(test_reader()) with fluid.program_guard(quantized_test_program): @@ -287,7 +282,7 @@ class TestQuantizationFreezePass(unittest.TestCase): freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass.apply(test_graph) marked_nodes = set() - for op in test_graph.all_ops(): + for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) test_graph.draw('.', 'test_freeze' + dev_name + quant_type, @@ -299,21 +294,21 @@ class TestQuantizationFreezePass(unittest.TestCase): feed=feeder.feed(test_data), fetch_list=[loss]) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) - #print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1)) - #print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2)) + print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1)) + print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2)) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) # Maybe failed, this is due to the calculation precision # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) - #print('{}: {}'.format('w_freeze' + dev_name + quant_type, - # np.sum(w_freeze))) - #print('{}: {}'.format('w_quant' + dev_name + quant_type, - # np.sum(w_quant))) + print('{}: {}'.format('w_freeze' + dev_name + quant_type, + np.sum(w_freeze))) + print('{}: {}'.format('w_quant' + dev_name + quant_type, + np.sum(w_quant))) # Convert parameter to 8-bit. convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass.apply(test_graph) marked_nodes = set() - for op in test_graph.all_ops(): + for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes) @@ -330,14 +325,14 @@ class TestQuantizationFreezePass(unittest.TestCase): w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) - #print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit))) - #print('{}: {}'.format('w_freeze' + dev_name + quant_type, - # np.sum(w_freeze))) + print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit))) + print('{}: {}'.format('w_freeze' + dev_name + quant_type, + np.sum(w_freeze))) mobile_pass = TransformForMobilePass() mobile_pass.apply(test_graph) marked_nodes = set() - for op in test_graph.all_ops(): + for op in test_graph.all_op_nodes(): if op.name().find('quantize') > -1: marked_nodes.add(op) test_graph.draw('.', 'test_mobile' + dev_name + quant_type, diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 489f8d6b3a9..70c100d9ec7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1538,10 +1538,297 @@ class Block(object): return ret_var +class IrNode(object): + """ + Python IrNode. Beneath it is a core.Node, which is used for Ir Pass. + """ + + def __init__(self, node): + """ + Construct an IrNode using core.Node. + + Args: + node(core.Node): C++ Node. + """ + assert isinstance(node, + core.Node), 'node must be the instance of core.Node.' + self.node = node + + def name(self): + """ + Return the node name. + + Returns: + str: node name. + """ + return self.node.name() + + def node_type(self): + """ + Return the node type. + + Returns: + core.Node.Type: node type(core.Node.Type.Operation or core.Node.Type.Variable). + """ + return self.node.node_type() + + def var(self): + """ + Return the node variable description. + + Returns: + core.VarDesc: node variable description. + """ + return self.node.var() + + def op(self): + """ + Return the node operator description. + + Returns: + core.OpDesc: node operator description. + """ + return self.node.op() + + def id(self): + """ + Return the node id. + + Returns: + int: node id. + """ + return self.node.id() + + def is_op(self): + """ + If the node is an operator, then return true. + + Returns: + bool: indicate whether the node is an operator. + """ + return self.node.is_op() + + def is_var(self): + """ + If the node is a variable, then return true. + + Returns: + bool: indicate whether the node is a variable. + """ + return self.node.is_var() + + def is_ctrl_var(self): + """ + If the node is a control dependence variable, then return true. + + Returns: + bool: indicate whether the node is a control dependence variable. + """ + return self.node.is_ctrl_var() + + def clear_inputs(self): + """ + Clear the node inputs. After executing the `clear_inputs` function, + the node inputs will be empty. + """ + self.node.clear_inputs() + + def inputs_remove_by_id(self, node_id): + """ + Remove a node from inputs by the given node id. + + Args: + node_id(int): the given node id. + """ + self.node.inputs_remove(node_id) + + def inputs_remove(self, ir_node): + """ + Remove a node from inputs. + + Args: + ir_node(IrNode): the node being removed. + """ + self.node.inputs_remove(ir_node.node) + + def inputs_append(self, ir_node): + """ + Append a node in inputs. + + Args: + ir_node(IrNode): the node being appended. + """ + self.node.inputs_append(ir_node.node) + + def clear_outputs(self): + """ + Clear the node outputs. After executing the `clear_outputs` function, + the node outputs will be empty. + """ + self.node.clear_outputs() + + def outputs_remove_by_id(self, node_id): + """ + Remove a node from outputs by the given node id. + + Args: + node_id(int): the given node id. + """ + self.node.outputs_remove(node_id) + + def outputs_remove(self, ir_node): + """ + Remove a node from outputs. + + Args: + ir_node(IrNode): the node being removed. + """ + self.node.outputs_remove(ir_node.node) + + def outputs_append(self, ir_node): + """ + Append a node in outputs. + + Args: + ir_node(IrNode): the node being appended. + """ + self.node.outputs_append(ir_node.node) + + @property + def inputs(self): + """ + Return the node inputs. + + Returns: + list(IrNode): node inputs wrapped by IrNode. + """ + return [IrNode(n) for n in self.node.inputs] + + @property + def outputs(self): + """ + Return the node outputs. + + Returns: + list(IrNode): node outputs wrapped by IrNode. + """ + return [IrNode(n) for n in self.node.outputs] + + +class IrVarNode(IrNode): + """ + Python IrVarNode. Beneath it is a core.Node, it inherits from IrNode. + """ + + def __init__(self, node): + """ + Construct an IrVarNode using core.Node. + + Args: + node(core.Node): C++ Node. + """ + assert isinstance(node, core.Node) and node.is_var(), \ + 'node must be the instance of core.Node and it must be a variable node.' + super(IrVarNode, self).__init__(node) + self.node = node + + def set_shape(self, shape): + """ + Set the node variable shape. + + Args: + shape(list): shape to be set. + """ + assert self.node.var() is not None, \ + "The node variable description cannot be None." + self.node.var().set_shape(shape) + + def persistable(self): + """ + If the variable node is a persistable variable, then return true. + + Returns: + bool: indicate whether the variable is persistable. + """ + assert self.node.var() is not None, \ + "The node variable description cannot be None." + return self.node.var().persistable() + + @property + def inputs(self): + """ + Return the node inputs. + + Returns: + list(IrOpNode): node inputs wrapped by IrOpNode. + """ + return [IrOpNode(n) for n in self.node.inputs] + + @property + def outputs(self): + """ + Return the node outputs. + + Returns: + list(IrOpNode): node outputs wrapped by IrOpNode. + """ + return [IrOpNode(n) for n in self.node.outputs] + + +class IrOpNode(IrNode): + """ + Python IrOpNode. Beneath it is a core.Node, it inherits from IrNode. + """ + + def __init__(self, node): + """ + Construct an IrOpNode using core.Node. + + Args: + node(core.Node): C++ Node. + """ + assert isinstance(node, core.Node) and node.is_op(), \ + 'node must be the instance of core.Node and it must be a operator node.' + super(IrOpNode, self).__init__(node) + self.node = node + + def rename_input(self, old_input_name, new_input_name): + """ + Rename the input of this node. + + Args: + old_input_name(str): the old input name. + new_input_name(str): the new input name. + """ + assert self.node.op() is not None, \ + "The node operator description cannot be None." + self.node.op()._rename_input(old_input_name, new_input_name) + + @property + def inputs(self): + """ + Return the node inputs. + + Returns: + list(IrVarNode): node inputs wrapped by IrVarNode. + """ + return [IrVarNode(n) for n in self.node.inputs] + + @property + def outputs(self): + """ + Return the node outputs. + + Returns: + list(IrVarNode): node outputs wrapped by IrVarNode. + """ + return [IrVarNode(n) for n in self.node.outputs] + + class IrGraph(object): """ Python IrGraph. Beneath it is a core.Graph, which is used for - create a c++ Ir Pass Graph. An IrGraph is just a graph view of + creating a c++ Ir Pass Graph. An IrGraph is just a graph view of a Program. In an IrGraph, both Variables and Operators are graph nodes. """ @@ -1569,15 +1856,15 @@ class IrGraph(object): """ Return all nodes included in the graph as a set. """ - return {node for node in self.graph.nodes()} + return {IrNode(node) for node in self.graph.nodes()} - def all_vars(self): + def all_var_nodes(self): """ Return all variable nodes included in the graph as a set. """ - return {node for node in self.graph.nodes() if node.is_var()} + return {IrVarNode(node) for node in self.graph.nodes() if node.is_var()} - def all_persistable_vars(self): + def all_persistable_nodes(self): """ Return all persistable variable nodes included in the graph as a set. """ @@ -1586,13 +1873,13 @@ class IrGraph(object): if node.is_var() and node.var() is not None and node.var( ).persistable(): persistable_nodes.add(node) - return persistable_nodes + return {IrVarNode(p) for p in persistable_nodes} - def all_ops(self): + def all_op_nodes(self): """ Return all operator nodes included in the graph as a set. """ - return {node for node in self.graph.nodes() if node.is_op()} + return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} def var_node(self, name): """ @@ -1606,14 +1893,14 @@ class IrGraph(object): doesn't have a variable with the giving name. Returns: - core.Node: the variable node with the giving name. + IrVarNode: the variable node with the giving name. """ if not isinstance(name, six.string_types): raise TypeError( "var require string as parameter, but get %s instead." % (type(name))) target_var_node = None - var_nodes = self.all_vars() + var_nodes = self.all_var_nodes() for var_node in var_nodes: if var_node.name() == name: target_var_node = var_node @@ -1621,7 +1908,7 @@ class IrGraph(object): raise ValueError("var_node %s not in this graph" % name) return target_var_node - def create_param_node(self, name, var_type, shape, var_dtype): + def create_persistable_node(self, name, var_type, shape, var_dtype): """ Create a persistable variable node in the graph. In IrGraph, it can not distinguish between persistable variables and parameters. @@ -1633,14 +1920,14 @@ class IrGraph(object): var_dtype(core.VarDesc.VarType): the data type of the persistable variable node. Returns: - core.Node: the created persistable variable node. + IrVarNode: the created persistable variable node. """ var_desc = core.VarDesc(name) var_desc.set_type(var_type) var_desc.set_shape(shape) var_desc.set_dtype(var_dtype) var_desc.set_persistable(True) - return self.graph.create_var_node(var_desc) + return IrVarNode(self.graph.create_var_node(var_desc)) def create_var_node(self, name, var_type, shape, var_dtype): """ @@ -1654,14 +1941,14 @@ class IrGraph(object): var_dtype(core.VarDesc.VarType): the data type of the variable node. Returns: - core.Node: the created variable node. + IrVarNode: the created variable node. """ var_desc = core.VarDesc(name) var_desc.set_type(var_type) var_desc.set_shape(shape) var_desc.set_dtype(var_dtype) - return self.graph.create_var_node(var_desc) + return IrVarNode(self.graph.create_var_node(var_desc)) def create_var_node_from_desc(self, var_desc): """ @@ -1672,9 +1959,9 @@ class IrGraph(object): var_desc(core.VarDesc): the giving variable description. Returns: - core.Node: the created variable node. + IrVarNode: the created variable node. """ - return self.graph.create_var_node(var_desc) + return IrVarNode(self.graph.create_var_node(var_desc)) def create_op_node(self, op_type, attrs, inputs, outputs): """ @@ -1687,7 +1974,7 @@ class IrGraph(object): outputs(dict): the outpus of the operator node. Returns: - core.Node: the created operator node. + IrOpNode: the created operator node. """ op_desc = core.OpDesc() op_desc.set_type(op_type) @@ -1703,7 +1990,7 @@ class IrGraph(object): var_nodes = [var_nodes] op_desc.set_output(output_name, [var_node.name() for var_node in var_nodes]) - return self.graph.create_op_node(op_desc) + return IrOpNode(self.graph.create_op_node(op_desc)) def create_op_node_from_desc(self, op_desc): """ @@ -1713,37 +2000,37 @@ class IrGraph(object): op_desc(core.VarDesc): the giving operator description. Returns: - core.Node: the created operator node. + IrOpNode: the created operator node. """ - return self.graph.create_op_node(op_desc) + return IrOpNode(self.graph.create_op_node(op_desc)) def update_input_link(self, old_input_node, new_input_node, op_node): """ Update the input's link of a operator node. Args: - old_input_node(core.Node): the old input node of the giving op_node. - new_input_node(core.Node): the new input node of the giving op_node. - op_node(core.Node): the operator node that is needed to update input's link. + old_input_node(IrNode): the old input node of the giving op_node. + new_input_node(IrNode): the new input node of the giving op_node. + op_node(IrOpNode): the operator node that is needed to update input's link. """ - assert old_input_node in self.graph.nodes() and new_input_node in \ - self.graph.nodes() and op_node in self.graph.nodes(), \ + assert old_input_node.node in self.graph.nodes() and new_input_node.node in \ + self.graph.nodes() and op_node.node in self.graph.nodes(), \ 'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.' old_input_node.outputs_remove(op_node) op_node.inputs_remove(old_input_node) new_input_node.outputs_append(op_node) op_node.inputs_append(new_input_node) - op_node.op()._rename_input(old_input_node.name(), new_input_node.name()) + op_node.rename_input(old_input_node.name(), new_input_node.name()) def link_to(self, node_in, node_out): """ Connect two nodes. Args: - node_in(core.Node): the input node. - node_out(core.Node): the output node. + node_in(IrNode): the input node. + node_out(IrNode): the output node. """ - assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \ + assert node_in.node in self.graph.nodes() and node_out.node in self.graph.nodes(), \ 'The two arguments(node_in&node_out) must be in the graph nodes.' node_in.outputs_append(node_out) node_out.inputs_append(node_in) @@ -1761,7 +2048,8 @@ class IrGraph(object): remove_nodes = set(remove_nodes) else: remove_nodes = {remove_nodes} - core.graph_safe_remove_nodes(self.graph, remove_nodes) + original_nodes = {n.node for n in remove_nodes} + core.graph_safe_remove_nodes(self.graph, original_nodes) def has_circle(self): """ @@ -1788,18 +2076,23 @@ class IrGraph(object): Notes: the `graph` cannot contain a circle. Returns: - set(core.Node): nodes in topology order. + set(IrNode): nodes in topology order. """ - return core.topology_sort(self.graph) + ordered_nodes = core.topology_sort(self.graph) + return {IrNode(n) for n in ordered_nodes} def build_adjacency_list(self): """ Build an adjacency list of operations for the `graph`. Returns: - dict{core.Node: set(core.Node)}: the adjacency list. + dict{IrNode: set(IrNode)}: the adjacency list. """ - return core.build_adjacency_list(self.graph) + adj_list = core.build_adjacency_list(self.graph) + wrapped_adj_list = dict() + for k, v in six.iteritems(adj_list): + wrapped_adj_list[IrNode(k)] = {IrNode(n) for n in v} + return wrapped_adj_list def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True): """ @@ -1809,7 +2102,7 @@ class IrGraph(object): Args: save_path(str): the save path of drawn graph. name(str): the name of drawn graph. - marked_nodes(set(core.Node)): nodes that are needed to be marked. + marked_nodes(set(IrNode)): nodes that are needed to be marked. Default value is None. remove_ctr_var(bool): If it is set True, all control variable nodes in the graph will be removed. Default value is True. @@ -1824,20 +2117,22 @@ class IrGraph(object): print('The {} is saved as the dot filetype.'.format( dot_file_path)) + remove_ctr_vars = set() if remove_ctr_var: - remove_ctr_vars = set() - for node in self.graph.nodes(): + for node in self.all_var_nodes(): if node.is_ctrl_var(): remove_ctr_vars.add(node) self.safe_remove_nodes(remove_ctr_vars) - ops_num = 0 - for node in self.graph.nodes(): - if node.is_op(): - ops_num += 1 - print('Total ops num = {}.'.format(ops_num)) + print('Total ops num = {}.'.format(len(self.all_op_nodes()))) + if marked_nodes is not None: if not isinstance(marked_nodes, set): - marked_nodes = set(marked_nodes) + if isinstance(marked_nodes, Iterable): + marked_nodes = set(marked_nodes) + else: + marked_nodes = {marked_nodes} + marked_nodes = {n.node for n in marked_nodes} + remove_ctr_vars = {n.node for n in remove_ctr_vars} marked_nodes = marked_nodes - remove_ctr_vars if self.graph.has('__graphviz__marked_node__'): self.graph.erase('__graphviz__marked_node__') -- GitLab