提交 a7efab7e 编写于 作者: W WangZhen

add comments for public API. test=develop

上级 0db41a9c
...@@ -39,7 +39,13 @@ class QuantizationTransformPass(object): ...@@ -39,7 +39,13 @@ class QuantizationTransformPass(object):
""" """
Convert and rewrite the IrGraph according to weight and Convert and rewrite the IrGraph according to weight and
activation quantization type. activation quantization type.
Args: Args:
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
program_exe(fluid.Executor): program_exe is used to initialize new
parameters described above.
weight_bits (int): quantization bit number for weights, weight_bits (int): quantization bit number for weights,
the bias is not quantized. the bias is not quantized.
activation_bits (int): quantization bit number for activation. activation_bits (int): quantization bit number for activation.
...@@ -53,6 +59,7 @@ class QuantizationTransformPass(object): ...@@ -53,6 +59,7 @@ class QuantizationTransformPass(object):
support 'abs_max'. The 'range_abs_max' usually is not used for support 'abs_max'. The 'range_abs_max' usually is not used for
weight, since weights are fixed once the model is well trained. weight, since weights are fixed once the model is well trained.
window_size (int): the window size for 'range_abs_max' quantization. window_size (int): the window size for 'range_abs_max' quantization.
Examples: Examples:
.. code-block:: python .. code-block:: python
# The original graph will be rewrite. # The original graph will be rewrite.
...@@ -96,6 +103,14 @@ class QuantizationTransformPass(object): ...@@ -96,6 +103,14 @@ class QuantizationTransformPass(object):
self._global_step = None self._global_step = None
def apply(self, graph): def apply(self, graph):
"""
Quantize the graph for training process. According to weight and
activation quantization type, the graph will be added some fake
quantize operators and fake dequantize operators.
Args:
graph(IrGraph): the applied graph.
"""
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._need_initialized.clear() self._need_initialized.clear()
...@@ -336,6 +351,23 @@ class QuantizationTransformPass(object): ...@@ -336,6 +351,23 @@ class QuantizationTransformPass(object):
class QuantizationFreezePass(object): class QuantizationFreezePass(object):
"""
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be freezed into
`activation -> quant -> conv2d -> dequant`
2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`,
and weight will be sacled offline.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the
model is well trained.
"""
def __init__(self, def __init__(self,
scope, scope,
place, place,
...@@ -361,6 +393,12 @@ class QuantizationFreezePass(object): ...@@ -361,6 +393,12 @@ class QuantizationFreezePass(object):
self._var_scale_map = collections.OrderedDict() self._var_scale_map = collections.OrderedDict()
def apply(self, graph): def apply(self, graph):
"""
Adjust quantize/dequantize operators order for the inference process.
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops() ops = graph.all_ops()
for op_node in ops: for op_node in ops:
...@@ -518,6 +556,15 @@ class QuantizationFreezePass(object): ...@@ -518,6 +556,15 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
"""
Convert the weights into int8_t type.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
"""
def __init__(self, scope, place): def __init__(self, scope, place):
assert scope is not None, \ assert scope is not None, \
'The scope cannot be set None.' 'The scope cannot be set None.'
...@@ -528,6 +575,13 @@ class ConvertToInt8Pass(object): ...@@ -528,6 +575,13 @@ class ConvertToInt8Pass(object):
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
def apply(self, graph): def apply(self, graph):
"""
Convert weights' tpye of the graph. After that, the data type of the
graph weigths is int8_t.
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()] persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops() ops = graph.all_ops()
input_map = {} input_map = {}
...@@ -581,6 +635,10 @@ class ConvertToInt8Pass(object): ...@@ -581,6 +635,10 @@ class ConvertToInt8Pass(object):
class TransformForMobilePass(object): class TransformForMobilePass(object):
"""
This pass is used to convert the freezed graph for paddle-mobile execution.
"""
def __init__(self): def __init__(self):
self._fake_quant_op_names = [ self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max' 'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
...@@ -588,6 +646,14 @@ class TransformForMobilePass(object): ...@@ -588,6 +646,14 @@ class TransformForMobilePass(object):
self._fake_dequant_op_names = ['fake_dequantize_max_abs'] self._fake_dequant_op_names = ['fake_dequantize_max_abs']
def apply(self, graph): def apply(self, graph):
"""
Because paddle-mobile use `quantize` an `dequantize` as the names of
quantize operator and dequantize operator, the `apply` function just
realize this logic.
Args:
graph(IrGraph): the graph will be transformed.
"""
ops = graph.all_ops() ops = graph.all_ops()
for op_node in ops: for op_node in ops:
name = op_node.name() name = op_node.name()
......
...@@ -248,8 +248,8 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -248,8 +248,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
quantized_main_program = main_graph.to_program() quantized_main_program = main_graph.to_program()
quantized_test_program = test_graph.to_program() quantized_test_program = test_graph.to_program()
iters = 10 iters = 5
batch_size = 128 batch_size = 16
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
main_program=quantized_main_program, main_program=quantized_main_program,
...@@ -271,7 +271,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -271,7 +271,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# fetch_list=[loss]) # fetch_list=[loss])
loss_v = train_exe.run(feed=feeder.feed(data), loss_v = train_exe.run(feed=feeder.feed(data),
fetch_list=[loss.name]) fetch_list=[loss.name])
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v)) #print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader()) test_data = next(test_reader())
with fluid.program_guard(quantized_test_program): with fluid.program_guard(quantized_test_program):
...@@ -299,15 +299,15 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -299,15 +299,15 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed=feeder.feed(test_data), feed=feeder.feed(test_data),
fetch_list=[loss]) fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3) self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1)) #print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2)) #print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision # Maybe failed, this is due to the calculation precision
self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant)) # self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
print('{}: {}'.format('w_freeze' + dev_name + quant_type, #print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze))) # np.sum(w_freeze)))
print('{}: {}'.format('w_quant' + dev_name + quant_type, #print('{}: {}'.format('w_quant' + dev_name + quant_type,
np.sum(w_quant))) # np.sum(w_quant)))
# Convert parameter to 8-bit. # Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
...@@ -330,9 +330,9 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -330,9 +330,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor()) w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8) self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit))) #print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
print('{}: {}'.format('w_freeze' + dev_name + quant_type, #print('{}: {}'.format('w_freeze' + dev_name + quant_type,
np.sum(w_freeze))) # np.sum(w_freeze)))
mobile_pass = TransformForMobilePass() mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph) mobile_pass.apply(test_graph)
......
...@@ -1516,12 +1516,16 @@ class Block(object): ...@@ -1516,12 +1516,16 @@ class Block(object):
class IrGraph(object): class IrGraph(object):
""" """
IrGraph uses core.Graph as the delegation to accomplish the manipulation. Python IrGraph. Beneath it is a core.Graph, which is used for
create a c++ Ir Pass Graph. An IrGraph is just a graph view of
a Program. In an IrGraph, both Variables and Operators are graph
nodes.
""" """
def __init__(self, graph, for_test=False): def __init__(self, graph, for_test=False):
""" """
Construct the IrGraph using core.Graph. Construct an IrGraph using core.Graph.
Args: Args:
graph(core.Graph): C++ Graph. graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph. for_test(bool): True for the test graph and false for the train graph.
...@@ -1532,15 +1536,27 @@ class IrGraph(object): ...@@ -1532,15 +1536,27 @@ class IrGraph(object):
self._for_test = for_test self._for_test = for_test
def is_test(self): def is_test(self):
"""
If the graph is used for testing, the function returns true. Otherwise, returns false.
"""
return self._for_test return self._for_test
def all_nodes(self): def all_nodes(self):
"""
Return all nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes()} return {node for node in self.graph.nodes()}
def all_vars(self): def all_vars(self):
"""
Return all variable nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_var()} return {node for node in self.graph.nodes() if node.is_var()}
def all_persistable_vars(self): def all_persistable_vars(self):
"""
Return all persistable variable nodes included in the graph as a set.
"""
persistable_nodes = set() persistable_nodes = set()
for node in self.graph.nodes(): for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var( if node.is_var() and node.var() is not None and node.var(
...@@ -1549,18 +1565,24 @@ class IrGraph(object): ...@@ -1549,18 +1565,24 @@ class IrGraph(object):
return persistable_nodes return persistable_nodes
def all_ops(self): def all_ops(self):
"""
Return all operator nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_op()} return {node for node in self.graph.nodes() if node.is_op()}
def var_node(self, name): def var_node(self, name):
""" """
Get a variable node by name from this graph. Get a variable node by name from the graph.
Args: Args:
name(str): the name of the variable node. name(str): the name of the variable node.
Raises: Raises:
ValueError: The If input's type is not str, or this graph ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name. doesn't have a variable with the giving name.
Returns: Returns:
Node: the variable node with the giving name. core.Node: the variable node with the giving name.
""" """
if not isinstance(name, six.string_types): if not isinstance(name, six.string_types):
raise TypeError( raise TypeError(
...@@ -1576,6 +1598,19 @@ class IrGraph(object): ...@@ -1576,6 +1598,19 @@ class IrGraph(object):
return target_var_node return target_var_node
def create_param_node(self, name, var_type, shape, var_dtype): def create_param_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
it can not distinguish between persistable variables and parameters.
Args:
name(str): the name of the persistable variable node.
vart_type(core.VarDesc.VarType): the type of the persistable variable node.
shape(list): the shape of the persistable variable node.
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
Returns:
core.Node: the created persistable variable node.
"""
var_desc = core.VarDesc(name) var_desc = core.VarDesc(name)
var_desc.set_type(var_type) var_desc.set_type(var_type)
var_desc.set_shape(shape) var_desc.set_shape(shape)
...@@ -1584,6 +1619,20 @@ class IrGraph(object): ...@@ -1584,6 +1619,20 @@ class IrGraph(object):
return self.graph.create_var_node(var_desc) return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype): def create_var_node(self, name, var_type, shape, var_dtype):
"""
Create a variable node in the graph. The created variable node is
not persistable.
Args:
name(str): the name of the variable node.
vart_type(core.VarDesc.VarType): the type of the variable node.
shape(list): the shape of the variable node.
var_dtype(core.VarDesc.VarType): the data type of the variable node.
Returns:
core.Node: the created variable node.
"""
var_desc = core.VarDesc(name) var_desc = core.VarDesc(name)
var_desc.set_type(var_type) var_desc.set_type(var_type)
var_desc.set_shape(shape) var_desc.set_shape(shape)
...@@ -1591,9 +1640,31 @@ class IrGraph(object): ...@@ -1591,9 +1640,31 @@ class IrGraph(object):
return self.graph.create_var_node(var_desc) return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc): def create_var_node_from_desc(self, var_desc):
"""
Create a variable node by using an existing VarDesc in the graph.
Depend on the giving VarDesc, the created variable node may be persistable.
Args:
var_desc(core.VarDesc): the giving variable description.
Returns:
core.Node: the created variable node.
"""
return self.graph.create_var_node(var_desc) return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs): def create_op_node(self, op_type, attrs, inputs, outputs):
"""
Create a operator node in the graph.
Args:
op_type(str): the type of the operator node.
attrs(dict): the attributes of the operator node.
inputs(dict): the inputs of the operator node.
outputs(dict): the outpus of the operator node.
Returns:
core.Node: the created operator node.
"""
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
for attr, value in attrs.iteritems(): for attr, value in attrs.iteritems():
...@@ -1611,9 +1682,26 @@ class IrGraph(object): ...@@ -1611,9 +1682,26 @@ class IrGraph(object):
return self.graph.create_op_node(op_desc) return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc): def create_op_node_from_desc(self, op_desc):
"""
Create a operator node by using an existing OpDesc in the graph.
Args:
op_desc(core.VarDesc): the giving operator description.
Returns:
core.Node: the created operator node.
"""
return self.graph.create_op_node(op_desc) return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node): def update_input_link(self, old_input_node, new_input_node, op_node):
"""
Update the input's link of a operator node.
Args:
old_input_node(core.Node): the old input node of the giving op_node.
new_input_node(core.Node): the new input node of the giving op_node.
op_node(core.Node): the operator node that is needed to update input's link.
"""
assert old_input_node in self.graph.nodes() and new_input_node in \ assert old_input_node in self.graph.nodes() and new_input_node in \
self.graph.nodes() and op_node in self.graph.nodes(), \ self.graph.nodes() and op_node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.' 'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
...@@ -1624,12 +1712,26 @@ class IrGraph(object): ...@@ -1624,12 +1712,26 @@ class IrGraph(object):
op_node.op()._rename_input(old_input_node.name(), new_input_node.name()) op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out): def link_to(self, node_in, node_out):
"""
Connect two nodes.
Args:
node_in(core.Node): the input node.
node_out(core.Node): the output node.
"""
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \ assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'The two arguments(node_in&node_out) must be in the graph nodes.' 'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out) node_in.outputs_append(node_out)
node_out.inputs_append(node_in) node_out.inputs_append(node_in)
def safe_remove_nodes(self, remove_nodes): def safe_remove_nodes(self, remove_nodes):
"""
Remove nodes safely since links connected to these removed nodes are
also removed.
Args:
remove_nodes(set): the nodes prepared to be removed.
"""
if not isinstance(remove_nodes, set): if not isinstance(remove_nodes, set):
if isinstance(remove_nodes, Iterable): if isinstance(remove_nodes, Iterable):
remove_nodes = set(remove_nodes) remove_nodes = set(remove_nodes)
...@@ -1638,18 +1740,57 @@ class IrGraph(object): ...@@ -1638,18 +1740,57 @@ class IrGraph(object):
core.graph_safe_remove_nodes(self.graph, remove_nodes) core.graph_safe_remove_nodes(self.graph, remove_nodes)
def has_circle(self): def has_circle(self):
"""
Check if the graph has a circle.
Returns:
bool: True if the graph has a circle else False.
"""
return core.has_circle(self.graph) return core.has_circle(self.graph)
def graph_num(self): def graph_num(self):
"""
Count the number of unconnected graphs in this graph.
Returns:
int: the number of unconnected graphs.
"""
return core.graph_num(self.graph) return core.graph_num(self.graph)
def topology_sort(self): def topology_sort(self):
"""
Perform the topology sort operation on the graph.
Notes: the `graph` cannot contain a circle.
Returns:
set(core.Node): nodes in topology order.
"""
return core.topology_sort(self.graph) return core.topology_sort(self.graph)
def build_adjacency_list(self): def build_adjacency_list(self):
"""
Build an adjacency list of operations for the `graph`.
Returns:
dict{core.Node: set(core.Node)}: the adjacency list.
"""
return core.build_adjacency_list(self.graph) return core.build_adjacency_list(self.graph)
def draw(self, save_path, name, marked_nodes=None): def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True):
"""
Draw the graph. If `dot` command is installed, the drawn graph
will be saved as pdf file type, otherwise dot file type is used.
Args:
save_path(str): the save path of drawn graph.
name(str): the name of drawn graph.
marked_nodes(set(core.Node)): nodes that are needed to be marked.
Default value is None.
remove_ctr_var(bool): If it is set True, all control variable nodes
in the graph will be removed. Default value is True.
"""
def _convert_to_pdf(dot_file_path): def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
...@@ -1659,15 +1800,17 @@ class IrGraph(object): ...@@ -1659,15 +1800,17 @@ class IrGraph(object):
print('The {} is saved as the dot filetype.'.format( print('The {} is saved as the dot filetype.'.format(
dot_file_path)) dot_file_path))
remove_ctr_vars = set() if remove_ctr_var:
remove_ctr_vars = set()
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
self.safe_remove_nodes(remove_ctr_vars)
ops_num = 0 ops_num = 0
for node in self.graph.nodes(): for node in self.graph.nodes():
if node.is_ctrl_var(): if node.is_op():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1 ops_num += 1
print('Total ops num = {}.'.format(ops_num)) print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None: if marked_nodes is not None:
if not isinstance(marked_nodes, set): if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes) marked_nodes = set(marked_nodes)
...@@ -1682,6 +1825,16 @@ class IrGraph(object): ...@@ -1682,6 +1825,16 @@ class IrGraph(object):
_convert_to_pdf(viz_dot_path) _convert_to_pdf(viz_dot_path)
def to_program(self): def to_program(self):
"""
Convert the graph into a Program.
Notes: When the graph includes backward operator nodes, the
conversion process may be failed. Usually, this function is
only used to convert a test graph.
Returns:
Program: a program converted from the graph.
"""
convert_pass = core.get_pass('graph_to_program_pass') convert_pass = core.get_pass('graph_to_program_pass')
desc = core.ProgramDesc() desc = core.ProgramDesc()
convert_pass.set_not_owned('program', desc) convert_pass.set_not_owned('program', desc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册