From 1c7e35dc98313b3c5c23dea3ceaf11feac4adcc1 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 13 Jun 2022 16:03:44 +0800 Subject: [PATCH] Add progress bar and speed up Quantization Pass (#43398) --- .../post_training_quantization.py | 51 +-- .../slim/quantization/quantization_pass.py | 301 ++++++++++-------- .../fluid/contrib/slim/quantization/utils.py | 27 +- 3 files changed, 223 insertions(+), 156 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 5c16e0fe27..a4888e6f90 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -17,6 +17,10 @@ import re import logging import numpy as np import shutil +try: + from tqdm import tqdm +except: + from .utils import tqdm from inspect import isgeneratorfunction from .... import io from .... import core @@ -359,38 +363,41 @@ class PostTrainingQuantization(object): self._set_activation_persistable() if self._algo in ["KL", "hist"]: - _logger.info("Preparation stage ...") batch_id = 0 + with tqdm( + total=self._batch_nums, + bar_format= + 'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for data in self._data_loader(): + self._executor.run(program=self._program, + feed=data, + fetch_list=self._fetch_list, + return_numpy=False, + scope=self._scope) + self._collect_activation_abs_min_max() + batch_id += 1 + t.update() + if self._batch_nums and batch_id >= self._batch_nums: + break + self._init_sampling_act_histogram() + + batch_id = 0 + with tqdm(total=self._batch_nums, + bar_format= + 'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: for data in self._data_loader(): self._executor.run(program=self._program, feed=data, fetch_list=self._fetch_list, return_numpy=False, scope=self._scope) - self._collect_activation_abs_min_max() - if batch_id % 5 == 0: - _logger.info("Run batch: " + str(batch_id)) + self._sampling() batch_id += 1 + t.update() if self._batch_nums and batch_id >= self._batch_nums: break - _logger.info("Finish preparation stage, all batch:" + str(batch_id)) - self._init_sampling_act_histogram() - - _logger.info("Sampling stage ...") - batch_id = 0 - for data in self._data_loader(): - self._executor.run(program=self._program, - feed=data, - fetch_list=self._fetch_list, - return_numpy=False, - scope=self._scope) - self._sampling() - if batch_id % 5 == 0: - _logger.info("Run batch: " + str(batch_id)) - batch_id += 1 - if self._batch_nums and batch_id >= self._batch_nums: - break - _logger.info("Finish sampling stage, all batch: " + str(batch_id)) if self._algo == 'avg': for var_name in self._quantized_act_var_name: diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index eaf9bed3d6..0dd79992eb 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -14,6 +14,10 @@ import collections import numpy as np +try: + from tqdm import tqdm +except: + from .utils import tqdm from ..... import compat as cpt from .... import core from ....framework import IrGraph @@ -373,10 +377,15 @@ class QuantizationTransformPass(object): graph.out_node_mapping_table = dict() # The process of _transform_forward and _transform_backward is needed in two for loops. # The loop for transforming the forward graph: - for op in ops: - if op.name() in self._quantizable_ops: - if not self._is_skip_quant(graph, op) and _has_weight(op): - _transform_forward(graph, op) + with tqdm(total=len(ops), + bar_format= + 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for op in ops: + if op.name() in self._quantizable_ops: + if not self._is_skip_quant(graph, op) and _has_weight(op): + _transform_forward(graph, op) + t.update() # The loop for renaming the inputs of backward op. for op in ops: if op.name() in self._quantizable_grad_ops and _has_weight(op): @@ -1418,73 +1427,81 @@ class OutScaleForTrainingPass(object): for op in graph.all_op_nodes(): if op.name() in self._teller_set: target_ops.append(op) - for op in target_ops: - for output_var_name in utils._get_op_output_var_names(op): - in_node = graph._find_node_by_name(op.outputs, output_var_name) - if in_node.dtype() not in \ - [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: - continue + with tqdm(total=len(target_ops), + bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for op in target_ops: + for output_var_name in utils._get_op_output_var_names(op): + in_node = graph._find_node_by_name(op.outputs, + output_var_name) + if in_node.dtype() not in \ + [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: + continue - scale_node = graph.create_persistable_node( - name=self._scale_name(in_node.name()), - var_type=core.VarDesc.VarType.LOD_TENSOR, - shape=[1], - var_dtype=in_node.dtype()) - data_type = 'float64' if in_node.dtype() \ - == core.VarDesc.VarType.FP64 else 'float32' - _init_var_node(scale_node, np.ones([1], dtype=data_type), - self._scope, self._place) - ins = {'X': in_node} - outs = {'OutScale': scale_node} - if not self._is_test: - state_in_node = graph.create_persistable_node( - name=unique_name.generate('scale_state@'), + scale_node = graph.create_persistable_node( + name=self._scale_name(in_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, - var_dtype=in_node.dtype(), - shape=[1]) - _init_var_node(state_in_node, np.ones([1], dtype=data_type), + shape=[1], + var_dtype=in_node.dtype()) + data_type = 'float64' if in_node.dtype() \ + == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node(scale_node, np.ones([1], dtype=data_type), self._scope, self._place) - accum_in_node = graph.create_persistable_node( - name=unique_name.generate('scale_accum@'), - var_type=core.VarDesc.VarType.LOD_TENSOR, - var_dtype=in_node.dtype(), - shape=[1]) - _init_var_node(accum_in_node, np.ones([1], dtype=data_type), - self._scope, self._place) - state_out_node = graph.create_var_node_from_desc( - state_in_node.var()) - accum_out_node = graph.create_var_node_from_desc( - accum_in_node.var()) - - ins['InState'] = state_in_node - ins['InAccum'] = accum_in_node - outs['OutState'] = state_out_node - outs['OutAccum'] = accum_out_node - - attrs = { - 'moving_rate': self._moving_rate, - 'is_test': self._is_test, - 'op_role': core.op_proto_and_checker_maker.OpRole.Forward - } - scale_op_node = graph.create_op_node( - op_type='moving_average_abs_max_scale', - attrs=attrs, - inputs=ins, - outputs=outs) - graph.link_to(in_node, scale_op_node) - graph.link_to(scale_op_node, scale_node) - if not self._is_test: - graph.link_to(state_in_node, scale_op_node) - graph.link_to(accum_in_node, scale_op_node) - graph.link_to(scale_op_node, state_out_node) - graph.link_to(scale_op_node, accum_out_node) + ins = {'X': in_node} + outs = {'OutScale': scale_node} + if not self._is_test: + state_in_node = graph.create_persistable_node( + name=unique_name.generate('scale_state@'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + var_dtype=in_node.dtype(), + shape=[1]) + _init_var_node(state_in_node, + np.ones([1], dtype=data_type), + self._scope, self._place) + accum_in_node = graph.create_persistable_node( + name=unique_name.generate('scale_accum@'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + var_dtype=in_node.dtype(), + shape=[1]) + _init_var_node(accum_in_node, + np.ones([1], dtype=data_type), + self._scope, self._place) + state_out_node = graph.create_var_node_from_desc( + state_in_node.var()) + accum_out_node = graph.create_var_node_from_desc( + accum_in_node.var()) + + ins['InState'] = state_in_node + ins['InAccum'] = accum_in_node + outs['OutState'] = state_out_node + outs['OutAccum'] = accum_out_node + + attrs = { + 'moving_rate': self._moving_rate, + 'is_test': self._is_test, + 'op_role': + core.op_proto_and_checker_maker.OpRole.Forward + } + scale_op_node = graph.create_op_node( + op_type='moving_average_abs_max_scale', + attrs=attrs, + inputs=ins, + outputs=outs) + graph.link_to(in_node, scale_op_node) + graph.link_to(scale_op_node, scale_node) + if not self._is_test: + graph.link_to(state_in_node, scale_op_node) + graph.link_to(accum_in_node, scale_op_node) + graph.link_to(scale_op_node, state_out_node) + graph.link_to(scale_op_node, accum_out_node) + t.update() return graph def _scale_name(self, var_name): """ Return the scale name for the var named `var_name`. """ - return "%s.scale" % (var_name) + return "%s@scale" % (var_name) class OutScaleForInferencePass(object): @@ -1544,7 +1561,7 @@ class OutScaleForInferencePass(object): """ Return the scale name for the var named `var_name`. """ - return "%s.scale" % (var_name) + return "%s@scale" % (var_name) class AddQuantDequantPass(object): @@ -1624,36 +1641,43 @@ class AddQuantDequantPass(object): # Forward stage, insert quant_dequant op all_op_nodes = graph.all_op_nodes() - for op_node in all_op_nodes: - if op_node.name() in self._quantizable_op_type: - is_skip = False - if isinstance(self._skip_pattern, list): - is_skip = op_node.op().has_attr("op_namescope") and \ - any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) - elif isinstance(self._skip_pattern, str): - is_skip = op_node.op().has_attr("op_namescope") and \ - op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 - is_quantized = op_node.op().has_attr("quantization_type") and \ - op_node.op().attr("quantization_type") == "qat_with_weight" - if is_skip or is_quantized or \ - (not _is_input_all_not_persistable(graph, op_node)): - continue + with tqdm(total=len(all_op_nodes), + bar_format= + 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for op_node in all_op_nodes: + if op_node.name() in self._quantizable_op_type: + is_skip = False + if isinstance(self._skip_pattern, list): + is_skip = op_node.op().has_attr("op_namescope") and \ + any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) + elif isinstance(self._skip_pattern, str): + is_skip = op_node.op().has_attr("op_namescope") and \ + op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 + is_quantized = op_node.op().has_attr("quantization_type") and \ + op_node.op().attr("quantization_type") == "qat_with_weight" + if is_skip or is_quantized or \ + (not _is_input_all_not_persistable(graph, op_node)): + continue - op_node.op()._set_attr("quantization_type", - "qat_without_weight") - op_node.op()._set_attr("activation_bits", self._quant_bits) - op_node.op()._set_attr("with_quant_attr", True) - arg_names = utils._get_op_input_var_names(op_node) - for arg_name in arg_names: - in_node = graph._find_node_by_name(op_node.inputs, arg_name) - if arg_name in dequantized_vars_map: - quant_var_node = dequantized_vars_map[arg_name] - else: - quant_var_node, _ = \ - self._inser_quant_dequant_moving_average_abs_max_op( - graph, in_node, self._quant_bits) - dequantized_vars_map[arg_name] = quant_var_node - graph.update_input_link(in_node, quant_var_node, op_node) + op_node.op()._set_attr("quantization_type", + "qat_without_weight") + op_node.op()._set_attr("activation_bits", self._quant_bits) + op_node.op()._set_attr("with_quant_attr", True) + arg_names = utils._get_op_input_var_names(op_node) + for arg_name in arg_names: + in_node = graph._find_node_by_name( + op_node.inputs, arg_name) + if arg_name in dequantized_vars_map: + quant_var_node = dequantized_vars_map[arg_name] + else: + quant_var_node, _ = \ + self._inser_quant_dequant_moving_average_abs_max_op( + graph, in_node, self._quant_bits) + dequantized_vars_map[arg_name] = quant_var_node + graph.update_input_link(in_node, quant_var_node, + op_node) + t.update() # Backward stage, update input link for op_node in all_op_nodes: @@ -2204,10 +2228,16 @@ class QuantizationTransformPassV2(object): graph.out_node_mapping_table = dict() # The process of _transform_forward and _transform_backward is needed in two for loops. # The loop for transforming the forward graph: - for op in ops: - if op.name() in self._quantizable_ops: - if not self._is_skip_quant(graph, op) and self._has_weight(op): - self._transform_forward(graph, op) + with tqdm(total=len(ops), + bar_format= + 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for op in ops: + if op.name() in self._quantizable_ops: + if not self._is_skip_quant(graph, + op) and self._has_weight(op): + self._transform_forward(graph, op) + t.update() # The loop for renaming the inputs of backward op. for op in ops: if op.name() in self._quantizable_grad_ops and self._has_weight(op): @@ -2310,43 +2340,50 @@ class AddQuantDequantPassV2(object): # Forward stage, insert quant_dequant op all_op_nodes = graph.all_op_nodes() - for op_node in all_op_nodes: - if op_node.name() in self._quantizable_op_type: - is_skip = False - if isinstance(self._skip_pattern, list): - is_skip = op_node.op().has_attr("op_namescope") and \ - any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) - elif isinstance(self._skip_pattern, str): - is_skip = op_node.op().has_attr("op_namescope") and \ - op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 - is_quantized = op_node.op().has_attr("quantization_type") and \ - op_node.op().attr("quantization_type") == "qat_with_weight" - if is_skip or is_quantized: - continue - - op_node.op()._set_attr("quantization_type", - "qat_without_weight") - arg_names = utils._get_op_input_var_names(op_node) - for arg_name in arg_names: - in_node = graph._find_node_by_name(op_node.inputs, arg_name) - if in_node.persistable(): + with tqdm(total=len(all_op_nodes), + bar_format= + 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', + ncols=80) as t: + for op_node in all_op_nodes: + if op_node.name() in self._quantizable_op_type: + is_skip = False + if isinstance(self._skip_pattern, list): + is_skip = op_node.op().has_attr("op_namescope") and \ + any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) + elif isinstance(self._skip_pattern, str): + is_skip = op_node.op().has_attr("op_namescope") and \ + op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 + is_quantized = op_node.op().has_attr("quantization_type") and \ + op_node.op().attr("quantization_type") == "qat_with_weight" + if is_skip or is_quantized: continue - if arg_name in dequantized_vars_map: - dequant_var_node = dequantized_vars_map[arg_name] - else: - insert_quant_pass = InsertQuantizeLinear( - self._place, - self._scope, - quant_bits=self._quant_bits, - quant_axis=-1, - channel_wise=False, - is_test=self._is_test) - quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( - graph, in_node) - dequant_var_node = insert_quant_pass.insert_dequant_op( - graph, quant_var_node, scale_var_node) - dequantized_vars_map[arg_name] = dequant_var_node - graph.update_input_link(in_node, dequant_var_node, op_node) + + op_node.op()._set_attr("quantization_type", + "qat_without_weight") + arg_names = utils._get_op_input_var_names(op_node) + for arg_name in arg_names: + in_node = graph._find_node_by_name( + op_node.inputs, arg_name) + if in_node.persistable(): + continue + if arg_name in dequantized_vars_map: + dequant_var_node = dequantized_vars_map[arg_name] + else: + insert_quant_pass = InsertQuantizeLinear( + self._place, + self._scope, + quant_bits=self._quant_bits, + quant_axis=-1, + channel_wise=False, + is_test=self._is_test) + quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( + graph, in_node) + dequant_var_node = insert_quant_pass.insert_dequant_op( + graph, quant_var_node, scale_var_node) + dequantized_vars_map[arg_name] = dequant_var_node + graph.update_input_link(in_node, dequant_var_node, + op_node) + t.update() # Backward stage, update input link for op_node in all_op_nodes: diff --git a/python/paddle/fluid/contrib/slim/quantization/utils.py b/python/paddle/fluid/contrib/slim/quantization/utils.py index 608844dd55..b9c304df5b 100644 --- a/python/paddle/fluid/contrib/slim/quantization/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import numpy as np from ....framework import IrNode from ....framework import Operator @@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [ "leaky_relu", "tanh", "swish", - "scale", "transpose", "transpose2", "sigmoid", @@ -162,7 +162,6 @@ _op_real_in_out_name = { "sigmoid": [["X"], ["Out"]], "elementwise_mul": [["X", "Y"], ["Out"]], "elementwise_pow": [["X", "Y"], ["Out"]], - "scale": [["X"], ["Out"]], "hard_swish": [["X"], ["Out"]], "hard_sigmoid": [["X"], ["Out"]], "gru": [["Input", "Weight"], ["Hidden"]], @@ -414,3 +413,27 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \ / (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten())) return cos_sim + + +class tqdm(object): + + def __init__(self, total, bar_format='Loading|{bar}', ncols=80): + self.total = total + self.bar_format = bar_format + self.ncols = ncols + self.n = 0 + + def update(self, n=1): + self.n += n + a = "=" * round((self.n / self.total) * self.ncols) + b = " " * (self.ncols - len(a)) + prefix = self.bar_format.split('|')[0] + sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n, + self.total)) + sys.stderr.flush() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stderr.write('\n') -- GitLab