diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index ba0d4bb4355e9c7d59a672dc300386ffcf655115..24059140ab20e24917b93a5f60936b1087797ff9 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -148,8 +148,8 @@ void BindNode(py::module *m) { }) .def("outputs_append", [](Node &self, Node &node) { self.outputs.push_back(&node); }) - .def_readonly("inputs", &Node::inputs) - .def_readonly("outputs", &Node::outputs); + .def_readwrite("inputs", &Node::inputs) + .def_readwrite("outputs", &Node::outputs); py::enum_(node, "Type") .value("Operation", Node::Type::kOperation) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index c55e8b047590b89edc39a79ca3df1cf1e337f791..c470483756659b55329c022e0c43002182db815b 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -797,18 +797,18 @@ All parameter, weight, gradient are variables in Paddle. py::class_> pass(m, "Pass"); pass.def(py::init()) .def("has", &ir::Pass::Has) - .def("set_program", + .def("set", [](ir::Pass &self, const std::string &attr_name, const ProgramDesc &attr) { return self.Set(attr_name, new ProgramDesc(attr)); }) .def( - "set_str", + "set", [](ir::Pass &self, const std::string &name, const std::string &attr) { self.Set(name, new std::string(attr)); }) - .def("set_int", [](ir::Pass &self, const std::string &name, - int val) { self.Set(name, new int(val)); }) + .def("set", [](ir::Pass &self, const std::string &name, + int val) { self.Set(name, new int(val)); }) .def("get_program", &ir::Pass::Get) .def("type", &ir::Pass::Type) .def("apply", [](ir::Pass &self, std::shared_ptr graph) { diff --git a/python/paddle/fluid/contrib/slim/graph/graph.py b/python/paddle/fluid/contrib/slim/graph/graph.py index 80deeee87934b1ae378281c1d427b1cdf97c749e..f38d9783413a01cd1005a014c0aba5ecf5cc79c2 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph.py +++ b/python/paddle/fluid/contrib/slim/graph/graph.py @@ -18,140 +18,7 @@ from ....framework import Program from ....framework import Block from .... import core -__all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph'] - - -class PyGraph(object): - """ - PyGraph uses core.Graph as the delegation to accomplish the manipulation. - """ - - def __init__(self, graph, for_test=False): - """ - Construct the PyGraph using core.Graph. - Args: - graph(core.Graph): C++ Graph. - for_test(bool): True for the test graph and false for the train graph. - """ - assert isinstance( - graph, core.Graph), 'graph must be the instance of core.Graph.' - self.graph = graph - self.for_test = for_test - - def is_test(self): - return self.for_test - - def all_parameters(self): - param_nodes = set() - for node in self.graph.nodes(): - if node.is_var() and node.var() is not None and node.var( - ).persistable(): - param_nodes.add(node) - return param_nodes - - def all_vars(self): - return {node for node in self.graph.nodes() if node.is_var()} - - def all_ops(self): - return {node for node in self.graph.nodes() if node.is_op()} - - def create_param_node(self, name, var_type, shape, var_dtype): - var_desc = core.VarDesc(name) - var_desc.set_type(var_type) - var_desc.set_shape(shape) - var_desc.set_dtype(var_dtype) - var_desc.set_persistable(True) - return self.graph.create_var_node(var_desc) - - def create_var_node(self, name, var_type, shape, var_dtype): - var_desc = core.VarDesc(name) - var_desc.set_type(var_type) - var_desc.set_shape(shape) - var_desc.set_dtype(var_dtype) - return self.graph.create_var_node(var_desc) - - def create_var_node_from_desc(self, var_desc): - return self.graph.create_var_node(var_desc) - - def create_op_node(self, op_type, attrs, inputs, outputs): - op_desc = core.OpDesc() - op_desc.set_type(op_type) - for attr, value in attrs.iteritems(): - self._update_desc_attr(op_desc, attr, value) - for input_name, var_nodes in inputs.iteritems(): - if not isinstance(var_nodes, list): - var_nodes = [var_nodes] - op_desc.set_input(input_name, - [var_node.name() for var_node in var_nodes]) - for output_name, var_nodes in outputs.iteritems(): - if not isinstance(var_nodes, list): - var_nodes = [var_nodes] - op_desc.set_output(output_name, - [var_node.name() for var_node in var_nodes]) - return self.graph.create_op_node(op_desc) - - def create_op_node_from_desc(self, op_desc): - return self.graph.create_op_node(op_desc) - - def _update_desc_attr(self, desc, name, val): - """ - Update the value of desc's attribute by attribute's name. - """ - if isinstance(val, Block): - desc.set_block_attr(name, val.desc) - elif isinstance(val, list) and val and all( - isinstance(v, Block) for v in val): - desc.set_blocks_attr(name, [v.desc for v in val]) - elif isinstance(val, core.BlockDesc) or \ - isinstance(val, core.ProgramDesc): - desc.set_serialized_attr(name, val.serialize_to_string()) - else: - desc._set_attr(name, val) - - def safe_remove_nodes(self, remove_nodes): - if not isinstance(remove_nodes, set): - remove_nodes = set(remove_nodes) - core.graph_safe_remove_nodes(self.graph, remove_nodes) - - def draw(self, save_path, name, marked_nodes=None): - def _convert_to_pdf(dot_file_path): - pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' - exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ - + ' -o ' + pdf_save_path, shell=True) - if exited_code != 0: - print('The dot command is needed for creating pdf files.') - print('The {} is saved as the dot filetype.'.format( - dot_file_path)) - - remove_ctr_vars = set() - ops_num = 0 - for node in self.graph.nodes(): - if node.is_ctrl_var(): - remove_ctr_vars.add(node) - elif node.is_op(): - ops_num += 1 - print('Total ops num = {}.'.format(ops_num)) - self.safe_remove_nodes(remove_ctr_vars) - if marked_nodes is not None: - if not isinstance(marked_nodes, set): - marked_nodes = set(marked_nodes) - marked_nodes = marked_nodes - remove_ctr_vars - if self.graph.has('__graphviz__marked_node__'): - self.graph.erase('__graphviz__marked_node__') - self.graph.set('__graphviz__marked_node__', marked_nodes) - viz_dot_path = os.path.join(save_path, name) + '.dot' - viz_pass = core.get_pass('graph_viz_pass') - viz_pass.set_str('graph_viz_path', viz_dot_path) - viz_pass.apply(self.graph) - _convert_to_pdf(viz_dot_path) - - def to_program(self): - convert_pass = core.get_pass('graph_to_program_pass') - convert_pass.set_program('program', Program().desc) - convert_pass.apply(self.graph) - desc = convert_pass.get_program('program') - program = Program.construct_from_desc(desc) - return program +__all__ = ['Graph', 'ImitationGraph', 'IRGraph'] class Graph(object): diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 3c33a513ff385dd246599a12c3a68c72ec08f4cd..ce16a32415e3ff86224fb444cdb41b11ed74b9cd 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -13,13 +13,12 @@ # limitations under the License. import collections -import numpy as np from .... import core +from ....framework import IrGraph from ....framework import Program from ....framework import Variable from ....initializer import Constant from .... import unique_name -from ..graph import PyGraph __all__ = ['QuantizationTransformPass'] @@ -34,7 +33,7 @@ class QuantizationTransformPass(object): weight_quantize_type='abs_max', window_size=10000): """ - Convert and rewrite the PyGraph according to weight and + Convert and rewrite the IrGraph according to weight and activation quantization type. Args: weight_bits (int): quantization bit number for weights, @@ -56,19 +55,19 @@ class QuantizationTransformPass(object): import paddle.fluid as fluid from paddle.fluid.contrib.slim.quantization \ import QuantizationTransformPass - from paddle.fluid.contrib.slim.graph import PyGraph + from paddle.fluid.contrib.slim.graph import IrGraph from paddle.fluid import core - graph = PyGraph(core.Graph(program.desc), for_test=False) + graph = IrGraph(core.Graph(program.desc), for_test=False) exe = fluid.Executor(fluid.CPUPlace()) transform_pass = QuantizationTransformPass(fluid.global_scope(), exe) transform_pass.apply(graph) """ - self.scope = scope - self.program_exe = program_exe - self.weight_bits = weight_bits - self.activation_bits = activation_bits + self._scope = scope + self._program_exe = program_exe + self._weight_bits = weight_bits + self._activation_bits = activation_bits quant_type = ['abs_max', 'range_abs_max'] if activation_quantize_type not in quant_type: @@ -80,27 +79,27 @@ class QuantizationTransformPass(object): "Unknown weight_quantize_type: '%s'. It can only be ", "'abs_max' or 'range_abs_max'.", str(weight_quantize_type)) - self.activation_quantize_type = activation_quantize_type - self.weight_quantize_type = weight_quantize_type - self.window_size = window_size + self._activation_quantize_type = activation_quantize_type + self._weight_quantize_type = weight_quantize_type + self._window_size = window_size - self.need_initialized = collections.OrderedDict() - self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] - self.quantizable_grad_ops = [ - '%s_grad' % (op) for op in self.quantizable_ops + self._need_initialized = collections.OrderedDict() + self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._quantizable_grad_ops = [ + '%s_grad' % (op) for op in self._quantizable_ops ] - self.fake_quant_op_types = [ + self._fake_quant_op_types = [ 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' ] - self.fake_dequant_op_types = ['fake_dequantize_max_abs'] - self.is_test = None - self.global_step = None + self._fake_dequant_op_types = ['fake_dequantize_max_abs'] + self._is_test = None + self._global_step = None def apply(self, graph): assert isinstance(graph, - PyGraph), 'graph must be the instance of PyGraph.' - self.need_initialized.clear() - self.is_test = graph.is_test() + IrGraph), 'graph must be the instance of IrGraph.' + self._need_initialized.clear() + self._is_test = graph.is_test() # marked the variable which has been dequantized. dequantized_vars = collections.OrderedDict() params = [p.name() for p in graph.all_parameters()] @@ -110,72 +109,69 @@ class QuantizationTransformPass(object): if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] else: - quant_bits = self.weight_bits if var_node.name() in params \ - else self.activation_bits - quant_type = self.weight_quantize_type if var_node.name() \ - in params else self.activation_quantize_type + quant_bits = self._weight_bits if var_node.name() in params \ + else self._activation_bits + quant_type = self._weight_quantize_type if var_node.name() \ + in params else self._activation_quantize_type quant_var_node, scale_var_node = self._insert_quant_op( graph, var_node, quant_bits, quant_type) dequant_var_node = self._insert_dequant_op( graph, quant_var_node, scale_var_node, quant_bits) dequantized_vars[var_node.name()] = dequant_var_node - self._update_input(var_node, dequant_var_node, op) - op.op()._rename_input(var_node.name(), dequant_var_node.name()) + graph.update_input_link(var_node, dequant_var_node, op) def _transform_backward(graph, op): no_dequanted_input_vars = True for var_node in op.inputs: if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] - self._update_input(var_node, dequant_var_node, op) - op.op()._rename_input(var_node.name(), - dequant_var_node.name()) + graph.update_input_link(var_node, dequant_var_node, op) no_dequanted_input_vars = False if no_dequanted_input_vars: raise ValueError("There is no dequanted inputs for op %s." % (op.name())) - if not self.is_test: + if not self._is_test: self._create_global_step(graph) ops = graph.all_ops() # The process of _transform_forward and _transform_backward is needed in two for loops. # The loop for transforming the forward graph: for op in ops: - if op.name() in self.quantizable_ops: + if op.name() in self._quantizable_ops: _transform_forward(graph, op) # The loop for renaming the inputs of backward op. for op in ops: - if op.name() in self.quantizable_grad_ops: + if op.name() in self._quantizable_grad_ops: _transform_backward(graph, op) - if len(self.need_initialized) > 0: - assert self.scope is not None, \ + if len(self._need_initialized) > 0: + assert self._scope is not None, \ 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' - assert self.program_exe is not None, \ + assert self._program_exe is not None, \ 'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.' init_program = Program() - for var_desc, initializer in self.need_initialized.iteritems(): - var = Variable.construct_from_desc(init_program.global_block(), - var_desc) + for var_desc, initializer in self._need_initialized.iteritems(): + var = Variable(init_program.global_block()) + var._set_desc(var_desc) initializer(var, init_program.global_block()) - self.program_exe.run(program=init_program, scope=self.scope) + self._program_exe.run(program=init_program, scope=self._scope) return graph def _create_global_step(self, graph): - if self.weight_quantize_type == 'range_abs_max' or \ - self.activation_quantize_type == 'range_abs_max': + if self._weight_quantize_type == 'range_abs_max' or \ + self._activation_quantize_type == 'range_abs_max': counter_name = '@STEP_COUNTER@' for node in graph.all_vars(): if node.name() == counter_name: - self.global_step = node - if self.global_step is None: + self._global_step = node + if self._global_step is None: global_step_in = graph.create_param_node( name=counter_name, var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=core.VarDesc.VarType.INT64) - self.need_initialized[global_step_in.var()] = \ + self._need_initialized[global_step_in.var()] = \ Constant(value=0, force_cpu=True) global_step_out = graph.create_var_node_from_desc( global_step_in.var()) @@ -184,9 +180,9 @@ class QuantizationTransformPass(object): attrs={'step': 1.0}, inputs={'X': global_step_in}, outputs={'Out': global_step_out}) - self._link_to(global_step_in, increment_op) - self._link_to(increment_op, global_step_out) - self.global_step = global_step_out + graph.link_to(global_step_in, increment_op) + graph.link_to(increment_op, global_step_out) + self._global_step = global_step_out def _insert_quant_op(self, graph, var_node, quant_bits, quant_type): """ @@ -220,9 +216,9 @@ class QuantizationTransformPass(object): inputs={'X': var_node}, outputs={'Out': quant_var_node, 'OutScale': scale_var_node}) - self._link_to(var_node, quant_op_node) - self._link_to(quant_op_node, quant_var_node) - self._link_to(quant_op_node, scale_var_node) + graph.link_to(var_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + graph.link_to(quant_op_node, scale_var_node) return quant_var_node, scale_var_node def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits): @@ -242,26 +238,26 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.var().dtype()) - self.need_initialized[scale_in_node.var()] = Constant(value=0.001) + self._need_initialized[scale_in_node.var()] = Constant(value=0.001) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) inputs = {'X': var_node, 'InScale': scale_in_node} outputs = {'Out': quant_var_node, 'OutScale': scale_out_node} - if not self.is_test: + if not self._is_test: # The name of scales_var_node maybe 'scales_0', 'scales_1', etc. scales_node = graph.create_param_node( name=unique_name.generate('scales'), var_type=core.VarDesc.VarType.LOD_TENSOR, - shape=[self.window_size], + shape=[self._window_size], var_dtype=var_node.var().dtype()) - self.need_initialized[scales_node.var()] = Constant(value=0) - inputs['Iter'] = self.global_step + self._need_initialized[scales_node.var()] = Constant(value=0) + inputs['Iter'] = self._global_step outputs['OutScales'] = scales_node attrs = { - 'window_size': self.window_size, + 'window_size': self._window_size, 'bit_length': quant_bits, - 'is_test': self.is_test + 'is_test': self._is_test } quant_op_node = graph.create_op_node( op_type='fake_quantize_range_abs_max', @@ -269,14 +265,14 @@ class QuantizationTransformPass(object): inputs=inputs, outputs=outputs) - self._link_to(var_node, quant_op_node) - self._link_to(scale_in_node, quant_op_node) - self._link_to(quant_op_node, quant_var_node) - self._link_to(quant_op_node, scale_out_node) + graph.link_to(var_node, quant_op_node) + graph.link_to(scale_in_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + graph.link_to(quant_op_node, scale_out_node) - if not self.is_test: - self._link_to(self.global_step, quant_op_node) - self._link_to(quant_op_node, scales_node) + if not self._is_test: + graph.link_to(self._global_step, quant_op_node) + graph.link_to(quant_op_node, scales_node) return quant_var_node, scale_out_node @@ -298,21 +294,11 @@ class QuantizationTransformPass(object): inputs={'X': var_node, 'Scale': scale_var_node}, outputs={'Out': dequant_var_node}) - self._link_to(var_node, dequant_op_node) - self._link_to(scale_var_node, dequant_op_node) - self._link_to(dequant_op_node, dequant_var_node) + graph.link_to(var_node, dequant_op_node) + graph.link_to(scale_var_node, dequant_op_node) + graph.link_to(dequant_op_node, dequant_var_node) return dequant_var_node - def _update_input(self, old_input_node, new_input_node, op_node): - old_input_node.outputs_remove(op_node) - op_node.inputs_remove(old_input_node) - new_input_node.outputs_append(op_node) - op_node.inputs_append(new_input_node) - - def _link_to(self, node_in, node_out): - node_in.outputs_append(node_out) - node_out.inputs_append(node_in) - def _quantized_var_name(self, var_name): """ Return quantized variable name for the input `var_name`. @@ -330,25 +316,3 @@ class QuantizationTransformPass(object): Return quantized variable name for the input `var_name`. """ return "%s.scale" % (var_name) - - def _original_var_name(self, var_name): - """ - Return the original variable name. - """ - if var_name.endswith('.quantized.dequantized'): - return var_name[:-len('.quantized.dequantized')] - if var_name.endswith('.quantized'): - return var_name[:-len('.quantized')] - if var_name.endswith('.dequantized'): - return var_name[:-len('.dequantized')] - if var_name.endswith('.scale'): - return var_name[:-len('.scale')] - else: - return var_name - - def _is_float(self, v): - return isinstance(v, float) or isinstance(v, np.float32) - - def _quant(self, x, scale, num_bits): - y = np.round(x / scale * ((1 << (num_bits - 1)) - 1)) - return y diff --git a/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py index 31188bedbbe0f5d7478473e1e8d0b627298b4413..1bd4b95d6b90b7f16d507061190f0b463f6c4cc5 100644 --- a/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/unitest/test_quantization_pass.py @@ -18,8 +18,8 @@ import numpy as np import paddle.fluid as fluid import six from paddle.fluid.framework import Program +from paddle.fluid.framework import IrGraph from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass -from paddle.fluid.contrib.slim.graph import PyGraph from paddle.fluid import core @@ -106,7 +106,7 @@ class TestQuantizationTransformPass(unittest.TestCase): opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) exe = fluid.Executor(fluid.CPUPlace()) - graph = PyGraph(core.Graph(main.desc), for_test=False) + graph = IrGraph(core.Graph(main.desc), for_test=False) transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), program_exe=exe, @@ -119,7 +119,7 @@ class TestQuantizationTransformPass(unittest.TestCase): graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) program = graph.to_program() self.check_program(transform_pass, program) - val_graph = PyGraph(core.Graph(program.desc), for_test=False) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_marked_nodes = set() for op in val_graph.all_ops(): if op.name().find('quantize') > -1: @@ -142,7 +142,7 @@ class TestQuantizationTransformPass(unittest.TestCase): opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) exe = fluid.Executor(fluid.CPUPlace()) - graph = PyGraph(core.Graph(main.desc), for_test=False) + graph = IrGraph(core.Graph(main.desc), for_test=False) transform_pass = QuantizationTransformPass( scope=fluid.global_scope(), program_exe=exe, @@ -155,7 +155,7 @@ class TestQuantizationTransformPass(unittest.TestCase): graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) program = graph.to_program() self.check_program(transform_pass, program) - val_graph = PyGraph(core.Graph(program.desc), for_test=False) + val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_marked_nodes = set() for op in val_graph.all_ops(): if op.name().find('quantize') > -1: diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 1913f58e67934a3b13d989b4f090156c6f4ce7fa..fc5e471ae3041b0245c873d85efa8b49f3a43678 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -23,6 +23,7 @@ import traceback import six import numpy as np +import subprocess from .. import compat as cpt from .proto import framework_pb2 @@ -381,27 +382,6 @@ class Variable(object): self._ivar.desc = self.desc self._ivar.stop_gradient = stop_gradient - @staticmethod - def construct_from_desc(block, desc): - """ - Construct a Variable from variable desc. - Args: - desc(core.VarDesc): The variable desc for constructing. - - Returns: - Variable: A variable. - """ - v = Variable( - block=block, - type=desc.type(), - name=desc.name(), - shape=desc.shape(), - dtype=desc.dtype(), - lod_level=desc.lod_level(), - persistable=desc.persistable()) - v.desc = desc - return v - def _numpy(self): tensor = self._ivar.value().get_tensor() return np.array(tensor) @@ -1533,6 +1513,154 @@ class Block(object): return ret_var +class IrGraph(object): + """ + IrGraph uses core.Graph as the delegation to accomplish the manipulation. + """ + + def __init__(self, graph, for_test=False): + """ + Construct the IrGraph using core.Graph. + Args: + graph(core.Graph): C++ Graph. + for_test(bool): True for the test graph and false for the train graph. + """ + assert isinstance( + graph, core.Graph), 'graph must be the instance of core.Graph.' + self.graph = graph + self._for_test = for_test + + def is_test(self): + return self._for_test + + def all_parameters(self): + param_nodes = set() + for node in self.graph.nodes(): + if node.is_var() and node.var() is not None and node.var( + ).persistable(): + param_nodes.add(node) + return param_nodes + + def all_vars(self): + return {node for node in self.graph.nodes() if node.is_var()} + + def all_ops(self): + return {node for node in self.graph.nodes() if node.is_op()} + + def create_param_node(self, name, var_type, shape, var_dtype): + var_desc = core.VarDesc(name) + var_desc.set_type(var_type) + var_desc.set_shape(shape) + var_desc.set_dtype(var_dtype) + var_desc.set_persistable(True) + return self.graph.create_var_node(var_desc) + + def create_var_node(self, name, var_type, shape, var_dtype): + var_desc = core.VarDesc(name) + var_desc.set_type(var_type) + var_desc.set_shape(shape) + var_desc.set_dtype(var_dtype) + return self.graph.create_var_node(var_desc) + + def create_var_node_from_desc(self, var_desc): + return self.graph.create_var_node(var_desc) + + def create_op_node(self, op_type, attrs, inputs, outputs): + op_desc = core.OpDesc() + op_desc.set_type(op_type) + for attr, value in attrs.iteritems(): + self._update_desc_attr(op_desc, attr, value) + for input_name, var_nodes in inputs.iteritems(): + if not isinstance(var_nodes, list): + var_nodes = [var_nodes] + op_desc.set_input(input_name, + [var_node.name() for var_node in var_nodes]) + for output_name, var_nodes in outputs.iteritems(): + if not isinstance(var_nodes, list): + var_nodes = [var_nodes] + op_desc.set_output(output_name, + [var_node.name() for var_node in var_nodes]) + return self.graph.create_op_node(op_desc) + + def create_op_node_from_desc(self, op_desc): + return self.graph.create_op_node(op_desc) + + def update_input_link(self, old_input_node, new_input_node, op_node): + assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \ + op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.' + old_input_node.outputs_remove(op_node) + op_node.inputs_remove(old_input_node) + new_input_node.outputs_append(op_node) + op_node.inputs_append(new_input_node) + op_node.op()._rename_input(old_input_node.name(), new_input_node.name()) + + def link_to(self, node_in, node_out): + assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \ + 'Th two arguments must be in the graph nodes.' + node_in.outputs_append(node_out) + node_out.inputs_append(node_in) + + def safe_remove_nodes(self, remove_nodes): + if not isinstance(remove_nodes, set): + remove_nodes = set(remove_nodes) + core.graph_safe_remove_nodes(self.graph, remove_nodes) + + def draw(self, save_path, name, marked_nodes=None): + def _convert_to_pdf(dot_file_path): + pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' + exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ + + ' -o ' + pdf_save_path, shell=True) + if exited_code != 0: + print('The dot command is needed for creating pdf files.') + print('The {} is saved as the dot filetype.'.format( + dot_file_path)) + + remove_ctr_vars = set() + ops_num = 0 + for node in self.graph.nodes(): + if node.is_ctrl_var(): + remove_ctr_vars.add(node) + elif node.is_op(): + ops_num += 1 + print('Total ops num = {}.'.format(ops_num)) + self.safe_remove_nodes(remove_ctr_vars) + if marked_nodes is not None: + if not isinstance(marked_nodes, set): + marked_nodes = set(marked_nodes) + marked_nodes = marked_nodes - remove_ctr_vars + if self.graph.has('__graphviz__marked_node__'): + self.graph.erase('__graphviz__marked_node__') + self.graph.set('__graphviz__marked_node__', marked_nodes) + viz_dot_path = os.path.join(save_path, name) + '.dot' + viz_pass = core.get_pass('graph_viz_pass') + viz_pass.set('graph_viz_path', viz_dot_path) + viz_pass.apply(self.graph) + _convert_to_pdf(viz_dot_path) + + def to_program(self): + convert_pass = core.get_pass('graph_to_program_pass') + convert_pass.set('program', Program().desc) + convert_pass.apply(self.graph) + desc = convert_pass.get_program('program') + program = Program._construct_from_desc(desc) + return program + + def _update_desc_attr(self, desc, name, val): + """ + Update the value of desc's attribute by attribute's name. + """ + if isinstance(val, Block): + desc.set_block_attr(name, val.desc) + elif isinstance(val, list) and val and all( + isinstance(v, Block) for v in val): + desc.set_blocks_attr(name, [v.desc for v in val]) + elif isinstance(val, core.BlockDesc) or \ + isinstance(val, core.ProgramDesc): + desc.set_serialized_attr(name, val.serialize_to_string()) + else: + desc._set_attr(name, val) + + class Program(object): """ Python Program. Beneath it is a ProgramDesc, which is used for @@ -1958,12 +2086,10 @@ class Program(object): return p @staticmethod - def construct_from_desc(desc): + def _construct_from_desc(desc): """ Construct a program from program desc. - Notes: All information about parameters will be lost. - Args: desc(core.ProgramDesc): The program desc for constructing.