From 2ccbfd5e100f133bd841c854b8cdd32fd74ab132 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Mon, 25 Mar 2019 22:47:58 +0800 Subject: [PATCH] Fix some bugs for quantization passes. --- .../slim/quantization/quantization_pass.py | 181 +++++++++++------- python/paddle/fluid/framework.py | 57 ++++-- 2 files changed, 157 insertions(+), 81 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 5dcef50671..bbf3d17d19 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -14,15 +14,10 @@ import collections import numpy as np -import six from ..... import compat as cpt from .... import core -from .... import Executor from ....framework import IrGraph from ....framework import IrNode -from ....framework import Program -from ....initializer import Constant -from ....initializer import NumpyArrayInitializer from .... import unique_name __all__ = [ @@ -31,6 +26,35 @@ __all__ = [ ] +def _resolve_hazard(graph): + def _to_node(nodes, node_name): + target_node = None + for n in nodes: + if n.name() == node_name: + target_node = n.node + assert target_node is not None, "Cannot find the target node in the giving set." + return target_node + + ordered_nodes = graph.topology_sort() + var_nodes = dict() + for node in ordered_nodes: + if node.is_op() and node.op() is not None: + for each_var_name in node.op().input_arg_names(): + if each_var_name not in var_nodes: + var_nodes[each_var_name] = [ + _to_node(node.inputs, each_var_name) + ] + for each_var_name in node.op().output_arg_names(): + if each_var_name not in var_nodes: + var_nodes[each_var_name] = [ + _to_node(node.outputs, each_var_name) + ] + else: + var_nodes[each_var_name].append( + _to_node(node.outputs, each_var_name)) + graph.graph.resolve_hazard(var_nodes) + + class QuantizationTransformPass(object): def __init__(self, scope=None, @@ -107,7 +131,6 @@ class QuantizationTransformPass(object): self._window_size = window_size self._moving_rate = moving_rate - self._need_initialized = collections.OrderedDict() self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._quantizable_grad_ops = [ @@ -127,7 +150,8 @@ class QuantizationTransformPass(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - self._need_initialized.clear() + sequential_execution = core.get_pass('sequential_execution_pass') + sequential_execution.apply(graph.graph) self._is_test = graph.is_test() # marked the variable which has been dequantized. dequantized_vars = collections.OrderedDict() @@ -135,6 +159,8 @@ class QuantizationTransformPass(object): def _transform_forward(graph, op): for var_node in op.inputs: + if var_node.name() not in op.input_arg_names(): + continue if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] else: @@ -168,6 +194,8 @@ class QuantizationTransformPass(object): def _transform_backward(graph, op): no_dequanted_input_vars = True for var_node in op.inputs: + if var_node.name() not in op.input_arg_names(): + continue if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] graph.update_input_link(var_node, dequant_var_node, op) @@ -188,25 +216,7 @@ class QuantizationTransformPass(object): for op in ops: if op.name() in self._quantizable_grad_ops: _transform_backward(graph, op) - - if len(self._need_initialized) > 0: - assert self._scope is not None, \ - 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' - assert self._place is not None, \ - 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' - init_program = Program() - for var_desc, initializer in six.iteritems(self._need_initialized): - var = init_program.global_block().create_var( - name=var_desc.name(), - shape=var_desc.shape(), - dtype=var_desc.dtype(), - type=var_desc.type(), - lod_level=var_desc.lod_level(), - persistable=var_desc.persistable()) - initializer(var, init_program.global_block()) - exe = Executor(self._place) - exe.run(program=init_program, scope=self._scope) - + _resolve_hazard(graph) return graph def _create_global_step(self, graph): @@ -222,8 +232,9 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=core.VarDesc.VarType.INT64) - self._need_initialized[global_step_in.var()] = \ - Constant(value=0, force_cpu=True) + self._init_var_node( + global_step_in, np.zeros( + [1], dtype='int64')) global_step_out = graph.create_var_node_from_desc( global_step_in.var()) # The attribute of `op_role` is needed by ParallelExecutor. @@ -300,7 +311,9 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.dtype()) - self._need_initialized[scale_in_node.var()] = Constant(value=0.001) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) inputs = {'X': var_node, 'InScale': scale_in_node} @@ -313,7 +326,11 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[self._window_size], var_dtype=var_node.dtype()) - self._need_initialized[scales_node.var()] = Constant(value=0) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + self._init_var_node( + scales_node, np.zeros( + [self._window_size], dtype=data_type)) inputs['Iter'] = self._global_step outputs['OutScales'] = scales_node attrs = { @@ -353,7 +370,9 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=var_node.dtype()) - self._need_initialized[scale_in_node.var()] = Constant(value=0.001) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) ins = {'X': var_node, 'InScale': scale_in_node} @@ -364,13 +383,15 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=var_node.dtype(), shape=[1]) - self._need_initialized[state_in_node.var()] = Constant(value=1) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + self._init_var_node(scale_in_node, np.ones([1], dtype=data_type)) accum_in_node = graph.create_persistable_node( name=unique_name.generate('accum'), var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=var_node.dtype(), shape=[1]) - self._need_initialized[accum_in_node.var()] = Constant(value=1) + self._init_var_node(accum_in_node, np.ones([1], dtype=data_type)) state_out_node = graph.create_var_node_from_desc(state_in_node.var( )) accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( @@ -490,6 +511,16 @@ class QuantizationTransformPass(object): graph.link_to(dequant_op_node, dequant_var_node) return dequant_var_node + def _init_var_node(self, var_node, value): + assert isinstance( + value, np.ndarray), 'The type of value should be numpy array.' + assert self._scope is not None, \ + 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' + assert self._place is not None, \ + 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' + tensor = self._scope.var(var_node.name()).get_tensor() + tensor.set(value, self._place) + def _quantized_var_name(self, var_name): """ Return quantized variable name for the input `var_name`. @@ -592,7 +623,8 @@ class QuantizationFreezePass(object): self._weight_bits) self._restore_var(input_arg_name, quantized_param_v) else: - scale_v = graph.var_node(op_node.output('OutScale')[0]) + scale_v = self._to_node(op_node.outputs, + op_node.output('OutScale')[0]) self._var_scale_map[input_arg_name] = scale_v ops = graph.all_op_nodes() @@ -613,10 +645,9 @@ class QuantizationFreezePass(object): for op_node in ops: # insert dequant_op after fc/conv, need to rename inputs of the followed ops for var_node in op_node.inputs: - name = var_node.name() - if name in self._op_output_rename_map: - old_in = graph.var_node(name) - new_in = self._op_output_rename_map[name] + if var_node.node in self._op_output_rename_map: + old_in = var_node + new_in = self._op_output_rename_map[var_node.node] graph.update_input_link(old_in, new_in, op_node) # remove the unused var node in the graph @@ -624,21 +655,24 @@ class QuantizationFreezePass(object): return graph def _remove_fake_quant_and_dequant_op(self, graph, op_node): - k = op_node.output('Out')[0] - v = op_node.input('X')[0] - if v not in self._op_input_rename_map: - self._op_input_rename_map[k] = v + k = self._to_node(op_node.outputs, op_node.output('Out')[0]) + v = self._to_node(op_node.inputs, op_node.input('X')[0]) + if v.node not in self._op_input_rename_map: + self._op_input_rename_map[k.node] = v else: - self._op_input_rename_map[k] = self._op_input_rename_map[v] + self._op_input_rename_map[k.node] = self._op_input_rename_map[ + v.node] graph.safe_remove_nodes(op_node) def _insert_post_channel_dequant_op(self, graph, op_node): persistable_vars = [p.name() for p in graph.all_persistable_nodes()] for var_node in op_node.inputs: name = var_node.name() - if name in self._op_input_rename_map: - old_in = graph.var_node(name) - new_in = graph.var_node(self._op_input_rename_map[name]) + if name not in op_node.input_arg_names(): + continue + if var_node.node in self._op_input_rename_map: + old_in = var_node + new_in = self._op_input_rename_map[var_node.node] new_in.clear_outputs() graph.update_input_link(old_in, new_in, op_node) original_var_name = self._original_var_name(name) @@ -653,28 +687,20 @@ class QuantizationFreezePass(object): assert isinstance(scale_v, IrNode) scale_var_node = self._var_scale_map[original_var_name] - if len(op_node.outputs) != 1: + if len(op_node.output_arg_names()) != 1: raise ValueError("Only support one output, but op %s has" " more than one output." % (op_node.name())) - output_var_node = op_node.outputs[0] + output_var_node = self._to_node(op_node.outputs, + op_node.output_arg_names()[0]) weight_scale_node = graph.create_persistable_node( name=unique_name.generate('channel_scale'), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[channel_scale.shape[0]], var_dtype=output_var_node.dtype()) - init_program = Program() - weight_scale_var = init_program.global_block().create_var( - name=weight_scale_node.name(), - shape=weight_scale_node.shape(), - dtype=weight_scale_node.dtype(), - type=weight_scale_node.type(), - lod_level=weight_scale_node.var().lod_level(), - persistable=weight_scale_node.persistable()) - initializer = NumpyArrayInitializer(value=channel_scale) - initializer(weight_scale_var, init_program.global_block()) - exe = Executor(self._place) - exe.run(program=init_program, scope=self._scope) + data_type = 'float64' if output_var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + self._init_var_node(weight_scale_node, channel_scale.astype(data_type)) dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(output_var_node.name()), var_type=output_var_node.type(), @@ -695,16 +721,18 @@ class QuantizationFreezePass(object): graph.link_to(scale_var_node, dequant_op_node) graph.link_to(weight_scale_node, dequant_op_node) graph.link_to(dequant_op_node, dequant_var_node) - self._op_output_rename_map[output_var_node.name()] = dequant_var_node + self._op_output_rename_map[output_var_node.node] = dequant_var_node return dequant_var_node def _insert_post_dequant_op(self, graph, op_node): persistable_vars = [p.name() for p in graph.all_persistable_nodes()] for var_node in op_node.inputs: name = var_node.name() - if name in self._op_input_rename_map: - old_in = graph.var_node(name) - new_in = graph.var_node(self._op_input_rename_map[name]) + if name not in op_node.input_arg_names(): + continue + if var_node.node in self._op_input_rename_map: + old_in = var_node + new_in = self._op_input_rename_map[var_node.node] new_in.clear_outputs() graph.update_input_link(old_in, new_in, op_node) original_var_name = self._original_var_name(name) @@ -720,11 +748,12 @@ class QuantizationFreezePass(object): assert isinstance(scale_v, IrNode) scale_var_node = self._var_scale_map[original_var_name] - if len(op_node.outputs) != 1: + if len(op_node.output_arg_names()) != 1: raise ValueError("Only support one output, but op %s has" " more than one output." % (op_node.name())) - output_var_node = op_node.outputs[0] + output_var_node = self._to_node(op_node.outputs, + op_node.output_arg_names()[0]) dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(output_var_node.name()), var_type=output_var_node.type(), @@ -742,9 +771,27 @@ class QuantizationFreezePass(object): graph.link_to(output_var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node) graph.link_to(dequant_op_node, dequant_var_node) - self._op_output_rename_map[output_var_node.name()] = dequant_var_node + self._op_output_rename_map[output_var_node.node] = dequant_var_node return dequant_var_node + def _init_var_node(self, var_node, value): + assert isinstance( + value, np.ndarray), 'The type of value should be numpy array.' + assert self._scope is not None, \ + 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' + assert self._place is not None, \ + 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' + tensor = self._scope.var(var_node.name()).get_tensor() + tensor.set(value, self._place) + + def _to_node(self, nodes, node_name): + target_node = None + for n in nodes: + if n.name() == node_name: + target_node = n + assert target_node is not None, "Cannot find the target node in the giving set." + return target_node + def _load_var(self, name): return np.array(self._scope.find_var(name).get_tensor()) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e4169c247f..8097495f5b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1964,6 +1964,28 @@ class IrOpNode(IrNode): else: desc._set_attr(name, val) + def input_arg_names(self): + """ + Return input arguments' names of this op node. + + Returns: + list(str): input arguments' names of this op node. + """ + assert self.node.op() is not None, \ + "The node operator description cannot be None." + return self.node.op().input_arg_names() + + def output_arg_names(self): + """ + Return output arguments' names of this op node. + + Returns: + list(str): output arguments' names of this op node. + """ + assert self.node.op() is not None, \ + "The node operator description cannot be None." + return self.node.op().output_arg_names() + @property def inputs(self): """ @@ -2054,31 +2076,38 @@ class IrGraph(object): """ return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} - def var_node(self, name): + def _find_var_node(self, key): """ - Get a variable node by name from the graph. + Get a variable node by the `key` from this graph. The key + can be a node name or a node id. + + WARNS: + There are some nodes may have the same name. So, be + cautious about using this method when you find the + target var node by its name. Args: - name(str): the name of the variable node. + key(str|int): The str type denotes that the target variable node's name. + And the int type denotes that the target variable node's id. Raises: - ValueError: The If input's type is not str, or this graph - doesn't have a variable with the giving name. + ValueError: If this graph doesn't have a variable with the giving name or id. Returns: - IrVarNode: the variable node with the giving name. + IrVarNode: the variable node with the giving name or id. """ - if not isinstance(name, six.string_types): - raise TypeError( - "var require string as parameter, but get %s instead." % - (type(name))) target_var_node = None var_nodes = self.all_var_nodes() - for var_node in var_nodes: - if var_node.name() == name: - target_var_node = var_node + if isinstance(key, six.string_types): + for var_node in var_nodes: + if var_node.name() == key: + target_var_node = var_node + elif isinstance(key, int): + for var_node in var_nodes: + if var_node.id() == key: + target_var_node = var_node if target_var_node is None: - raise ValueError("var_node %s not in this graph" % name) + raise ValueError("var_node %s not in this graph" % key) return target_var_node def create_persistable_node(self, name, var_type, shape, var_dtype): -- GitLab