未验证 提交 abb0b2d6 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]Add progress bar and speed up Quantization Pass (#43454)

* Add progress bar and speed up Quantization Pass

* fix typo
上级 7e940b84
......@@ -17,6 +17,10 @@ import re
import logging
import numpy as np
import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction
from .... import io
from .... import core
......@@ -357,8 +361,11 @@ class PostTrainingQuantization(object):
self._set_activation_persistable()
if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
......@@ -366,16 +373,17 @@ class PostTrainingQuantization(object):
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
......@@ -383,12 +391,10 @@ class PostTrainingQuantization(object):
return_numpy=False,
scope=self._scope)
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
......@@ -823,8 +829,9 @@ class PostTrainingQuantization(object):
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if var_name not in self._sampling_act_abs_min_max:
self._sampling_act_abs_min_max[
var_name] = [min_value, max_value]
self._sampling_act_abs_min_max[var_name] = [
min_value, max_value
]
else:
if min_value < self._sampling_act_abs_min_max[var_name][0]:
self._sampling_act_abs_min_max[var_name][0] = min_value
......
......@@ -14,6 +14,10 @@
import collections
import numpy as np
try:
from tqdm import tqdm
except:
from .utils import tqdm
from ..... import compat as cpt
from .... import core
from ....framework import IrGraph
......@@ -372,10 +376,15 @@ class QuantizationTransformPass(object):
graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
with tqdm(
total=len(ops),
bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops and _has_weight(op):
......@@ -1427,9 +1436,14 @@ class OutScaleForTrainingPass(object):
for op in graph.all_op_nodes():
if op.name() in self._teller_set:
target_ops.append(op)
with tqdm(
total=len(target_ops),
bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in target_ops:
for output_var_name in utils._get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs, output_var_name)
in_node = graph._find_node_by_name(op.outputs,
output_var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
......@@ -1485,7 +1499,8 @@ class OutScaleForTrainingPass(object):
attrs = {
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
......@@ -1499,13 +1514,14 @@ class OutScaleForTrainingPass(object):
graph.link_to(accum_in_node, scale_op_node)
graph.link_to(scale_op_node, state_out_node)
graph.link_to(scale_op_node, accum_out_node)
t.update()
return graph
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)
class OutScaleForInferencePass(object):
......@@ -1564,7 +1580,7 @@ class OutScaleForInferencePass(object):
"""
Return the scale name for the var named `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)
class AddQuantDequantPass(object):
......@@ -1644,6 +1660,10 @@ class AddQuantDequantPass(object):
# Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes()
with tqdm(
total=len(all_op_nodes),
bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
is_skip = False
......@@ -1665,7 +1685,8 @@ class AddQuantDequantPass(object):
op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
in_node = graph._find_node_by_name(op_node.inputs,
arg_name)
if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name]
else:
......@@ -1673,7 +1694,9 @@ class AddQuantDequantPass(object):
self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, op_node)
graph.update_input_link(in_node, quant_var_node,
op_node)
t.update()
# Backward stage, update input link
for op_node in all_op_nodes:
......@@ -2240,10 +2263,16 @@ class QuantizationTransformPassV2(object):
graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
with tqdm(
total=len(ops),
bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and self._has_weight(op):
if not self._is_skip_quant(graph,
op) and self._has_weight(op):
self._transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops and self._has_weight(op):
......@@ -2346,6 +2375,10 @@ class AddQuantDequantPassV2(object):
# Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes()
with tqdm(
total=len(all_op_nodes),
bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
is_skip = False
......@@ -2364,7 +2397,8 @@ class AddQuantDequantPassV2(object):
"qat_without_weight")
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
in_node = graph._find_node_by_name(op_node.inputs,
arg_name)
if in_node.persistable():
continue
if arg_name in dequantized_vars_map:
......@@ -2382,7 +2416,9 @@ class AddQuantDequantPassV2(object):
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node)
dequantized_vars_map[arg_name] = dequant_var_node
graph.update_input_link(in_node, dequant_var_node, op_node)
graph.update_input_link(in_node, dequant_var_node,
op_node)
t.update()
# Backward stage, update input link
for op_node in all_op_nodes:
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from ....framework import IrNode
from ....framework import Operator
......@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu",
"tanh",
"swish",
"scale",
"transpose",
"transpose2",
"sigmoid",
......@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
......@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim
class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total
self.bar_format = bar_format
self.ncols = ncols
self.n = 0
def update(self, n=1):
self.n += n
a = "=" * round((self.n / self.total) * self.ncols)
b = " " * (self.ncols - len(a))
prefix = self.bar_format.split('|')[0]
sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n,
self.total))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册