未验证 提交 1c7e35dc 编写于 作者: G Guanghua Yu 提交者: GitHub

Add progress bar and speed up Quantization Pass (#43398)

上级 5fcd8061
...@@ -17,6 +17,10 @@ import re ...@@ -17,6 +17,10 @@ import re
import logging import logging
import numpy as np import numpy as np
import shutil import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction from inspect import isgeneratorfunction
from .... import io from .... import io
from .... import core from .... import core
...@@ -359,38 +363,41 @@ class PostTrainingQuantization(object): ...@@ -359,38 +363,41 @@ class PostTrainingQuantization(object):
self._set_activation_persistable() self._set_activation_persistable()
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0 batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
self._init_sampling_act_histogram()
batch_id = 0
with tqdm(total=self._batch_nums,
bar_format=
'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False, return_numpy=False,
scope=self._scope) scope=self._scope)
self._collect_activation_abs_min_max() self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._algo == 'avg': if self._algo == 'avg':
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
import collections import collections
import numpy as np import numpy as np
try:
from tqdm import tqdm
except:
from .utils import tqdm
from ..... import compat as cpt from ..... import compat as cpt
from .... import core from .... import core
from ....framework import IrGraph from ....framework import IrGraph
...@@ -373,10 +377,15 @@ class QuantizationTransformPass(object): ...@@ -373,10 +377,15 @@ class QuantizationTransformPass(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: with tqdm(total=len(ops),
if op.name() in self._quantizable_ops: bar_format=
if not self._is_skip_quant(graph, op) and _has_weight(op): 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
_transform_forward(graph, op) ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops and _has_weight(op): if op.name() in self._quantizable_grad_ops and _has_weight(op):
...@@ -1418,73 +1427,81 @@ class OutScaleForTrainingPass(object): ...@@ -1418,73 +1427,81 @@ class OutScaleForTrainingPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._teller_set: if op.name() in self._teller_set:
target_ops.append(op) target_ops.append(op)
for op in target_ops: with tqdm(total=len(target_ops),
for output_var_name in utils._get_op_output_var_names(op): bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
in_node = graph._find_node_by_name(op.outputs, output_var_name) ncols=80) as t:
if in_node.dtype() not in \ for op in target_ops:
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: for output_var_name in utils._get_op_output_var_names(op):
continue in_node = graph._find_node_by_name(op.outputs,
output_var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
scale_node = graph.create_persistable_node( scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()), name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
self._scope, self._place)
ins = {'X': in_node}
outs = {'OutScale': scale_node}
if not self._is_test:
state_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_state@'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(), shape=[1],
shape=[1]) var_dtype=in_node.dtype())
_init_var_node(state_in_node, np.ones([1], dtype=data_type), data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
self._scope, self._place) self._scope, self._place)
accum_in_node = graph.create_persistable_node( ins = {'X': in_node}
name=unique_name.generate('scale_accum@'), outs = {'OutScale': scale_node}
var_type=core.VarDesc.VarType.LOD_TENSOR, if not self._is_test:
var_dtype=in_node.dtype(), state_in_node = graph.create_persistable_node(
shape=[1]) name=unique_name.generate('scale_state@'),
_init_var_node(accum_in_node, np.ones([1], dtype=data_type), var_type=core.VarDesc.VarType.LOD_TENSOR,
self._scope, self._place) var_dtype=in_node.dtype(),
state_out_node = graph.create_var_node_from_desc( shape=[1])
state_in_node.var()) _init_var_node(state_in_node,
accum_out_node = graph.create_var_node_from_desc( np.ones([1], dtype=data_type),
accum_in_node.var()) self._scope, self._place)
accum_in_node = graph.create_persistable_node(
ins['InState'] = state_in_node name=unique_name.generate('scale_accum@'),
ins['InAccum'] = accum_in_node var_type=core.VarDesc.VarType.LOD_TENSOR,
outs['OutState'] = state_out_node var_dtype=in_node.dtype(),
outs['OutAccum'] = accum_out_node shape=[1])
_init_var_node(accum_in_node,
attrs = { np.ones([1], dtype=data_type),
'moving_rate': self._moving_rate, self._scope, self._place)
'is_test': self._is_test, state_out_node = graph.create_var_node_from_desc(
'op_role': core.op_proto_and_checker_maker.OpRole.Forward state_in_node.var())
} accum_out_node = graph.create_var_node_from_desc(
scale_op_node = graph.create_op_node( accum_in_node.var())
op_type='moving_average_abs_max_scale',
attrs=attrs, ins['InState'] = state_in_node
inputs=ins, ins['InAccum'] = accum_in_node
outputs=outs) outs['OutState'] = state_out_node
graph.link_to(in_node, scale_op_node) outs['OutAccum'] = accum_out_node
graph.link_to(scale_op_node, scale_node)
if not self._is_test: attrs = {
graph.link_to(state_in_node, scale_op_node) 'moving_rate': self._moving_rate,
graph.link_to(accum_in_node, scale_op_node) 'is_test': self._is_test,
graph.link_to(scale_op_node, state_out_node) 'op_role':
graph.link_to(scale_op_node, accum_out_node) core.op_proto_and_checker_maker.OpRole.Forward
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
attrs=attrs,
inputs=ins,
outputs=outs)
graph.link_to(in_node, scale_op_node)
graph.link_to(scale_op_node, scale_node)
if not self._is_test:
graph.link_to(state_in_node, scale_op_node)
graph.link_to(accum_in_node, scale_op_node)
graph.link_to(scale_op_node, state_out_node)
graph.link_to(scale_op_node, accum_out_node)
t.update()
return graph return graph
def _scale_name(self, var_name): def _scale_name(self, var_name):
""" """
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
class OutScaleForInferencePass(object): class OutScaleForInferencePass(object):
...@@ -1544,7 +1561,7 @@ class OutScaleForInferencePass(object): ...@@ -1544,7 +1561,7 @@ class OutScaleForInferencePass(object):
""" """
Return the scale name for the var named `var_name`. Return the scale name for the var named `var_name`.
""" """
return "%s.scale" % (var_name) return "%s@scale" % (var_name)
class AddQuantDequantPass(object): class AddQuantDequantPass(object):
...@@ -1624,36 +1641,43 @@ class AddQuantDequantPass(object): ...@@ -1624,36 +1641,43 @@ class AddQuantDequantPass(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes: with tqdm(total=len(all_op_nodes),
if op_node.name() in self._quantizable_op_type: bar_format=
is_skip = False 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
if isinstance(self._skip_pattern, list): ncols=80) as t:
is_skip = op_node.op().has_attr("op_namescope") and \ for op_node in all_op_nodes:
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) if op_node.name() in self._quantizable_op_type:
elif isinstance(self._skip_pattern, str): is_skip = False
is_skip = op_node.op().has_attr("op_namescope") and \ if isinstance(self._skip_pattern, list):
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 is_skip = op_node.op().has_attr("op_namescope") and \
is_quantized = op_node.op().has_attr("quantization_type") and \ any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
op_node.op().attr("quantization_type") == "qat_with_weight" elif isinstance(self._skip_pattern, str):
if is_skip or is_quantized or \ is_skip = op_node.op().has_attr("op_namescope") and \
(not _is_input_all_not_persistable(graph, op_node)): op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
continue is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue
op_node.op()._set_attr("quantization_type", op_node.op()._set_attr("quantization_type",
"qat_without_weight") "qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits) op_node.op()._set_attr("activation_bits", self._quant_bits)
op_node.op()._set_attr("with_quant_attr", True) op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node) arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name) in_node = graph._find_node_by_name(
if arg_name in dequantized_vars_map: op_node.inputs, arg_name)
quant_var_node = dequantized_vars_map[arg_name] if arg_name in dequantized_vars_map:
else: quant_var_node = dequantized_vars_map[arg_name]
quant_var_node, _ = \ else:
self._inser_quant_dequant_moving_average_abs_max_op( quant_var_node, _ = \
graph, in_node, self._quant_bits) self._inser_quant_dequant_moving_average_abs_max_op(
dequantized_vars_map[arg_name] = quant_var_node graph, in_node, self._quant_bits)
graph.update_input_link(in_node, quant_var_node, op_node) dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node,
op_node)
t.update()
# Backward stage, update input link # Backward stage, update input link
for op_node in all_op_nodes: for op_node in all_op_nodes:
...@@ -2204,10 +2228,16 @@ class QuantizationTransformPassV2(object): ...@@ -2204,10 +2228,16 @@ class QuantizationTransformPassV2(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: with tqdm(total=len(ops),
if op.name() in self._quantizable_ops: bar_format=
if not self._is_skip_quant(graph, op) and self._has_weight(op): 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
self._transform_forward(graph, op) ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph,
op) and self._has_weight(op):
self._transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops and self._has_weight(op): if op.name() in self._quantizable_grad_ops and self._has_weight(op):
...@@ -2310,43 +2340,50 @@ class AddQuantDequantPassV2(object): ...@@ -2310,43 +2340,50 @@ class AddQuantDequantPassV2(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes: with tqdm(total=len(all_op_nodes),
if op_node.name() in self._quantizable_op_type: bar_format=
is_skip = False 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
if isinstance(self._skip_pattern, list): ncols=80) as t:
is_skip = op_node.op().has_attr("op_namescope") and \ for op_node in all_op_nodes:
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) if op_node.name() in self._quantizable_op_type:
elif isinstance(self._skip_pattern, str): is_skip = False
is_skip = op_node.op().has_attr("op_namescope") and \ if isinstance(self._skip_pattern, list):
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 is_skip = op_node.op().has_attr("op_namescope") and \
is_quantized = op_node.op().has_attr("quantization_type") and \ any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
op_node.op().attr("quantization_type") == "qat_with_weight" elif isinstance(self._skip_pattern, str):
if is_skip or is_quantized: is_skip = op_node.op().has_attr("op_namescope") and \
continue op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op()._set_attr("quantization_type", op_node.op().attr("quantization_type") == "qat_with_weight"
"qat_without_weight") if is_skip or is_quantized:
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if in_node.persistable():
continue continue
if arg_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[arg_name] op_node.op()._set_attr("quantization_type",
else: "qat_without_weight")
insert_quant_pass = InsertQuantizeLinear( arg_names = utils._get_op_input_var_names(op_node)
self._place, for arg_name in arg_names:
self._scope, in_node = graph._find_node_by_name(
quant_bits=self._quant_bits, op_node.inputs, arg_name)
quant_axis=-1, if in_node.persistable():
channel_wise=False, continue
is_test=self._is_test) if arg_name in dequantized_vars_map:
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op( dequant_var_node = dequantized_vars_map[arg_name]
graph, in_node) else:
dequant_var_node = insert_quant_pass.insert_dequant_op( insert_quant_pass = InsertQuantizeLinear(
graph, quant_var_node, scale_var_node) self._place,
dequantized_vars_map[arg_name] = dequant_var_node self._scope,
graph.update_input_link(in_node, dequant_var_node, op_node) quant_bits=self._quant_bits,
quant_axis=-1,
channel_wise=False,
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node)
dequantized_vars_map[arg_name] = dequant_var_node
graph.update_input_link(in_node, dequant_var_node,
op_node)
t.update()
# Backward stage, update input link # Backward stage, update input link
for op_node in all_op_nodes: for op_node in all_op_nodes:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import numpy as np import numpy as np
from ....framework import IrNode from ....framework import IrNode
from ....framework import Operator from ....framework import Operator
...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [ ...@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu", "leaky_relu",
"tanh", "tanh",
"swish", "swish",
"scale",
"transpose", "transpose",
"transpose2", "transpose2",
"sigmoid", "sigmoid",
...@@ -162,7 +162,6 @@ _op_real_in_out_name = { ...@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]], "sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]], "elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]], "elementwise_pow": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]], "hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]], "hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]], "gru": [["Input", "Weight"], ["Hidden"]],
...@@ -414,3 +413,27 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): ...@@ -414,3 +413,27 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \ cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten())) / (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim return cos_sim
class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total
self.bar_format = bar_format
self.ncols = ncols
self.n = 0
def update(self, n=1):
self.n += n
a = "=" * round((self.n / self.total) * self.ncols)
b = " " * (self.ncols - len(a))
prefix = self.bar_format.split('|')[0]
sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n,
self.total))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册