提交 b913463e 编写于 作者: W WangZhen 提交者: root

Update according to the reviewers' suggestion. test=develop

上级 3ce61720
...@@ -148,8 +148,8 @@ void BindNode(py::module *m) { ...@@ -148,8 +148,8 @@ void BindNode(py::module *m) {
}) })
.def("outputs_append", .def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); }) [](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readonly("inputs", &Node::inputs) .def_readwrite("inputs", &Node::inputs)
.def_readonly("outputs", &Node::outputs); .def_readwrite("outputs", &Node::outputs);
py::enum_<Node::Type>(node, "Type") py::enum_<Node::Type>(node, "Type")
.value("Operation", Node::Type::kOperation) .value("Operation", Node::Type::kOperation)
......
...@@ -797,17 +797,17 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -797,17 +797,17 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass"); py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init()) pass.def(py::init())
.def("has", &ir::Pass::Has) .def("has", &ir::Pass::Has)
.def("set_program", .def("set",
[](ir::Pass &self, const std::string &attr_name, [](ir::Pass &self, const std::string &attr_name,
const ProgramDesc &attr) { const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr)); return self.Set(attr_name, new ProgramDesc(attr));
}) })
.def( .def(
"set_str", "set",
[](ir::Pass &self, const std::string &name, const std::string &attr) { [](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr)); self.Set<std::string>(name, new std::string(attr));
}) })
.def("set_int", [](ir::Pass &self, const std::string &name, .def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); }) int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>) .def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type) .def("type", &ir::Pass::Type)
......
...@@ -18,140 +18,7 @@ from ....framework import Program ...@@ -18,140 +18,7 @@ from ....framework import Program
from ....framework import Block from ....framework import Block
from .... import core from .... import core
__all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph'] __all__ = ['Graph', 'ImitationGraph', 'IRGraph']
class PyGraph(object):
"""
PyGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def __init__(self, graph, for_test=False):
"""
Construct the PyGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph
self.for_test = for_test
def is_test(self):
return self.for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_vars(self):
return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self):
return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc):
return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs):
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
return self.graph.create_op_node(op_desc)
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
"""
if isinstance(val, Block):
desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string())
else:
desc._set_attr(name, val)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set_str('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set_program('program', Program().desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program.construct_from_desc(desc)
return program
class Graph(object): class Graph(object):
......
...@@ -13,13 +13,12 @@ ...@@ -13,13 +13,12 @@
# limitations under the License. # limitations under the License.
import collections import collections
import numpy as np
from .... import core from .... import core
from ....framework import IrGraph
from ....framework import Program from ....framework import Program
from ....framework import Variable from ....framework import Variable
from ....initializer import Constant from ....initializer import Constant
from .... import unique_name from .... import unique_name
from ..graph import PyGraph
__all__ = ['QuantizationTransformPass'] __all__ = ['QuantizationTransformPass']
...@@ -34,7 +33,7 @@ class QuantizationTransformPass(object): ...@@ -34,7 +33,7 @@ class QuantizationTransformPass(object):
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
window_size=10000): window_size=10000):
""" """
Convert and rewrite the PyGraph according to weight and Convert and rewrite the IrGraph according to weight and
activation quantization type. activation quantization type.
Args: Args:
weight_bits (int): quantization bit number for weights, weight_bits (int): quantization bit number for weights,
...@@ -56,19 +55,19 @@ class QuantizationTransformPass(object): ...@@ -56,19 +55,19 @@ class QuantizationTransformPass(object):
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \ from paddle.fluid.contrib.slim.quantization \
import QuantizationTransformPass import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import PyGraph from paddle.fluid.contrib.slim.graph import IrGraph
from paddle.fluid import core from paddle.fluid import core
graph = PyGraph(core.Graph(program.desc), for_test=False) graph = IrGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
transform_pass = QuantizationTransformPass(fluid.global_scope(), transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe) exe)
transform_pass.apply(graph) transform_pass.apply(graph)
""" """
self.scope = scope self._scope = scope
self.program_exe = program_exe self._program_exe = program_exe
self.weight_bits = weight_bits self._weight_bits = weight_bits
self.activation_bits = activation_bits self._activation_bits = activation_bits
quant_type = ['abs_max', 'range_abs_max'] quant_type = ['abs_max', 'range_abs_max']
if activation_quantize_type not in quant_type: if activation_quantize_type not in quant_type:
...@@ -80,27 +79,27 @@ class QuantizationTransformPass(object): ...@@ -80,27 +79,27 @@ class QuantizationTransformPass(object):
"Unknown weight_quantize_type: '%s'. It can only be ", "Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'range_abs_max'.", str(weight_quantize_type)) "'abs_max' or 'range_abs_max'.", str(weight_quantize_type))
self.activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
self.weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self.window_size = window_size self._window_size = window_size
self.need_initialized = collections.OrderedDict() self._need_initialized = collections.OrderedDict()
self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self.quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self.quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
] ]
self.fake_quant_op_types = [ self._fake_quant_op_types = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max' 'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
] ]
self.fake_dequant_op_types = ['fake_dequantize_max_abs'] self._fake_dequant_op_types = ['fake_dequantize_max_abs']
self.is_test = None self._is_test = None
self.global_step = None self._global_step = None
def apply(self, graph): def apply(self, graph):
assert isinstance(graph, assert isinstance(graph,
PyGraph), 'graph must be the instance of PyGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self.need_initialized.clear() self._need_initialized.clear()
self.is_test = graph.is_test() self._is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
params = [p.name() for p in graph.all_parameters()] params = [p.name() for p in graph.all_parameters()]
...@@ -110,72 +109,69 @@ class QuantizationTransformPass(object): ...@@ -110,72 +109,69 @@ class QuantizationTransformPass(object):
if var_node.name() in dequantized_vars: if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()] dequant_var_node = dequantized_vars[var_node.name()]
else: else:
quant_bits = self.weight_bits if var_node.name() in params \ quant_bits = self._weight_bits if var_node.name() in params \
else self.activation_bits else self._activation_bits
quant_type = self.weight_quantize_type if var_node.name() \ quant_type = self._weight_quantize_type if var_node.name() \
in params else self.activation_quantize_type in params else self._activation_quantize_type
quant_var_node, scale_var_node = self._insert_quant_op( quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, quant_type) graph, var_node, quant_bits, quant_type)
dequant_var_node = self._insert_dequant_op( dequant_var_node = self._insert_dequant_op(
graph, quant_var_node, scale_var_node, quant_bits) graph, quant_var_node, scale_var_node, quant_bits)
dequantized_vars[var_node.name()] = dequant_var_node dequantized_vars[var_node.name()] = dequant_var_node
self._update_input(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(), dequant_var_node.name())
def _transform_backward(graph, op): def _transform_backward(graph, op):
no_dequanted_input_vars = True no_dequanted_input_vars = True
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() in dequantized_vars: if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()] dequant_var_node = dequantized_vars[var_node.name()]
self._update_input(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(),
dequant_var_node.name())
no_dequanted_input_vars = False no_dequanted_input_vars = False
if no_dequanted_input_vars: if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." % raise ValueError("There is no dequanted inputs for op %s." %
(op.name())) (op.name()))
if not self.is_test: if not self._is_test:
self._create_global_step(graph) self._create_global_step(graph)
ops = graph.all_ops() ops = graph.all_ops()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
if op.name() in self.quantizable_ops: if op.name() in self._quantizable_ops:
_transform_forward(graph, op) _transform_forward(graph, op)
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self.quantizable_grad_ops: if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op) _transform_backward(graph, op)
if len(self.need_initialized) > 0: if len(self._need_initialized) > 0:
assert self.scope is not None, \ assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self.program_exe is not None, \ assert self._program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.' 'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program() init_program = Program()
for var_desc, initializer in self.need_initialized.iteritems(): for var_desc, initializer in self._need_initialized.iteritems():
var = Variable.construct_from_desc(init_program.global_block(), var = Variable(init_program.global_block())
var_desc) var._set_desc(var_desc)
initializer(var, init_program.global_block()) initializer(var, init_program.global_block())
self.program_exe.run(program=init_program, scope=self.scope) self._program_exe.run(program=init_program, scope=self._scope)
return graph return graph
def _create_global_step(self, graph): def _create_global_step(self, graph):
if self.weight_quantize_type == 'range_abs_max' or \ if self._weight_quantize_type == 'range_abs_max' or \
self.activation_quantize_type == 'range_abs_max': self._activation_quantize_type == 'range_abs_max':
counter_name = '@STEP_COUNTER@' counter_name = '@STEP_COUNTER@'
for node in graph.all_vars(): for node in graph.all_vars():
if node.name() == counter_name: if node.name() == counter_name:
self.global_step = node self._global_step = node
if self.global_step is None: if self._global_step is None:
global_step_in = graph.create_param_node( global_step_in = graph.create_param_node(
name=counter_name, name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=core.VarDesc.VarType.INT64) var_dtype=core.VarDesc.VarType.INT64)
self.need_initialized[global_step_in.var()] = \ self._need_initialized[global_step_in.var()] = \
Constant(value=0, force_cpu=True) Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc( global_step_out = graph.create_var_node_from_desc(
global_step_in.var()) global_step_in.var())
...@@ -184,9 +180,9 @@ class QuantizationTransformPass(object): ...@@ -184,9 +180,9 @@ class QuantizationTransformPass(object):
attrs={'step': 1.0}, attrs={'step': 1.0},
inputs={'X': global_step_in}, inputs={'X': global_step_in},
outputs={'Out': global_step_out}) outputs={'Out': global_step_out})
self._link_to(global_step_in, increment_op) graph.link_to(global_step_in, increment_op)
self._link_to(increment_op, global_step_out) graph.link_to(increment_op, global_step_out)
self.global_step = global_step_out self._global_step = global_step_out
def _insert_quant_op(self, graph, var_node, quant_bits, quant_type): def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
""" """
...@@ -220,9 +216,9 @@ class QuantizationTransformPass(object): ...@@ -220,9 +216,9 @@ class QuantizationTransformPass(object):
inputs={'X': var_node}, inputs={'X': var_node},
outputs={'Out': quant_var_node, outputs={'Out': quant_var_node,
'OutScale': scale_var_node}) 'OutScale': scale_var_node})
self._link_to(var_node, quant_op_node) graph.link_to(var_node, quant_op_node)
self._link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, quant_var_node)
self._link_to(quant_op_node, scale_var_node) graph.link_to(quant_op_node, scale_var_node)
return quant_var_node, scale_var_node return quant_var_node, scale_var_node
def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits): def _insert_quant_range_abs_max_op(self, graph, var_node, quant_bits):
...@@ -242,26 +238,26 @@ class QuantizationTransformPass(object): ...@@ -242,26 +238,26 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.var().dtype()) var_dtype=var_node.var().dtype())
self.need_initialized[scale_in_node.var()] = Constant(value=0.001) self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node} inputs = {'X': var_node, 'InScale': scale_in_node}
outputs = {'Out': quant_var_node, 'OutScale': scale_out_node} outputs = {'Out': quant_var_node, 'OutScale': scale_out_node}
if not self.is_test: if not self._is_test:
# The name of scales_var_node maybe 'scales_0', 'scales_1', etc. # The name of scales_var_node maybe 'scales_0', 'scales_1', etc.
scales_node = graph.create_param_node( scales_node = graph.create_param_node(
name=unique_name.generate('scales'), name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self.window_size], shape=[self._window_size],
var_dtype=var_node.var().dtype()) var_dtype=var_node.var().dtype())
self.need_initialized[scales_node.var()] = Constant(value=0) self._need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self.global_step inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
attrs = { attrs = {
'window_size': self.window_size, 'window_size': self._window_size,
'bit_length': quant_bits, 'bit_length': quant_bits,
'is_test': self.is_test 'is_test': self._is_test
} }
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max', op_type='fake_quantize_range_abs_max',
...@@ -269,14 +265,14 @@ class QuantizationTransformPass(object): ...@@ -269,14 +265,14 @@ class QuantizationTransformPass(object):
inputs=inputs, inputs=inputs,
outputs=outputs) outputs=outputs)
self._link_to(var_node, quant_op_node) graph.link_to(var_node, quant_op_node)
self._link_to(scale_in_node, quant_op_node) graph.link_to(scale_in_node, quant_op_node)
self._link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, quant_var_node)
self._link_to(quant_op_node, scale_out_node) graph.link_to(quant_op_node, scale_out_node)
if not self.is_test: if not self._is_test:
self._link_to(self.global_step, quant_op_node) graph.link_to(self._global_step, quant_op_node)
self._link_to(quant_op_node, scales_node) graph.link_to(quant_op_node, scales_node)
return quant_var_node, scale_out_node return quant_var_node, scale_out_node
...@@ -298,21 +294,11 @@ class QuantizationTransformPass(object): ...@@ -298,21 +294,11 @@ class QuantizationTransformPass(object):
inputs={'X': var_node, inputs={'X': var_node,
'Scale': scale_var_node}, 'Scale': scale_var_node},
outputs={'Out': dequant_var_node}) outputs={'Out': dequant_var_node})
self._link_to(var_node, dequant_op_node) graph.link_to(var_node, dequant_op_node)
self._link_to(scale_var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node)
self._link_to(dequant_op_node, dequant_var_node) graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node return dequant_var_node
def _update_input(self, old_input_node, new_input_node, op_node):
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
def _link_to(self, node_in, node_out):
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def _quantized_var_name(self, var_name): def _quantized_var_name(self, var_name):
""" """
Return quantized variable name for the input `var_name`. Return quantized variable name for the input `var_name`.
...@@ -330,25 +316,3 @@ class QuantizationTransformPass(object): ...@@ -330,25 +316,3 @@ class QuantizationTransformPass(object):
Return quantized variable name for the input `var_name`. Return quantized variable name for the input `var_name`.
""" """
return "%s.scale" % (var_name) return "%s.scale" % (var_name)
def _original_var_name(self, var_name):
"""
Return the original variable name.
"""
if var_name.endswith('.quantized.dequantized'):
return var_name[:-len('.quantized.dequantized')]
if var_name.endswith('.quantized'):
return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[:-len('.scale')]
else:
return var_name
def _is_float(self, v):
return isinstance(v, float) or isinstance(v, np.float32)
def _quant(self, x, scale, num_bits):
y = np.round(x / scale * ((1 << (num_bits - 1)) - 1))
return y
...@@ -18,8 +18,8 @@ import numpy as np ...@@ -18,8 +18,8 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import six import six
from paddle.fluid.framework import Program from paddle.fluid.framework import Program
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid import core from paddle.fluid import core
...@@ -106,7 +106,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -106,7 +106,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
graph = PyGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
program_exe=exe, program_exe=exe,
...@@ -119,7 +119,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -119,7 +119,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes) graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_ops(): for op in val_graph.all_ops():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
...@@ -142,7 +142,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -142,7 +142,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
graph = PyGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
program_exe=exe, program_exe=exe,
...@@ -155,7 +155,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -155,7 +155,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes) graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set() val_marked_nodes = set()
for op in val_graph.all_ops(): for op in val_graph.all_ops():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
......
...@@ -23,6 +23,7 @@ import traceback ...@@ -23,6 +23,7 @@ import traceback
import six import six
import numpy as np import numpy as np
import subprocess
from .. import compat as cpt from .. import compat as cpt
from .proto import framework_pb2 from .proto import framework_pb2
...@@ -381,27 +382,6 @@ class Variable(object): ...@@ -381,27 +382,6 @@ class Variable(object):
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient self._ivar.stop_gradient = stop_gradient
@staticmethod
def construct_from_desc(block, desc):
"""
Construct a Variable from variable desc.
Args:
desc(core.VarDesc): The variable desc for constructing.
Returns:
Variable: A variable.
"""
v = Variable(
block=block,
type=desc.type(),
name=desc.name(),
shape=desc.shape(),
dtype=desc.dtype(),
lod_level=desc.lod_level(),
persistable=desc.persistable())
v.desc = desc
return v
def _numpy(self): def _numpy(self):
tensor = self._ivar.value().get_tensor() tensor = self._ivar.value().get_tensor()
return np.array(tensor) return np.array(tensor)
...@@ -1533,6 +1513,154 @@ class Block(object): ...@@ -1533,6 +1513,154 @@ class Block(object):
return ret_var return ret_var
class IrGraph(object):
"""
IrGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def __init__(self, graph, for_test=False):
"""
Construct the IrGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph
self._for_test = for_test
def is_test(self):
return self._for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_vars(self):
return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self):
return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
var_desc.set_persistable(True)
return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
var_desc.set_dtype(var_dtype)
return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc):
return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs):
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node):
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out):
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'Th two arguments must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc)
return program
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
"""
if isinstance(val, Block):
desc.set_block_attr(name, val.desc)
elif isinstance(val, list) and val and all(
isinstance(v, Block) for v in val):
desc.set_blocks_attr(name, [v.desc for v in val])
elif isinstance(val, core.BlockDesc) or \
isinstance(val, core.ProgramDesc):
desc.set_serialized_attr(name, val.serialize_to_string())
else:
desc._set_attr(name, val)
class Program(object): class Program(object):
""" """
Python Program. Beneath it is a ProgramDesc, which is used for Python Program. Beneath it is a ProgramDesc, which is used for
...@@ -1958,12 +2086,10 @@ class Program(object): ...@@ -1958,12 +2086,10 @@ class Program(object):
return p return p
@staticmethod @staticmethod
def construct_from_desc(desc): def _construct_from_desc(desc):
""" """
Construct a program from program desc. Construct a program from program desc.
Notes: All information about parameters will be lost.
Args: Args:
desc(core.ProgramDesc): The program desc for constructing. desc(core.ProgramDesc): The program desc for constructing.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册