未验证 提交 abb0b2d6 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]Add progress bar and speed up Quantization Pass (#43454)

* Add progress bar and speed up Quantization Pass

* fix typo
上级 7e940b84
...@@ -17,6 +17,10 @@ import re ...@@ -17,6 +17,10 @@ import re
import logging import logging
import numpy as np import numpy as np
import shutil import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction from inspect import isgeneratorfunction
from .... import io from .... import io
from .... import core from .... import core
...@@ -357,8 +361,11 @@ class PostTrainingQuantization(object): ...@@ -357,8 +361,11 @@ class PostTrainingQuantization(object):
self._set_activation_persistable() self._set_activation_persistable()
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0 batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
...@@ -366,16 +373,17 @@ class PostTrainingQuantization(object): ...@@ -366,16 +373,17 @@ class PostTrainingQuantization(object):
return_numpy=False, return_numpy=False,
scope=self._scope) scope=self._scope)
self._collect_activation_abs_min_max() self._collect_activation_abs_min_max()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram() self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0 batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
...@@ -383,12 +391,10 @@ class PostTrainingQuantization(object): ...@@ -383,12 +391,10 @@ class PostTrainingQuantization(object):
return_numpy=False, return_numpy=False,
scope=self._scope) scope=self._scope)
self._sampling() self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._algo == 'avg': if self._algo == 'avg':
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
...@@ -823,8 +829,9 @@ class PostTrainingQuantization(object): ...@@ -823,8 +829,9 @@ class PostTrainingQuantization(object):
min_value = float(np.min(var_tensor)) min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor)) max_value = float(np.max(var_tensor))
if var_name not in self._sampling_act_abs_min_max: if var_name not in self._sampling_act_abs_min_max:
self._sampling_act_abs_min_max[ self._sampling_act_abs_min_max[var_name] = [
var_name] = [min_value, max_value] min_value, max_value
]
else: else:
if min_value < self._sampling_act_abs_min_max[var_name][0]: if min_value < self._sampling_act_abs_min_max[var_name][0]:
self._sampling_act_abs_min_max[var_name][0] = min_value self._sampling_act_abs_min_max[var_name][0] = min_value
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
import collections import collections
import numpy as np import numpy as np
try:
from tqdm import tqdm
except:
from .utils import tqdm
from ..... import compat as cpt from ..... import compat as cpt
from .... import core from .... import core
from ....framework import IrGraph from ....framework import IrGraph
...@@ -372,10 +376,15 @@ class QuantizationTransformPass(object): ...@@ -372,10 +376,15 @@ class QuantizationTransformPass(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
with tqdm(
total=len(ops),
bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and _has_weight(op): if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op) _transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops and _has_weight(op): if op.name() in self._quantizable_grad_ops and _has_weight(op):
...@@ -1427,9 +1436,14 @@ class OutScaleForTrainingPass(object): ...@@ -1427,9 +1436,14 @@ class OutScaleForTrainingPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._teller_set: if op.name() in self._teller_set:
target_ops.append(op) target_ops.append(op)
with tqdm(
total=len(target_ops),
bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in target_ops: for op in target_ops:
for output_var_name in utils._get_op_output_var_names(op): for output_var_name in utils._get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs, output_var_name) in_node = graph._find_node_by_name(op.outputs,
output_var_name)
if in_node.dtype() not in \ if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue continue
...@@ -1485,7 +1499,8 @@ class OutScaleForTrainingPass(object): ...@@ -1485,7 +1499,8 @@ class OutScaleForTrainingPass(object):
attrs = { attrs = {
'moving_rate': self._moving_rate, 'moving_rate': self._moving_rate,
'is_test': self._is_test, 'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
} }
scale_op_node = graph.create_op_node( scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale', op_type='moving_average_abs_max_scale',
...@@ -1499,13 +1514,14 @@ class OutScaleForTrainingPass(object): ...@@ -1499,13 +1514,14 @@ class OutScaleForTrainingPass(object):
graph.link_to(accum_in_node, scale_op_node) graph.link_to(accum_in_node, scale_op_node)
graph.link_to(scale_op_node, state_out_node) graph.link_to(scale_op_node, state_out_node)
graph.link_to(scale_op_node, accum_out_node) graph.link_to(scale_op_node, accum_out_node)
t.update()
return graph return graph
def _scale_name(self, var_name): def _scale_name(self, var_name):
""" """
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
class OutScaleForInferencePass(object): class OutScaleForInferencePass(object):
...@@ -1564,7 +1580,7 @@ class OutScaleForInferencePass(object): ...@@ -1564,7 +1580,7 @@ class OutScaleForInferencePass(object):
""" """
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
class AddQuantDequantPass(object): class AddQuantDequantPass(object):
...@@ -1644,6 +1660,10 @@ class AddQuantDequantPass(object): ...@@ -1644,6 +1660,10 @@ class AddQuantDequantPass(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
with tqdm(
total=len(all_op_nodes),
bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op_node in all_op_nodes: for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type: if op_node.name() in self._quantizable_op_type:
is_skip = False is_skip = False
...@@ -1665,7 +1685,8 @@ class AddQuantDequantPass(object): ...@@ -1665,7 +1685,8 @@ class AddQuantDequantPass(object):
op_node.op()._set_attr("with_quant_attr", True) op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node) arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name) in_node = graph._find_node_by_name(op_node.inputs,
arg_name)
if arg_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name] quant_var_node = dequantized_vars_map[arg_name]
else: else:
...@@ -1673,7 +1694,9 @@ class AddQuantDequantPass(object): ...@@ -1673,7 +1694,9 @@ class AddQuantDequantPass(object):
self._inser_quant_dequant_moving_average_abs_max_op( self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits) graph, in_node, self._quant_bits)
dequantized_vars_map[arg_name] = quant_var_node dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, op_node) graph.update_input_link(in_node, quant_var_node,
op_node)
t.update()
# Backward stage, update input link # Backward stage, update input link
for op_node in all_op_nodes: for op_node in all_op_nodes:
...@@ -2240,10 +2263,16 @@ class QuantizationTransformPassV2(object): ...@@ -2240,10 +2263,16 @@ class QuantizationTransformPassV2(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
with tqdm(
total=len(ops),
bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and self._has_weight(op): if not self._is_skip_quant(graph,
op) and self._has_weight(op):
self._transform_forward(graph, op) self._transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops and self._has_weight(op): if op.name() in self._quantizable_grad_ops and self._has_weight(op):
...@@ -2346,6 +2375,10 @@ class AddQuantDequantPassV2(object): ...@@ -2346,6 +2375,10 @@ class AddQuantDequantPassV2(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
with tqdm(
total=len(all_op_nodes),
bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op_node in all_op_nodes: for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type: if op_node.name() in self._quantizable_op_type:
is_skip = False is_skip = False
...@@ -2364,7 +2397,8 @@ class AddQuantDequantPassV2(object): ...@@ -2364,7 +2397,8 @@ class AddQuantDequantPassV2(object):
"qat_without_weight") "qat_without_weight")
arg_names = utils._get_op_input_var_names(op_node) arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name) in_node = graph._find_node_by_name(op_node.inputs,
arg_name)
if in_node.persistable(): if in_node.persistable():
continue continue
if arg_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
...@@ -2382,7 +2416,9 @@ class AddQuantDequantPassV2(object): ...@@ -2382,7 +2416,9 @@ class AddQuantDequantPassV2(object):
dequant_var_node = insert_quant_pass.insert_dequant_op( dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node) graph, quant_var_node, scale_var_node)
dequantized_vars_map[arg_name] = dequant_var_node dequantized_vars_map[arg_name] = dequant_var_node
graph.update_input_link(in_node, dequant_var_node, op_node) graph.update_input_link(in_node, dequant_var_node,
op_node)
t.update()
# Backward stage, update input link # Backward stage, update input link
for op_node in all_op_nodes: for op_node in all_op_nodes:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import numpy as np import numpy as np
from ....framework import IrNode from ....framework import IrNode
from ....framework import Operator from ....framework import Operator
...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [ ...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu", "leaky_relu",
"tanh", "tanh",
"swish", "swish",
"scale",
"transpose", "transpose",
"transpose2", "transpose2",
"sigmoid", "sigmoid",
...@@ -162,7 +162,6 @@ _op_real_in_out_name = { ...@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]], "sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]], "elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]], "elementwise_pow": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]], "hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]], "hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]], "gru": [["Input", "Weight"], ["Hidden"]],
...@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): ...@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \ cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten())) / (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim return cos_sim
class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total
self.bar_format = bar_format
self.ncols = ncols
self.n = 0
def update(self, n=1):
self.n += n
a = "=" * round((self.n / self.total) * self.ncols)
b = " " * (self.ncols - len(a))
prefix = self.bar_format.split('|')[0]
sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n,
self.total))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册