提交 b913463e 编写于 作者: W WangZhen 提交者: root

Update according to the reviewers' suggestion. test=develop

上级 3ce61720
......@@ -148,8 +148,8 @@ void BindNode(py::module *m) {
})
.def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readonly("inputs", &Node::inputs)
.def_readonly("outputs", &Node::outputs);
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
py::enum_<Node::Type>(node, "Type")
.value("Operation", Node::Type::kOperation)
......
......@@ -797,17 +797,17 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
.def("set_program",
.def("set",
[](ir::Pass &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def(
"set_str",
"set",
[](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
})
.def("set_int", [](ir::Pass &self, const std::string &name,
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type)
......
......@@ -18,140 +18,7 @@ from ....framework import Program
from ....framework import Block
from .... import core
__all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph']
class PyGraph(object):
"""
PyGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def __init__(self, graph, for_test=False):
"""
Construct the PyGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph
self.for_test = for_test
def is_test(self):
return self.for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_vars(self):
return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self):
return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc):
return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs):
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
return self.graph.create_op_node(op_desc)
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
"""
if isinstance(val, Block):
desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string())
else:
desc._set_attr(name, val)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set_str('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set_program('program', Program().desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program.construct_from_desc(desc)
return program
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']
class Graph(object):
......
......@@ -13,13 +13,12 @@
# limitations under the License.
import collections
import numpy as np
from .... import core
from ....framework import IrGraph
from ....framework import Program
from ....framework import Variable
from ....initializer import Constant
from .... import unique_name
from ..graph import PyGraph
__all__ = ['QuantizationTransformPass']
......@@ -34,7 +33,7 @@ class QuantizationTransformPass(object):
weight_quantize_type='abs_max',
window_size=10000):
"""
Convert and rewrite the PyGraph according to weight and
Convert and rewrite the IrGraph according to weight and
activation quantization type.
Args:
weight_bits (int): quantization bit number for weights,
......@@ -56,19 +55,19 @@ class QuantizationTransformPass(object):
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core
graph = PyGraph(core.Graph(program.desc), for_test=False)
graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace())
transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe)
transform_pass.apply(graph)
"""
self.scope = scope
self.program_exe = program_exe
self.weight_bits = weight_bits
self.activation_bits = activation_bits
self._scope = scope
self._program_exe = program_exe
self._weight_bits = weight_bits
self._activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max']
if activation_quantize_type not in quant_type:
......@@ -80,27 +79,27 @@ class QuantizationTransformPass(object):
"Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(weight_quantize_type))
self.activation_quantize_type = activation_quantize_type
self.weight_quantize_type = weight_quantize_type
self.window_size = window_size
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._window_size = window_size
self.need_initialized = collections.OrderedDict()
self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self.quantizable_grad_ops = [
'%s_grad' % (op) for op in self.quantizable_ops
self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
]
self.fake_quant_op_types = [
self._fake_quant_op_types = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
]
self.fake_dequant_op_types = ['fake_dequantize_max_abs']
self.is_test = None
self.global_step = None
self._fake_dequant_op_types = ['fake_dequantize_max_abs']
self._is_test = None
self._global_step = None
def apply(self, graph):
assert isinstance(graph,
PyGraph), 'graph must be the instance of PyGraph.'
self.need_initialized.clear()
self.is_test = graph.is_test()
IrGraph), 'graph must be the instance of IrGraph.'
self._need_initialized.clear()
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
params = [p.name() for p in graph.all_parameters()]
......@@ -110,72 +109,69 @@ class QuantizationTransformPass(object):
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
else:
quant_bits = self.weight_bits if var_node.name() in params \
else self.activation_bits
quant_type = self.weight_quantize_type if var_node.name() \
in params else self.activation_quantize_type
quant_bits = self._weight_bits if var_node.name() in params \
else self._activation_bits
quant_type = self._weight_quantize_type if var_node.name() \
in params else self._activation_quantize_type
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, quant_type)
dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits)
dequantized_vars[var_node.name()] = dequant_var_node
self._update_input(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(), dequant_var_node.name())
graph.update_input_link(var_node, dequant_var_node, op)
def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs:
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
self._update_input(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(),
dequant_var_node.name())
graph.update_input_link(var_node, dequant_var_node, op)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
(op.name()))
if not self.is_test:
if not self._is_test:
self._create_global_step(graph)
ops = graph.all_ops()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self.quantizable_ops:
if op.name() in self._quantizable_ops:
_transform_forward(graph, op)
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self.quantizable_grad_ops:
if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op)
if len(self.need_initialized) > 0:
assert self.scope is not None, \
if len(self._need_initialized) > 0:
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self.program_exe is not None, \
assert self._program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in self.need_initialized.iteritems():
var = Variable.construct_from_desc(init_program.global_block(),
var_desc)
for var_desc, initializer in self._need_initialized.iteritems():
var = Variable(init_program.global_block())
var._set_desc(var_desc)
initializer(var, init_program.global_block())
self.program_exe.run(program=init_program, scope=self.scope)
self._program_exe.run(program=init_program, scope=self._scope)
return graph
def _create_global_step(self, graph):
if self.weight_quantize_type == 'range_abs_max' or \
self.activation_quantize_type == 'range_abs_max':
if self._weight_quantize_type == 'range_abs_max' or \
self._activation_quantize_type == 'range_abs_max':
counter_name = '@STEP_COUNTER@'
for node in graph.all_vars():
if node.name() == counter_name:
self.global_step = node
if self.global_step is None:
self._global_step = node
if self._global_step is None:
global_step_in = graph.create_param_node(
name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self.need_initialized[global_step_in.var()] = \
self._need_initialized[global_step_in.var()] = \
Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
......@@ -184,9 +180,9 @@ class QuantizationTransformPass(object):
attrs={'step': 1.0},
inputs={'X': global_step_in},
outputs={'Out': global_step_out})
self._link_to(global_step_in, increment_op)
self._link_to(increment_op, global_step_out)
self.global_step = global_step_out
graph.link_to(global_step_in, increment_op)
graph.link_to(increment_op, global_step_out)
self._global_step = global_step_out
def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
"""
......@@ -220,9 +216,9 @@ class QuantizationTransformPass(object):
inputs={'X': var_node},
outputs={'Out': quant_var_node,
'OutScale': scale_var_node})
self._link_to(var_node, quant_op_node)
self._link_to(quant_op_node, quant_var_node)
self._link_to(quant_op_node, scale_var_node)
graph.link_to(var_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node
def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
......@@ -242,26 +238,26 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.var().dtype())
self.need_initialized[scale_in_node.var()] = Constant(value=0.001)
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node}
outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}
if not self.is_test:
if not self._is_test:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node = graph.create_param_node(
name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self.window_size],
shape=[self._window_size],
var_dtype=var_node.var().dtype())
self.need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self.global_step
self._need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
attrs = {
'window_size': self.window_size,
'window_size': self._window_size,
'bit_length': quant_bits,
'is_test': self.is_test
'is_test': self._is_test
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max',
......@@ -269,14 +265,14 @@ class QuantizationTransformPass(object):
inputs=inputs,
outputs=outputs)
self._link_to(var_node, quant_op_node)
self._link_to(scale_in_node, quant_op_node)
self._link_to(quant_op_node, quant_var_node)
self._link_to(quant_op_node, scale_out_node)
graph.link_to(var_node, quant_op_node)
graph.link_to(scale_in_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_out_node)
if not self.is_test:
self._link_to(self.global_step, quant_op_node)
self._link_to(quant_op_node, scales_node)
if not self._is_test:
graph.link_to(self._global_step, quant_op_node)
graph.link_to(quant_op_node, scales_node)
return quant_var_node, scale_out_node
......@@ -298,21 +294,11 @@ class QuantizationTransformPass(object):
inputs={'X': var_node,
'Scale': scale_var_node},
outputs={'Out': dequant_var_node})
self._link_to(var_node, dequant_op_node)
self._link_to(scale_var_node, dequant_op_node)
self._link_to(dequant_op_node, dequant_var_node)
graph.link_to(var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
def _update_input(self, old_input_node, new_input_node, op_node):
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
def _link_to(self, node_in, node_out):
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
......@@ -330,25 +316,3 @@ class QuantizationTransformPass(object):
Return quantized variable name for the input `var_name`.
"""
return "%s.scale" % (var_name)
def _original_var_name(self, var_name):
"""
Return the original variable name.
"""
if var_name.endswith('.quantized.dequantized'):
return var_name[:-len('.quantized.dequantized')]
if var_name.endswith('.quantized'):
return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[:-len('.scale')]
else:
return var_name
def _is_float(self, v):
return isinstance(v, float) or isinstance(v, np.float32)
def _quant(self, x, scale, num_bits):
y = np.round(x / scale * ((1 << (num_bits - 1)) - 1))
return y
......@@ -18,8 +18,8 @@ import numpy as np
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid import core
......@@ -106,7 +106,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = PyGraph(core.Graph(main.desc), for_test=False)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
......@@ -119,7 +119,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
......@@ -142,7 +142,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace())
graph = PyGraph(core.Graph(main.desc), for_test=False)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
......@@ -155,7 +155,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
......
......@@ -23,6 +23,7 @@ import traceback
import six
import numpy as np
import subprocess
from .. import compat as cpt
from .proto import framework_pb2
......@@ -381,27 +382,6 @@ class Variable(object):
self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient
@staticmethod
def construct_from_desc(block, desc):
"""
Construct a Variable from variable desc.
Args:
desc(core.VarDesc): The variable desc for constructing.
Returns:
Variable: A variable.
"""
v = Variable(
block=block,
type=desc.type(),
name=desc.name(),
shape=desc.shape(),
dtype=desc.dtype(),
lod_level=desc.lod_level(),
persistable=desc.persistable())
v.desc = desc
return v
def _numpy(self):
tensor = self._ivar.value().get_tensor()
return np.array(tensor)
......@@ -1533,6 +1513,154 @@ class Block(object):
return ret_var
class IrGraph(object):
"""
IrGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def __init__(self, graph, for_test=False):
"""
Construct the IrGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph
self._for_test = for_test
def is_test(self):
return self._for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_vars(self):
return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self):
return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc):
return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs):
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node):
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out):
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'Th two arguments must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc)
return program
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
"""
if isinstance(val, Block):
desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string())
else:
desc._set_attr(name, val)
class Program(object):
"""
Python Program. Beneath it is a ProgramDesc, which is used for
......@@ -1958,12 +2086,10 @@ class Program(object):
return p
@staticmethod
def construct_from_desc(desc):
def _construct_from_desc(desc):
"""
Construct a program from program desc.
Notes: All information about parameters will be lost.
Args:
desc(core.ProgramDesc): The program desc for constructing.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册