From 183bacebe3d822776abdaa93a7f1765dcc0ade54 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 27 Mar 2019 16:46:39 +0800 Subject: [PATCH] clean codes and fix some bugs. test=develop --- .../slim/quantization/quantization_pass.py | 120 ++++++++++-------- .../quantization/quantization_strategy.py | 16 ++- .../slim/tests/quantization/compress.yaml | 2 + .../slim/tests/test_quantization_pass.py | 3 - python/paddle/fluid/framework.py | 80 ++++-------- 5 files changed, 103 insertions(+), 118 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index ab3bd8bd18..3809e32794 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -26,6 +26,17 @@ __all__ = [ ] +def _init_var_node(var_node, value, scope, place): + assert isinstance(value, + np.ndarray), 'The type of value should be numpy array.' + assert scope is not None, \ + 'The scope cannot be set None.' + assert place is not None, \ + 'The place cannot be set None.' + tensor = scope.var(var_node.name()).get_tensor() + tensor.set(value, place) + + class QuantizationTransformPass(object): def __init__(self, scope=None, @@ -88,14 +99,14 @@ class QuantizationTransformPass(object): assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'." if activation_quantize_type not in quant_type: raise ValueError( - "Unknown activation_quantize_type : '%s'. It can only be ", - "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", - str(activation_quantize_type)) + "Unknown activation_quantize_type : '%s'. It can only be " + "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'." % + (str(activation_quantize_type))) if weight_quantize_type not in quant_type: raise ValueError( - "Unknown weight_quantize_type: '%s'. It can only be ", - "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", - str(weight_quantize_type)) + "Unknown weight_quantize_type: '%s'. It can only be " + "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'." + % (str(weight_quantize_type))) self._activation_quantize_type = activation_quantize_type self._weight_quantize_type = weight_quantize_type @@ -121,8 +132,6 @@ class QuantizationTransformPass(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - #sequential_execution = core.get_pass('sequential_execution_pass') - #sequential_execution.apply(graph.graph) self._is_test = graph.is_test() # marked the variable which has been dequantized. dequantized_vars = collections.OrderedDict() @@ -203,9 +212,12 @@ class QuantizationTransformPass(object): var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], var_dtype=core.VarDesc.VarType.INT64) - self._init_var_node( - global_step_in, np.zeros( - [1], dtype='int64')) + _init_var_node( + global_step_in, + np.zeros( + [1], dtype='int64'), + self._scope, + self._place) global_step_out = graph.create_var_node_from_desc( global_step_in.var()) # The attribute of `op_role` is needed by ParallelExecutor. @@ -284,7 +296,12 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) + _init_var_node( + scale_in_node, + np.array( + [0.001], dtype=data_type), + self._scope, + self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) inputs = {'X': var_node, 'InScale': scale_in_node} @@ -299,9 +316,13 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - self._init_var_node( - scales_node, np.zeros( - [self._window_size], dtype=data_type)) + _init_var_node( + scales_node, + np.zeros( + [self._window_size], dtype=data_type), + self._scope, + self._place) + inputs['Iter'] = self._global_step outputs['OutScales'] = scales_node attrs = { @@ -343,7 +364,12 @@ class QuantizationTransformPass(object): var_dtype=var_node.dtype()) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type)) + _init_var_node( + scale_in_node, + np.array( + [0.001], dtype=data_type), + self._scope, + self._place) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) ins = {'X': var_node, 'InScale': scale_in_node} @@ -356,13 +382,23 @@ class QuantizationTransformPass(object): shape=[1]) data_type = 'float64' if var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - self._init_var_node(scale_in_node, np.ones([1], dtype=data_type)) + _init_var_node( + scale_in_node, + np.ones( + [1], dtype=data_type), + self._scope, + self._place) accum_in_node = graph.create_persistable_node( name=unique_name.generate('accum'), var_type=core.VarDesc.VarType.LOD_TENSOR, var_dtype=var_node.dtype(), shape=[1]) - self._init_var_node(accum_in_node, np.ones([1], dtype=data_type)) + _init_var_node( + accum_in_node, + np.ones( + [1], dtype=data_type), + self._scope, + self._place) state_out_node = graph.create_var_node_from_desc(state_in_node.var( )) accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( @@ -482,16 +518,6 @@ class QuantizationTransformPass(object): graph.link_to(dequant_op_node, dequant_var_node) return dequant_var_node - def _init_var_node(self, var_node, value): - assert isinstance( - value, np.ndarray), 'The type of value should be numpy array.' - assert self._scope is not None, \ - 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' - assert self._place is not None, \ - 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' - tensor = self._scope.var(var_node.name()).get_tensor() - tensor.set(value, self._place) - def _quantized_var_name(self, var_name): """ Return quantized variable name for the input `var_name`. @@ -594,8 +620,8 @@ class QuantizationFreezePass(object): self._weight_bits) self._restore_var(input_arg_name, quantized_param_v) else: - scale_v = self._to_node(op_node.outputs, - op_node.output('OutScale')[0]) + scale_v = graph._find_node_by_name( + op_node.outputs, op_node.output('OutScale')[0]) self._var_scale_map[input_arg_name] = scale_v ops = graph.all_op_nodes() @@ -627,8 +653,8 @@ class QuantizationFreezePass(object): return graph def _remove_fake_quant_and_dequant_op(self, graph, op_node): - k = self._to_node(op_node.outputs, op_node.output('Out')[0]) - v = self._to_node(op_node.inputs, op_node.input('X')[0]) + k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0]) + v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0]) if v.node not in self._op_input_rename_map: self._op_input_rename_map[k.node] = v else: @@ -663,8 +689,8 @@ class QuantizationFreezePass(object): raise ValueError("Only support one output, but op %s has" " more than one output." % (op_node.name())) - output_var_node = self._to_node(op_node.outputs, - op_node.output_arg_names()[0]) + output_var_node = graph._find_node_by_name( + op_node.outputs, op_node.output_arg_names()[0]) weight_scale_node = graph.create_persistable_node( name=unique_name.generate('channel_scale'), var_type=core.VarDesc.VarType.LOD_TENSOR, @@ -672,7 +698,9 @@ class QuantizationFreezePass(object): var_dtype=output_var_node.dtype()) data_type = 'float64' if output_var_node.dtype( ) == core.VarDesc.VarType.FP64 else 'float32' - self._init_var_node(weight_scale_node, channel_scale.astype(data_type)) + _init_var_node(weight_scale_node, + channel_scale.astype(data_type), self._scope, + self._place) dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(output_var_node.name()), var_type=output_var_node.type(), @@ -724,8 +752,8 @@ class QuantizationFreezePass(object): raise ValueError("Only support one output, but op %s has" " more than one output." % (op_node.name())) - output_var_node = self._to_node(op_node.outputs, - op_node.output_arg_names()[0]) + output_var_node = graph._find_node_by_name( + op_node.outputs, op_node.output_arg_names()[0]) dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(output_var_node.name()), var_type=output_var_node.type(), @@ -746,24 +774,6 @@ class QuantizationFreezePass(object): self._op_output_rename_map[output_var_node.node] = dequant_var_node return dequant_var_node - def _init_var_node(self, var_node, value): - assert isinstance( - value, np.ndarray), 'The type of value should be numpy array.' - assert self._scope is not None, \ - 'The scope cannot be set None when activation_quantize_type equals to range_abs_max.' - assert self._place is not None, \ - 'The place cannot be set None when activation_quantize_type equals to range_abs_max.' - tensor = self._scope.var(var_node.name()).get_tensor() - tensor.set(value, self._place) - - def _to_node(self, nodes, node_name): - target_node = None - for n in nodes: - if n.name() == node_name: - target_node = n - assert target_node is not None, "Cannot find the target node in the giving set." - return target_node - def _load_var(self, name): return np.array(self._scope.find_var(name).get_tensor()) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py index 6812b4c633..da3510de39 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py @@ -45,13 +45,14 @@ class QuantizationStrategy(Strategy): activation_bits=8, weight_bits=8, activation_quantize_type='abs_max', + weight_quantize_type='abs_max', save_in_nodes=None, save_out_nodes=None): """ Args: start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0 end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0 - float_model_save_path(str): The path to save model with float weights. + float_model_save_path(str): The path to save model with float weights. None means it doesn't save float model. defalut: None. mobile_model_save_path(str): The path to save model for paddle-mobile execution. None means it doesn't save mobile model. defalut: None. @@ -66,9 +67,11 @@ class QuantizationStrategy(Strategy): dynamically each step in both training and testing period. If use 'range_abs_max', a static quantization scale will be calculated during training and used in inference. - save_in_nodes(list): A list of variable names used to prune graph + weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. + The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. + save_in_nodes(list): A list of variable names used to prune graph for saving inference model. - save_out_nodes(list): A list of variable names used to prune graph + save_out_nodes(list): A list of variable names used to prune graph for saving inference model. """ @@ -81,6 +84,7 @@ class QuantizationStrategy(Strategy): self.activation_bits = activation_bits self.weight_bits = weight_bits self.activation_quantize_type = activation_quantize_type + self.weight_quantize_type = weight_quantize_type self.save_out_nodes = save_out_nodes self.save_in_nodes = save_in_nodes @@ -100,7 +104,8 @@ class QuantizationStrategy(Strategy): place=context.place, weight_bits=self.weight_bits, activation_bits=self.activation_bits, - activation_quantize_type=self.activation_quantize_type) + activation_quantize_type=self.activation_quantize_type, + weight_quantize_type=self.weight_quantize_type) transform_pass.apply(train_ir_graph) transform_pass.apply(test_ir_graph) @@ -134,7 +139,8 @@ class QuantizationStrategy(Strategy): scope=context.scope, place=context.place, weight_bits=self.weight_bits, - activation_bits=self.activation_bits) + activation_bits=self.activation_bits, + weight_quantize_type=self.weight_quantize_type) freeze_pass.apply(test_ir_graph) # for other strategies diff --git a/python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml b/python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml index f29eb53f88..a3a5a724fb 100644 --- a/python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml +++ b/python/paddle/fluid/contrib/slim/tests/quantization/compress.yaml @@ -35,6 +35,8 @@ strategies: start_epoch: 0 end_epoch: 0 float_model_save_path: './output/float' + mobile_model_save_path: './output/mobile' + int8_model_save_path: './output/int8' weight_bits: 8 activation_bits: 8 weight_quantize_type: 'abs_max' diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index c7feca0b82..e896f8bb42 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -256,8 +256,6 @@ class TestQuantizationFreezePass(unittest.TestCase): place=place, activation_quantize_type=activation_quant_type, weight_quantize_type=weight_quant_type) - #transform_pass = QuantizationTransformPass( - # scope=scope, place=place, activation_quantize_type=activation_quant_type) transform_pass.apply(main_graph) transform_pass.apply(test_graph) dev_name = '_gpu_' if use_cuda else '_cpu_' @@ -315,7 +313,6 @@ class TestQuantizationFreezePass(unittest.TestCase): # Freeze graph for inference, but the weight of fc/conv is still float type. freeze_pass = QuantizationFreezePass( scope=scope, place=place, weight_quantize_type=weight_quant_type) - #freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass.apply(test_graph) if not for_ci: marked_nodes = set() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5ac2b50a99..a209f389f3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -104,14 +104,14 @@ def cuda_places(device_ids=None): :code:`FLAGS_selected_gpus=0,1,2`, the returned list would be [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)]. If :code:`FLAGS_selected_gpus` is not set, all visible - gpu places would be returned. + gpu places would be returned. If :code:`device_ids` is not None, it should be the device - ids of gpus. For example, if :code:`device_ids=[0,1,2]`, - the returned list would be + ids of gpus. For example, if :code:`device_ids=[0,1,2]`, + the returned list would be [fluid.CUDAPlace(0), fluid.CUDAPlace(1), fluid.CUDAPlace(2)]. - - Args: + + Args: device_ids (None|list(int)|tuple(int)): gpu device id list. Returns: @@ -133,11 +133,11 @@ def cuda_places(device_ids=None): def cpu_places(device_count=None): ''' Create a list of :code:`fluid.CPUPlace` objects. - + If :code:`device_count` is None, the device count would - be determined by environment variable :code:`CPU_NUM`. + be determined by environment variable :code:`CPU_NUM`. If :code:`CPU_NUM` is not set, the device count would - be determined by :code:`multiprocessing.cpu_count()`. + be determined by :code:`multiprocessing.cpu_count()`. Args: device_count (None|int): device number. @@ -155,9 +155,9 @@ def cuda_pinned_places(device_count=None): Create a list of :code:`fluid.CUDAPinnedPlace` objects. If :code:`device_count` is None, the device count would - be determined by environment variable :code:`CPU_NUM`. + be determined by environment variable :code:`CPU_NUM`. If :code:`CPU_NUM` is not set, the device count would - be determined by :code:`multiprocessing.cpu_count()`. + be determined by :code:`multiprocessing.cpu_count()`. Args: device_count (None|int): device number. @@ -2164,40 +2164,6 @@ class IrGraph(object): """ return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} - def _find_var_node(self, key): - """ - Get a variable node by the `key` from this graph. The key - can be a node name or a node id. - - WARNS: - There are some nodes may have the same name. So, be - cautious about using this method when you find the - target var node by its name. - - Args: - key(str|int): The str type denotes that the target variable node's name. - And the int type denotes that the target variable node's id. - - Raises: - ValueError: If this graph doesn't have a variable with the giving name or id. - - Returns: - IrVarNode: the variable node with the giving name or id. - """ - target_var_node = None - var_nodes = self.all_var_nodes() - if isinstance(key, six.string_types): - for var_node in var_nodes: - if var_node.name() == key: - target_var_node = var_node - elif isinstance(key, int): - for var_node in var_nodes: - if var_node.id() == key: - target_var_node = var_node - if target_var_node is None: - raise ValueError("var_node %s not in this graph" % key) - return target_var_node - def create_persistable_node(self, name, var_type, shape, var_dtype): """ Create a persistable variable node in the graph. In IrGraph, @@ -2342,14 +2308,6 @@ class IrGraph(object): core.graph_safe_remove_nodes(self.graph, original_nodes) def resolve_hazard(self): - def _to_node(nodes, node_name): - target_node = None - for n in nodes: - if n.name() == node_name: - target_node = n - assert target_node is not None, "Cannot find the target node in the giving set." - return target_node - ordered_nodes = core.topology_sort(self.graph) var_nodes = dict() for node in ordered_nodes: @@ -2357,16 +2315,17 @@ class IrGraph(object): for each_var_name in node.op().input_arg_names(): if each_var_name not in var_nodes: var_nodes[each_var_name] = [ - _to_node(node.inputs, each_var_name) + self._find_node_by_name(node.inputs, each_var_name) ] for each_var_name in node.op().output_arg_names(): if each_var_name not in var_nodes: var_nodes[each_var_name] = [ - _to_node(node.outputs, each_var_name) + self._find_node_by_name(node.outputs, each_var_name) ] else: var_nodes[each_var_name].append( - _to_node(node.outputs, each_var_name)) + self._find_node_by_name(node.outputs, + each_var_name)) self.graph.resolve_hazard(var_nodes) def has_circle(self): @@ -2479,6 +2438,17 @@ class IrGraph(object): program = Program._construct_from_desc(desc) return program + def _find_node_by_name(self, nodes, node_name): + """ + Find a node in the giving nodes set by the name. + """ + target_node = None + for n in nodes: + if n.name() == node_name: + target_node = n + assert target_node is not None, "Cannot find the target node in the giving set." + return target_node + def _update_desc_attr(self, desc, name, val): """ Update the value of desc's attribute by attribute's name. -- GitLab