未验证 提交 5eff6f01 编写于 作者: G Guanghua Yu 提交者: GitHub

support conv1d quant & skip calibrate zero-size tensor (#48912)

上级 5d49e3e9
......@@ -398,6 +398,9 @@ class PostTrainingQuantization:
self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}
# If the tensor is zero-size during any calibration step,
# it will be stored in self._zero_size_var_names
self._zero_size_var_names = set()
self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model
self._scale_dict = scale_dict
......@@ -465,9 +468,12 @@ class PostTrainingQuantization:
if self._algo == 'avg':
for var_name in self._quantized_act_var_name:
if var_name not in self._quantized_var_avg:
continue
self._quantized_threshold[var_name] = np.array(
self._quantized_var_avg[var_name]
).mean()
if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
......@@ -741,6 +747,9 @@ class PostTrainingQuantization:
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
......@@ -792,6 +801,9 @@ class PostTrainingQuantization:
_logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
......@@ -845,6 +857,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
if var_name not in self._quantized_var_avg:
self._quantized_var_avg[var_name] = []
......@@ -857,7 +872,6 @@ class PostTrainingQuantization:
)
)
self._quantized_var_avg[var_name].append(abs_avg_value)
continue
def _sample_abs_max(self):
if self._quantized_threshold == {}:
......@@ -884,6 +898,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_threshold) or (
abs_max_value > self._quantized_threshold[var_name]
......@@ -916,6 +933,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if (var_name not in self._quantized_var_min) or (
......@@ -930,6 +950,11 @@ class PostTrainingQuantization:
def _sample_histogram(self):
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if (not var_tensor.any()) or (
var_name not in self._sampling_act_histogram
):
self._zero_size_var_names.add(var_name)
continue
var_tensor_abs = np.abs(var_tensor)
bins = self._sampling_act_histogram[var_name][1]
hist, _ = np.histogram(var_tensor_abs, bins=bins)
......@@ -964,6 +989,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor)))
q_max = 2 ** (self._activation_bits - 1) - 1
scale8 = abs_max_value / q_max
......@@ -1020,6 +1048,9 @@ class PostTrainingQuantization:
'''
for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = np.abs(var_tensor)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
......@@ -1039,6 +1070,10 @@ class PostTrainingQuantization:
Based on the min/max value, init the sampling_act_histogram.
'''
for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_abs_min_max
):
continue
if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1]
......@@ -1077,6 +1112,10 @@ class PostTrainingQuantization:
self._quantized_var_threshold[var_name] = weight_threshold
for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_histogram
):
continue
hist, hist_edeges = self._sampling_act_histogram[var_name]
if self._algo == "KL":
bin_width = hist_edeges[1] - hist_edeges[0]
......@@ -1162,7 +1201,6 @@ class PostTrainingQuantization:
if self._same_scale_tensor_list is not None:
for tensor_list in self._same_scale_tensor_list:
max_scale = None
tmp_tensor_list = []
for tensor_name in tensor_list:
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
......@@ -1261,8 +1299,23 @@ class PostTrainingQuantization:
self._calibration_scales = {}
def save_info(
op_node, out_var_name, threshold_map, out_info_name, quantized_type
op_node,
out_var_name,
threshold_map,
out_info_name,
argname_index,
quantized_type,
):
if (out_var_name in self._zero_size_var_names) and (
out_var_name not in threshold_map
):
_logger.warning(
"{} is zero-size tensor and unable to calibrate, so skip quant it.".format(
out_var_name
)
)
return
else:
assert (
out_var_name in threshold_map
), "The output ({}) of {} node does not have threshold.".format(
......@@ -1270,12 +1323,16 @@ class PostTrainingQuantization:
)
if self._onnx_format:
# For easy extension, every var_node set a dict to save parameters of quant.
self._calibration_scales[var_name] = {}
self._calibration_scales[var_name]['scale'] = threshold_map[
var_name
self._calibration_scales[out_var_name] = {}
self._calibration_scales[out_var_name]['scale'] = threshold_map[
out_var_name
]
else:
op_node._set_attr(out_info_name, threshold_map[var_name])
op_node._set_attr(out_info_name, threshold_map[out_var_name])
op_node._set_attr(
argname_index[0] + str(argname_index[1]) + "_threshold",
threshold_map[out_var_name],
)
op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type)
......@@ -1285,52 +1342,23 @@ class PostTrainingQuantization:
assert argname_index is not None, (
out_var_name + " is not the output of the op"
)
if self._algo == "KL":
# For compatibility, we save output threshold by two methods.
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
"out_threshold",
"post_kl",
)
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl",
)
elif self._algo == "hist":
if self._algo in ["KL", "hist"]:
# For compatibility, we save output threshold by two methods.
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
"out_threshold",
"post_hist",
argname_index,
"post_" + str(self._algo).lower(),
)
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_hist",
)
elif self._algo in ["avg", "abs_max", "mse", "emd", "ptf"]:
save_info(
op_node,
out_var_name,
self._quantized_threshold,
"out_threshold",
"post_" + str(self._algo),
)
save_info(
op_node,
out_var_name,
self._quantized_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
argname_index,
"post_" + str(self._algo),
)
elif self._algo == "min_max":
......@@ -1339,6 +1367,7 @@ class PostTrainingQuantization:
out_var_name,
self._quantized_var_min,
"out_min",
argname_index,
"post_min_max",
)
save_info(
......@@ -1346,6 +1375,7 @@ class PostTrainingQuantization:
out_var_name,
self._quantized_var_max,
"out_max",
argname_index,
"post_min_max",
)
......
......@@ -2134,7 +2134,9 @@ class InsertQuantizeLinear:
self._moving_rate = moving_rate
self._scale_dict = scale_dict
def insert_quant_op(self, graph, var_node, var_name=None):
def insert_quant_op(
self, graph, var_node, var_name=None, scale_var_node=None
):
assert var_node.is_var(), '{} is not a var'.format(var_node.name())
var_name = var_node.name() if not var_name else var_name
quant_var_node = graph.create_var_node(
......@@ -2143,6 +2145,7 @@ class InsertQuantizeLinear:
shape=var_node.shape(),
var_dtype=var_node.dtype(),
)
if not scale_var_node:
data_type = (
'float64'
if var_node.dtype() == core.VarDesc.VarType.FP64
......@@ -2153,12 +2156,15 @@ class InsertQuantizeLinear:
scale_var_shape = var_node.shape()[self.quant_axis]
scale_var_type = core.VarDesc.VarType.LOD_TENSOR
init_scale_value = (
np.ones(scale_var_shape, dtype=data_type) * _SCALE_DEFAULT_VALUE
np.ones(scale_var_shape, dtype=data_type)
* _SCALE_DEFAULT_VALUE
)
else:
scale_var_shape = 1
scale_var_type = var_node.type()
init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type)
init_scale_value = np.array(
[_SCALE_DEFAULT_VALUE], dtype=data_type
)
if (
self._scale_dict is not None
......@@ -2167,7 +2173,6 @@ class InsertQuantizeLinear:
init_scale_value = np.array(
[self._scale_dict[var_node.name()]], dtype=data_type
)
scale_var_node = graph.create_persistable_node(
name=scale_name,
var_type=scale_var_type,
......@@ -2510,6 +2515,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
def _transform_forward(self, graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
weight_scale_node = None
inputs = op.inputs
for var_node in inputs:
if var_node.name() not in op.input_arg_names():
......@@ -2595,7 +2601,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
)
self.dequantized_vars[name] = dequant_var_node
if is_weight:
weight_scale_node = scale_var_node
graph.update_input_link(var_node, dequant_var_node, op)
return weight_scale_node
def _transform_backward(self, graph, op):
for var_node in op.inputs:
......@@ -2610,11 +2619,71 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
name = var_node.name()
if var_node.name() in self.persistable_vars:
has_weight = True
return has_weight
def _quant_conv1d(self, graph, op):
# conv1d in inference is a combination of unsqueeze2 + conv2d
if ("conv2d" not in op.name()) or (
"unsqueeze2" not in op.input("Filter")[0]
):
return
conv_weight_var_name = op.input("Filter")[0]
# unsqueeze2 and conv2d will share weight scale
weight_scale_node = None
# quant unsqueeze2
for _op in graph.all_op_nodes():
var_names = utils._get_op_output_var_names(_op)
if conv_weight_var_name in var_names and self._has_weight(_op):
weight_scale_node = self._transform_forward(graph, _op)
# insert qdq before conv2d
for var_node in op.inputs:
quant_bits = (
self._weight_bits
if var_node.name() == conv_weight_var_name
else self._activation_bits
)
quant_type = (
self._weight_quantize_type
if var_node.name() == conv_weight_var_name
else self._activation_quantize_type
)
quant_axis = -1
channel_wise = False
if quant_type == 'channel_wise_abs_max':
channel_wise = True
quant_axis = (
1 if op.name() in utils._channelwise_quant_axis1_ops else 0
)
insert_quant_pass = InsertQuantizeLinear(
self._place,
self._scope,
quant_bits=quant_bits,
quant_axis=quant_axis,
channel_wise=channel_wise,
moving_rate=self._moving_rate,
is_test=self._is_test,
)
scale_var_node = (
weight_scale_node
if var_node.name() == conv_weight_var_name
else None
)
(
quant_var_node,
scale_var_node,
) = insert_quant_pass.insert_quant_op(
graph,
var_node,
var_name=var_node.name(),
scale_var_node=scale_var_node,
)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
)
graph.update_input_link(var_node, dequant_var_node, op)
def apply(self, graph):
"""
Quantize the graph for training process. According to weight and
......@@ -2664,6 +2733,9 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
op
):
self._transform_forward(graph, op)
else: # op is not persistable
# support conv1d quantization
self._quant_conv1d(graph, op)
t.update()
# The loop for renaming the inputs of backward op.
for op in ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册