未验证 提交 abb0b2d6 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]Add progress bar and speed up Quantization Pass (#43454)

* Add progress bar and speed up Quantization Pass

* fix typo
上级 7e940b84
...@@ -17,6 +17,10 @@ import re ...@@ -17,6 +17,10 @@ import re
import logging import logging
import numpy as np import numpy as np
import shutil import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction from inspect import isgeneratorfunction
from .... import io from .... import io
from .... import core from .... import core
...@@ -357,38 +361,40 @@ class PostTrainingQuantization(object): ...@@ -357,38 +361,40 @@ class PostTrainingQuantization(object):
self._set_activation_persistable() self._set_activation_persistable()
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0 batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
self._init_sampling_act_histogram()
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False, return_numpy=False,
scope=self._scope) scope=self._scope)
self._collect_activation_abs_min_max() self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._algo == 'avg': if self._algo == 'avg':
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
...@@ -823,8 +829,9 @@ class PostTrainingQuantization(object): ...@@ -823,8 +829,9 @@ class PostTrainingQuantization(object):
min_value = float(np.min(var_tensor)) min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor)) max_value = float(np.max(var_tensor))
if var_name not in self._sampling_act_abs_min_max: if var_name not in self._sampling_act_abs_min_max:
self._sampling_act_abs_min_max[ self._sampling_act_abs_min_max[var_name] = [
var_name] = [min_value, max_value] min_value, max_value
]
else: else:
if min_value < self._sampling_act_abs_min_max[var_name][0]: if min_value < self._sampling_act_abs_min_max[var_name][0]:
self._sampling_act_abs_min_max[var_name][0] = min_value self._sampling_act_abs_min_max[var_name][0] = min_value
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
import collections import collections
import numpy as np import numpy as np
try:
from tqdm import tqdm
except:
from .utils import tqdm
from ..... import compat as cpt from ..... import compat as cpt
from .... import core from .... import core
from ....framework import IrGraph from ....framework import IrGraph
...@@ -294,10 +298,10 @@ class QuantizationTransformPass(object): ...@@ -294,10 +298,10 @@ class QuantizationTransformPass(object):
else False else False
# if var node is weight and weight_preprocess_func is not None, # if var node is weight and weight_preprocess_func is not None,
# will insert weight preprocess func # will insert weight preprocess func
# to preorocess weight before quantization # to preorocess weight before quantization
# if var node is activation and act_preprocess_func is not None, # if var node is activation and act_preprocess_func is not None,
# will insert activation preprocess func # will insert activation preprocess func
# to preorocess activation before quantization # to preorocess activation before quantization
if is_weight and self._weight_preprocess_func is not None: if is_weight and self._weight_preprocess_func is not None:
var_node = self._insert_func( var_node = self._insert_func(
...@@ -372,10 +376,15 @@ class QuantizationTransformPass(object): ...@@ -372,10 +376,15 @@ class QuantizationTransformPass(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: with tqdm(
if op.name() in self._quantizable_ops: total=len(ops),
if not self._is_skip_quant(graph, op) and _has_weight(op): bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}',
_transform_forward(graph, op) ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops and _has_weight(op): if op.name() in self._quantizable_grad_ops and _has_weight(op):
...@@ -1427,85 +1436,92 @@ class OutScaleForTrainingPass(object): ...@@ -1427,85 +1436,92 @@ class OutScaleForTrainingPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._teller_set: if op.name() in self._teller_set:
target_ops.append(op) target_ops.append(op)
for op in target_ops: with tqdm(
for output_var_name in utils._get_op_output_var_names(op): total=len(target_ops),
in_node = graph._find_node_by_name(op.outputs, output_var_name) bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
if in_node.dtype() not in \ ncols=80) as t:
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: for op in target_ops:
continue for output_var_name in utils._get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs,
output_var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
scale_node = graph.create_persistable_node( scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()), name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(
scale_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
ins = {'X': in_node}
outs = {'OutScale': scale_node}
if not self._is_test:
state_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_state@'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(), shape=[1],
shape=[1]) var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(
state_in_node, scale_node,
np.ones( np.ones(
[1], dtype=data_type), [1], dtype=data_type),
self._scope, self._scope,
self._place) self._place)
accum_in_node = graph.create_persistable_node( ins = {'X': in_node}
name=unique_name.generate('scale_accum@'), outs = {'OutScale': scale_node}
var_type=core.VarDesc.VarType.LOD_TENSOR, if not self._is_test:
var_dtype=in_node.dtype(), state_in_node = graph.create_persistable_node(
shape=[1]) name=unique_name.generate('scale_state@'),
_init_var_node( var_type=core.VarDesc.VarType.LOD_TENSOR,
accum_in_node, var_dtype=in_node.dtype(),
np.ones( shape=[1])
[1], dtype=data_type), _init_var_node(
self._scope, state_in_node,
self._place) np.ones(
state_out_node = graph.create_var_node_from_desc( [1], dtype=data_type),
state_in_node.var()) self._scope,
accum_out_node = graph.create_var_node_from_desc( self._place)
accum_in_node.var()) accum_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_accum@'),
ins['InState'] = state_in_node var_type=core.VarDesc.VarType.LOD_TENSOR,
ins['InAccum'] = accum_in_node var_dtype=in_node.dtype(),
outs['OutState'] = state_out_node shape=[1])
outs['OutAccum'] = accum_out_node _init_var_node(
accum_in_node,
attrs = { np.ones(
'moving_rate': self._moving_rate, [1], dtype=data_type),
'is_test': self._is_test, self._scope,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward self._place)
} state_out_node = graph.create_var_node_from_desc(
scale_op_node = graph.create_op_node( state_in_node.var())
op_type='moving_average_abs_max_scale', accum_out_node = graph.create_var_node_from_desc(
attrs=attrs, accum_in_node.var())
inputs=ins,
outputs=outs) ins['InState'] = state_in_node
graph.link_to(in_node, scale_op_node) ins['InAccum'] = accum_in_node
graph.link_to(scale_op_node, scale_node) outs['OutState'] = state_out_node
if not self._is_test: outs['OutAccum'] = accum_out_node
graph.link_to(state_in_node, scale_op_node)
graph.link_to(accum_in_node, scale_op_node) attrs = {
graph.link_to(scale_op_node, state_out_node) 'moving_rate': self._moving_rate,
graph.link_to(scale_op_node, accum_out_node) 'is_test': self._is_test,
'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
attrs=attrs,
inputs=ins,
outputs=outs)
graph.link_to(in_node, scale_op_node)
graph.link_to(scale_op_node, scale_node)
if not self._is_test:
graph.link_to(state_in_node, scale_op_node)
graph.link_to(accum_in_node, scale_op_node)
graph.link_to(scale_op_node, state_out_node)
graph.link_to(scale_op_node, accum_out_node)
t.update()
return graph return graph
def _scale_name(self, var_name): def _scale_name(self, var_name):
""" """
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
class OutScaleForInferencePass(object): class OutScaleForInferencePass(object):
...@@ -1564,7 +1580,7 @@ class OutScaleForInferencePass(object): ...@@ -1564,7 +1580,7 @@ class OutScaleForInferencePass(object):
""" """
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
class AddQuantDequantPass(object): class AddQuantDequantPass(object):
...@@ -1644,36 +1660,43 @@ class AddQuantDequantPass(object): ...@@ -1644,36 +1660,43 @@ class AddQuantDequantPass(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes: with tqdm(
if op_node.name() in self._quantizable_op_type: total=len(all_op_nodes),
is_skip = False bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
if isinstance(self._skip_pattern, list): ncols=80) as t:
is_skip = op_node.op().has_attr("op_namescope") and \ for op_node in all_op_nodes:
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) if op_node.name() in self._quantizable_op_type:
elif isinstance(self._skip_pattern, str): is_skip = False
is_skip = op_node.op().has_attr("op_namescope") and \ if isinstance(self._skip_pattern, list):
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 is_skip = op_node.op().has_attr("op_namescope") and \
is_quantized = op_node.op().has_attr("quantization_type") and \ any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
op_node.op().attr("quantization_type") == "qat_with_weight" elif isinstance(self._skip_pattern, str):
if is_skip or is_quantized or \ is_skip = op_node.op().has_attr("op_namescope") and \
(not _is_input_all_not_persistable(graph, op_node)): op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
continue is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue
op_node.op()._set_attr("quantization_type", op_node.op()._set_attr("quantization_type",
"qat_without_weight") "qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits) op_node.op()._set_attr("activation_bits", self._quant_bits)
op_node.op()._set_attr("with_quant_attr", True) op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node) arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name) in_node = graph._find_node_by_name(op_node.inputs,
if arg_name in dequantized_vars_map: arg_name)
quant_var_node = dequantized_vars_map[arg_name] if arg_name in dequantized_vars_map:
else: quant_var_node = dequantized_vars_map[arg_name]
quant_var_node, _ = \ else:
self._inser_quant_dequant_moving_average_abs_max_op( quant_var_node, _ = \
graph, in_node, self._quant_bits) self._inser_quant_dequant_moving_average_abs_max_op(
dequantized_vars_map[arg_name] = quant_var_node graph, in_node, self._quant_bits)
graph.update_input_link(in_node, quant_var_node, op_node) dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node,
op_node)
t.update()
# Backward stage, update input link # Backward stage, update input link
for op_node in all_op_nodes: for op_node in all_op_nodes:
...@@ -2122,10 +2145,10 @@ class QuantizationTransformPassV2(object): ...@@ -2122,10 +2145,10 @@ class QuantizationTransformPassV2(object):
else False else False
# if var node is weight and weight_preprocess_func is not None, # if var node is weight and weight_preprocess_func is not None,
# will insert weight preprocess func # will insert weight preprocess func
# to preorocess weight before quantization # to preorocess weight before quantization
# if var node is activation and act_preprocess_func is not None, # if var node is activation and act_preprocess_func is not None,
# will insert activation preprocess func # will insert activation preprocess func
# to preorocess activation before quantization # to preorocess activation before quantization
if is_weight and self._weight_preprocess_func is not None: if is_weight and self._weight_preprocess_func is not None:
var_node = self._insert_func( var_node = self._insert_func(
...@@ -2240,10 +2263,16 @@ class QuantizationTransformPassV2(object): ...@@ -2240,10 +2263,16 @@ class QuantizationTransformPassV2(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: with tqdm(
if op.name() in self._quantizable_ops: total=len(ops),
if not self._is_skip_quant(graph, op) and self._has_weight(op): bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}',
self._transform_forward(graph, op) ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph,
op) and self._has_weight(op):
self._transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops and self._has_weight(op): if op.name() in self._quantizable_grad_ops and self._has_weight(op):
...@@ -2346,43 +2375,50 @@ class AddQuantDequantPassV2(object): ...@@ -2346,43 +2375,50 @@ class AddQuantDequantPassV2(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes: with tqdm(
if op_node.name() in self._quantizable_op_type: total=len(all_op_nodes),
is_skip = False bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
if isinstance(self._skip_pattern, list): ncols=80) as t:
is_skip = op_node.op().has_attr("op_namescope") and \ for op_node in all_op_nodes:
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) if op_node.name() in self._quantizable_op_type:
elif isinstance(self._skip_pattern, str): is_skip = False
is_skip = op_node.op().has_attr("op_namescope") and \ if isinstance(self._skip_pattern, list):
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 is_skip = op_node.op().has_attr("op_namescope") and \
is_quantized = op_node.op().has_attr("quantization_type") and \ any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
op_node.op().attr("quantization_type") == "qat_with_weight" elif isinstance(self._skip_pattern, str):
if is_skip or is_quantized: is_skip = op_node.op().has_attr("op_namescope") and \
continue op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op()._set_attr("quantization_type", op_node.op().attr("quantization_type") == "qat_with_weight"
"qat_without_weight") if is_skip or is_quantized:
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if in_node.persistable():
continue continue
if arg_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[arg_name] op_node.op()._set_attr("quantization_type",
else: "qat_without_weight")
insert_quant_pass = InsertQuantizeLinear( arg_names = utils._get_op_input_var_names(op_node)
self._place, for arg_name in arg_names:
self._scope, in_node = graph._find_node_by_name(op_node.inputs,
quant_bits=self._quant_bits, arg_name)
quant_axis=-1, if in_node.persistable():
channel_wise=False, continue
is_test=self._is_test) if arg_name in dequantized_vars_map:
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( dequant_var_node = dequantized_vars_map[arg_name]
graph, in_node) else:
dequant_var_node = insert_quant_pass.insert_dequant_op( insert_quant_pass = InsertQuantizeLinear(
graph, quant_var_node, scale_var_node) self._place,
dequantized_vars_map[arg_name] = dequant_var_node self._scope,
graph.update_input_link(in_node, dequant_var_node, op_node) quant_bits=self._quant_bits,
quant_axis=-1,
channel_wise=False,
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node)
dequantized_vars_map[arg_name] = dequant_var_node
graph.update_input_link(in_node, dequant_var_node,
op_node)
t.update()
# Backward stage, update input link # Backward stage, update input link
for op_node in all_op_nodes: for op_node in all_op_nodes:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import numpy as np import numpy as np
from ....framework import IrNode from ....framework import IrNode
from ....framework import Operator from ....framework import Operator
...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [ ...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu", "leaky_relu",
"tanh", "tanh",
"swish", "swish",
"scale",
"transpose", "transpose",
"transpose2", "transpose2",
"sigmoid", "sigmoid",
...@@ -162,7 +162,6 @@ _op_real_in_out_name = { ...@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]], "sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]], "elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]], "elementwise_pow": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]], "hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]], "hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]], "gru": [["Input", "Weight"], ["Hidden"]],
...@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): ...@@ -414,3 +413,26 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \ cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten())) / (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim return cos_sim
class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total
self.bar_format = bar_format
self.ncols = ncols
self.n = 0
def update(self, n=1):
self.n += n
a = "=" * round((self.n / self.total) * self.ncols)
b = " " * (self.ncols - len(a))
prefix = self.bar_format.split('|')[0]
sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n,
self.total))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册