未验证 提交 1c7e35dc 编写于 作者: G Guanghua Yu 提交者: GitHub

Add progress bar and speed up Quantization Pass (#43398)

上级 5fcd8061
......@@ -17,6 +17,10 @@ import re
import logging
import numpy as np
import shutil
try:
from tqdm import tqdm
except:
from .utils import tqdm
from inspect import isgeneratorfunction
from .... import io
from .... import core
......@@ -359,38 +363,41 @@ class PostTrainingQuantization(object):
self._set_activation_persistable()
if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...")
batch_id = 0
with tqdm(
total=self._batch_nums,
bar_format=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
self._init_sampling_act_histogram()
batch_id = 0
with tqdm(total=self._batch_nums,
bar_format=
'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
self._sampling()
batch_id += 1
t.update()
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
......
......@@ -14,6 +14,10 @@
import collections
import numpy as np
try:
from tqdm import tqdm
except:
from .utils import tqdm
from ..... import compat as cpt
from .... import core
from ....framework import IrGraph
......@@ -373,10 +377,15 @@ class QuantizationTransformPass(object):
graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op)
with tqdm(total=len(ops),
bar_format=
'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops and _has_weight(op):
......@@ -1418,73 +1427,81 @@ class OutScaleForTrainingPass(object):
for op in graph.all_op_nodes():
if op.name() in self._teller_set:
target_ops.append(op)
for op in target_ops:
for output_var_name in utils._get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs, output_var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
with tqdm(total=len(target_ops),
bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in target_ops:
for output_var_name in utils._get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs,
output_var_name)
if in_node.dtype() not in \
[core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]:
continue
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
self._scope, self._place)
ins = {'X': in_node}
outs = {'OutScale': scale_node}
if not self._is_test:
state_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_state@'),
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(),
shape=[1])
_init_var_node(state_in_node, np.ones([1], dtype=data_type),
shape=[1],
var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(scale_node, np.ones([1], dtype=data_type),
self._scope, self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_accum@'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(),
shape=[1])
_init_var_node(accum_in_node, np.ones([1], dtype=data_type),
self._scope, self._place)
state_out_node = graph.create_var_node_from_desc(
state_in_node.var())
accum_out_node = graph.create_var_node_from_desc(
accum_in_node.var())
ins['InState'] = state_in_node
ins['InAccum'] = accum_in_node
outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node
attrs = {
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
attrs=attrs,
inputs=ins,
outputs=outs)
graph.link_to(in_node, scale_op_node)
graph.link_to(scale_op_node, scale_node)
if not self._is_test:
graph.link_to(state_in_node, scale_op_node)
graph.link_to(accum_in_node, scale_op_node)
graph.link_to(scale_op_node, state_out_node)
graph.link_to(scale_op_node, accum_out_node)
ins = {'X': in_node}
outs = {'OutScale': scale_node}
if not self._is_test:
state_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_state@'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(),
shape=[1])
_init_var_node(state_in_node,
np.ones([1], dtype=data_type),
self._scope, self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_accum@'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(),
shape=[1])
_init_var_node(accum_in_node,
np.ones([1], dtype=data_type),
self._scope, self._place)
state_out_node = graph.create_var_node_from_desc(
state_in_node.var())
accum_out_node = graph.create_var_node_from_desc(
accum_in_node.var())
ins['InState'] = state_in_node
ins['InAccum'] = accum_in_node
outs['OutState'] = state_out_node
outs['OutAccum'] = accum_out_node
attrs = {
'moving_rate': self._moving_rate,
'is_test': self._is_test,
'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
}
scale_op_node = graph.create_op_node(
op_type='moving_average_abs_max_scale',
attrs=attrs,
inputs=ins,
outputs=outs)
graph.link_to(in_node, scale_op_node)
graph.link_to(scale_op_node, scale_node)
if not self._is_test:
graph.link_to(state_in_node, scale_op_node)
graph.link_to(accum_in_node, scale_op_node)
graph.link_to(scale_op_node, state_out_node)
graph.link_to(scale_op_node, accum_out_node)
t.update()
return graph
def _scale_name(self, var_name):
"""
Return the scale name for the var named `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)
class OutScaleForInferencePass(object):
......@@ -1544,7 +1561,7 @@ class OutScaleForInferencePass(object):
"""
Return the scale name for the var named `var_name`.
"""
return "%s.scale" % (var_name)
return "%s@scale" % (var_name)
class AddQuantDequantPass(object):
......@@ -1624,36 +1641,43 @@ class AddQuantDequantPass(object):
# Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
is_skip = False
if isinstance(self._skip_pattern, list):
is_skip = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue
with tqdm(total=len(all_op_nodes),
bar_format=
'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
is_skip = False
if isinstance(self._skip_pattern, list):
is_skip = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name]
else:
quant_var_node, _ = \
self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, op_node)
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(
op_node.inputs, arg_name)
if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name]
else:
quant_var_node, _ = \
self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node,
op_node)
t.update()
# Backward stage, update input link
for op_node in all_op_nodes:
......@@ -2204,10 +2228,16 @@ class QuantizationTransformPassV2(object):
graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op) and self._has_weight(op):
self._transform_forward(graph, op)
with tqdm(total=len(ops),
bar_format=
'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph,
op) and self._has_weight(op):
self._transform_forward(graph, op)
t.update()
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops and self._has_weight(op):
......@@ -2310,43 +2340,50 @@ class AddQuantDequantPassV2(object):
# Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
is_skip = False
if isinstance(self._skip_pattern, list):
is_skip = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized:
continue
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if in_node.persistable():
with tqdm(total=len(all_op_nodes),
bar_format=
'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t:
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
is_skip = False
if isinstance(self._skip_pattern, list):
is_skip = op_node.op().has_attr("op_namescope") and \
any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized:
continue
if arg_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[arg_name]
else:
insert_quant_pass = InsertQuantizeLinear(
self._place,
self._scope,
quant_bits=self._quant_bits,
quant_axis=-1,
channel_wise=False,
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node)
dequantized_vars_map[arg_name] = dequant_var_node
graph.update_input_link(in_node, dequant_var_node, op_node)
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(
op_node.inputs, arg_name)
if in_node.persistable():
continue
if arg_name in dequantized_vars_map:
dequant_var_node = dequantized_vars_map[arg_name]
else:
insert_quant_pass = InsertQuantizeLinear(
self._place,
self._scope,
quant_bits=self._quant_bits,
quant_axis=-1,
channel_wise=False,
is_test=self._is_test)
quant_var_node, scale_var_node = insert_quant_pass.insert_quant_op(
graph, in_node)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node)
dequantized_vars_map[arg_name] = dequant_var_node
graph.update_input_link(in_node, dequant_var_node,
op_node)
t.update()
# Backward stage, update input link
for op_node in all_op_nodes:
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from ....framework import IrNode
from ....framework import Operator
......@@ -52,7 +53,6 @@ _act_supported_quantizable_op_type = [
"leaky_relu",
"tanh",
"swish",
"scale",
"transpose",
"transpose2",
"sigmoid",
......@@ -162,7 +162,6 @@ _op_real_in_out_name = {
"sigmoid": [["X"], ["Out"]],
"elementwise_mul": [["X", "Y"], ["Out"]],
"elementwise_pow": [["X", "Y"], ["Out"]],
"scale": [["X"], ["Out"]],
"hard_swish": [["X"], ["Out"]],
"hard_sigmoid": [["X"], ["Out"]],
"gru": [["Input", "Weight"], ["Hidden"]],
......@@ -414,3 +413,27 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
cos_sim = np.inner(orig_tensor.flatten(), qdq_tensor.flatten()) \
/ (np.linalg.norm(orig_tensor.flatten()) * np.linalg.norm(qdq_tensor.flatten()))
return cos_sim
class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total
self.bar_format = bar_format
self.ncols = ncols
self.n = 0
def update(self, n=1):
self.n += n
a = "=" * round((self.n / self.total) * self.ncols)
b = " " * (self.ncols - len(a))
prefix = self.bar_format.split('|')[0]
sys.stderr.write("\r{}|{}=>{}| {}/{}".format(prefix, a, b, self.n,
self.total))
sys.stderr.flush()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stderr.write('\n')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册