未验证 提交 5eff6f01 编写于 作者: G Guanghua Yu 提交者: GitHub

support conv1d quant & skip calibrate zero-size tensor (#48912)

上级 5d49e3e9
...@@ -398,6 +398,9 @@ class PostTrainingQuantization: ...@@ -398,6 +398,9 @@ class PostTrainingQuantization:
self._best_calibration_loss = {} self._best_calibration_loss = {}
# The threshold for algo = abs_max, mse or avg # The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {} self._quantized_threshold = {}
# If the tensor is zero-size during any calibration step,
# it will be stored in self._zero_size_var_names
self._zero_size_var_names = set()
self._same_scale_tensor_list = same_scale_tensor_list self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model self._freeze_model = freeze_model
self._scale_dict = scale_dict self._scale_dict = scale_dict
...@@ -465,9 +468,12 @@ class PostTrainingQuantization: ...@@ -465,9 +468,12 @@ class PostTrainingQuantization:
if self._algo == 'avg': if self._algo == 'avg':
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
if var_name not in self._quantized_var_avg:
continue
self._quantized_threshold[var_name] = np.array( self._quantized_threshold[var_name] = np.array(
self._quantized_var_avg[var_name] self._quantized_var_avg[var_name]
).mean() ).mean()
if self._algo in ["KL", "hist"]: if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold() self._calculate_kl_hist_threshold()
...@@ -741,6 +747,9 @@ class PostTrainingQuantization: ...@@ -741,6 +747,9 @@ class PostTrainingQuantization:
_logger.info("MSE searching stage ...") _logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten() var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
...@@ -792,6 +801,9 @@ class PostTrainingQuantization: ...@@ -792,6 +801,9 @@ class PostTrainingQuantization:
_logger.info("EMD searching stage ...") _logger.info("EMD searching stage ...")
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = var_tensor.flatten() var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value abs_max_value = 1e-8 if abs_max_value == 0.0 else abs_max_value
...@@ -845,6 +857,9 @@ class PostTrainingQuantization: ...@@ -845,6 +857,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
if var_name not in self._quantized_var_avg: if var_name not in self._quantized_var_avg:
self._quantized_var_avg[var_name] = [] self._quantized_var_avg[var_name] = []
...@@ -857,7 +872,6 @@ class PostTrainingQuantization: ...@@ -857,7 +872,6 @@ class PostTrainingQuantization:
) )
) )
self._quantized_var_avg[var_name].append(abs_avg_value) self._quantized_var_avg[var_name].append(abs_avg_value)
continue
def _sample_abs_max(self): def _sample_abs_max(self):
if self._quantized_threshold == {}: if self._quantized_threshold == {}:
...@@ -884,6 +898,9 @@ class PostTrainingQuantization: ...@@ -884,6 +898,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_threshold) or ( if (var_name not in self._quantized_threshold) or (
abs_max_value > self._quantized_threshold[var_name] abs_max_value > self._quantized_threshold[var_name]
...@@ -916,6 +933,9 @@ class PostTrainingQuantization: ...@@ -916,6 +933,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
min_value = float(np.min(var_tensor)) min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor)) max_value = float(np.max(var_tensor))
if (var_name not in self._quantized_var_min) or ( if (var_name not in self._quantized_var_min) or (
...@@ -930,6 +950,11 @@ class PostTrainingQuantization: ...@@ -930,6 +950,11 @@ class PostTrainingQuantization:
def _sample_histogram(self): def _sample_histogram(self):
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if (not var_tensor.any()) or (
var_name not in self._sampling_act_histogram
):
self._zero_size_var_names.add(var_name)
continue
var_tensor_abs = np.abs(var_tensor) var_tensor_abs = np.abs(var_tensor)
bins = self._sampling_act_histogram[var_name][1] bins = self._sampling_act_histogram[var_name][1]
hist, _ = np.histogram(var_tensor_abs, bins=bins) hist, _ = np.histogram(var_tensor_abs, bins=bins)
...@@ -964,6 +989,9 @@ class PostTrainingQuantization: ...@@ -964,6 +989,9 @@ class PostTrainingQuantization:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
q_max = 2 ** (self._activation_bits - 1) - 1 q_max = 2 ** (self._activation_bits - 1) - 1
scale8 = abs_max_value / q_max scale8 = abs_max_value / q_max
...@@ -1020,6 +1048,9 @@ class PostTrainingQuantization: ...@@ -1020,6 +1048,9 @@ class PostTrainingQuantization:
''' '''
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = utils.load_variable_data(self._scope, var_name) var_tensor = utils.load_variable_data(self._scope, var_name)
if not var_tensor.any():
self._zero_size_var_names.add(var_name)
continue
var_tensor = np.abs(var_tensor) var_tensor = np.abs(var_tensor)
min_value = float(np.min(var_tensor)) min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor)) max_value = float(np.max(var_tensor))
...@@ -1039,6 +1070,10 @@ class PostTrainingQuantization: ...@@ -1039,6 +1070,10 @@ class PostTrainingQuantization:
Based on the min/max value, init the sampling_act_histogram. Based on the min/max value, init the sampling_act_histogram.
''' '''
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_abs_min_max
):
continue
if var_name not in self._sampling_act_histogram: if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0] min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1] max_val = self._sampling_act_abs_min_max[var_name][1]
...@@ -1077,6 +1112,10 @@ class PostTrainingQuantization: ...@@ -1077,6 +1112,10 @@ class PostTrainingQuantization:
self._quantized_var_threshold[var_name] = weight_threshold self._quantized_var_threshold[var_name] = weight_threshold
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
if (var_name in self._zero_size_var_names) and (
var_name not in self._sampling_act_histogram
):
continue
hist, hist_edeges = self._sampling_act_histogram[var_name] hist, hist_edeges = self._sampling_act_histogram[var_name]
if self._algo == "KL": if self._algo == "KL":
bin_width = hist_edeges[1] - hist_edeges[0] bin_width = hist_edeges[1] - hist_edeges[0]
...@@ -1162,7 +1201,6 @@ class PostTrainingQuantization: ...@@ -1162,7 +1201,6 @@ class PostTrainingQuantization:
if self._same_scale_tensor_list is not None: if self._same_scale_tensor_list is not None:
for tensor_list in self._same_scale_tensor_list: for tensor_list in self._same_scale_tensor_list:
max_scale = None max_scale = None
tmp_tensor_list = []
for tensor_name in tensor_list: for tensor_name in tensor_list:
if '#' in tensor_name: if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split( real_tensor_name, opera, scalar = tensor_name.split(
...@@ -1261,21 +1299,40 @@ class PostTrainingQuantization: ...@@ -1261,21 +1299,40 @@ class PostTrainingQuantization:
self._calibration_scales = {} self._calibration_scales = {}
def save_info( def save_info(
op_node, out_var_name, threshold_map, out_info_name, quantized_type op_node,
out_var_name,
threshold_map,
out_info_name,
argname_index,
quantized_type,
): ):
assert ( if (out_var_name in self._zero_size_var_names) and (
out_var_name in threshold_map out_var_name not in threshold_map
), "The output ({}) of {} node does not have threshold.".format( ):
out_var_name, op_node.type _logger.warning(
) "{} is zero-size tensor and unable to calibrate, so skip quant it.".format(
out_var_name
)
)
return
else:
assert (
out_var_name in threshold_map
), "The output ({}) of {} node does not have threshold.".format(
out_var_name, op_node.type
)
if self._onnx_format: if self._onnx_format:
# For easy extension, every var_node set a dict to save parameters of quant. # For easy extension, every var_node set a dict to save parameters of quant.
self._calibration_scales[var_name] = {} self._calibration_scales[out_var_name] = {}
self._calibration_scales[var_name]['scale'] = threshold_map[ self._calibration_scales[out_var_name]['scale'] = threshold_map[
var_name out_var_name
] ]
else: else:
op_node._set_attr(out_info_name, threshold_map[var_name]) op_node._set_attr(out_info_name, threshold_map[out_var_name])
op_node._set_attr(
argname_index[0] + str(argname_index[1]) + "_threshold",
threshold_map[out_var_name],
)
op_node._set_attr("with_quant_attr", True) op_node._set_attr("with_quant_attr", True)
if op_node.type in self._quantizable_op_type: if op_node.type in self._quantizable_op_type:
op._set_attr("quantization_type", quantized_type) op._set_attr("quantization_type", quantized_type)
...@@ -1285,52 +1342,23 @@ class PostTrainingQuantization: ...@@ -1285,52 +1342,23 @@ class PostTrainingQuantization:
assert argname_index is not None, ( assert argname_index is not None, (
out_var_name + " is not the output of the op" out_var_name + " is not the output of the op"
) )
if self._algo == "KL": if self._algo in ["KL", "hist"]:
# For compatibility, we save output threshold by two methods.
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
"out_threshold",
"post_kl",
)
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl",
)
elif self._algo == "hist":
# For compatibility, we save output threshold by two methods. # For compatibility, we save output threshold by two methods.
save_info( save_info(
op_node, op_node,
out_var_name, out_var_name,
self._quantized_var_threshold, self._quantized_var_threshold,
"out_threshold", "out_threshold",
"post_hist", argname_index,
) "post_" + str(self._algo).lower(),
save_info(
op_node,
out_var_name,
self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_hist",
) )
elif self._algo in ["avg", "abs_max", "mse", "emd", "ptf"]: elif self._algo in ["avg", "abs_max", "mse", "emd", "ptf"]:
save_info( save_info(
op_node, op_node,
out_var_name, out_var_name,
self._quantized_threshold, self._quantized_threshold,
"out_threshold", "out_threshold",
"post_" + str(self._algo), argname_index,
)
save_info(
op_node,
out_var_name,
self._quantized_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_" + str(self._algo), "post_" + str(self._algo),
) )
elif self._algo == "min_max": elif self._algo == "min_max":
...@@ -1339,6 +1367,7 @@ class PostTrainingQuantization: ...@@ -1339,6 +1367,7 @@ class PostTrainingQuantization:
out_var_name, out_var_name,
self._quantized_var_min, self._quantized_var_min,
"out_min", "out_min",
argname_index,
"post_min_max", "post_min_max",
) )
save_info( save_info(
...@@ -1346,6 +1375,7 @@ class PostTrainingQuantization: ...@@ -1346,6 +1375,7 @@ class PostTrainingQuantization:
out_var_name, out_var_name,
self._quantized_var_max, self._quantized_var_max,
"out_max", "out_max",
argname_index,
"post_min_max", "post_min_max",
) )
......
...@@ -2134,7 +2134,9 @@ class InsertQuantizeLinear: ...@@ -2134,7 +2134,9 @@ class InsertQuantizeLinear:
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._scale_dict = scale_dict self._scale_dict = scale_dict
def insert_quant_op(self, graph, var_node, var_name=None): def insert_quant_op(
self, graph, var_node, var_name=None, scale_var_node=None
):
assert var_node.is_var(), '{} is not a var'.format(var_node.name()) assert var_node.is_var(), '{} is not a var'.format(var_node.name())
var_name = var_node.name() if not var_name else var_name var_name = var_node.name() if not var_name else var_name
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(
...@@ -2143,40 +2145,43 @@ class InsertQuantizeLinear: ...@@ -2143,40 +2145,43 @@ class InsertQuantizeLinear:
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
) )
data_type = ( if not scale_var_node:
'float64' data_type = (
if var_node.dtype() == core.VarDesc.VarType.FP64 'float64'
else 'float32' if var_node.dtype() == core.VarDesc.VarType.FP64
) else 'float32'
scale_name = self._quantized_scale_name(var_name)
if self.channel_wise:
scale_var_shape = var_node.shape()[self.quant_axis]
scale_var_type = core.VarDesc.VarType.LOD_TENSOR
init_scale_value = (
np.ones(scale_var_shape, dtype=data_type) * _SCALE_DEFAULT_VALUE
) )
else: scale_name = self._quantized_scale_name(var_name)
scale_var_shape = 1 if self.channel_wise:
scale_var_type = var_node.type() scale_var_shape = var_node.shape()[self.quant_axis]
init_scale_value = np.array([_SCALE_DEFAULT_VALUE], dtype=data_type) scale_var_type = core.VarDesc.VarType.LOD_TENSOR
init_scale_value = (
np.ones(scale_var_shape, dtype=data_type)
* _SCALE_DEFAULT_VALUE
)
else:
scale_var_shape = 1
scale_var_type = var_node.type()
init_scale_value = np.array(
[_SCALE_DEFAULT_VALUE], dtype=data_type
)
if ( if (
self._scale_dict is not None self._scale_dict is not None
and var_node.name() in self._scale_dict.keys() and var_node.name() in self._scale_dict.keys()
): ):
init_scale_value = np.array( init_scale_value = np.array(
[self._scale_dict[var_node.name()]], dtype=data_type [self._scale_dict[var_node.name()]], dtype=data_type
)
scale_var_node = graph.create_persistable_node(
name=scale_name,
var_type=scale_var_type,
shape=[scale_var_shape],
var_dtype=var_node.dtype(),
)
_init_var_node(
scale_var_node, init_scale_value, self._scope, self._place
) )
scale_var_node = graph.create_persistable_node(
name=scale_name,
var_type=scale_var_type,
shape=[scale_var_shape],
var_dtype=var_node.dtype(),
)
_init_var_node(
scale_var_node, init_scale_value, self._scope, self._place
)
zero_point_node = None zero_point_node = None
if zero_point_node is None: if zero_point_node is None:
...@@ -2510,6 +2515,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2510,6 +2515,7 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
def _transform_forward(self, graph, op): def _transform_forward(self, graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight") op.op()._set_attr("quantization_type", "qat_with_weight")
weight_scale_node = None
inputs = op.inputs inputs = op.inputs
for var_node in inputs: for var_node in inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
...@@ -2595,7 +2601,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2595,7 +2601,10 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
) )
self.dequantized_vars[name] = dequant_var_node self.dequantized_vars[name] = dequant_var_node
if is_weight:
weight_scale_node = scale_var_node
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
return weight_scale_node
def _transform_backward(self, graph, op): def _transform_backward(self, graph, op):
for var_node in op.inputs: for var_node in op.inputs:
...@@ -2610,11 +2619,71 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2610,11 +2619,71 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
continue continue
name = var_node.name()
if var_node.name() in self.persistable_vars: if var_node.name() in self.persistable_vars:
has_weight = True has_weight = True
return has_weight return has_weight
def _quant_conv1d(self, graph, op):
# conv1d in inference is a combination of unsqueeze2 + conv2d
if ("conv2d" not in op.name()) or (
"unsqueeze2" not in op.input("Filter")[0]
):
return
conv_weight_var_name = op.input("Filter")[0]
# unsqueeze2 and conv2d will share weight scale
weight_scale_node = None
# quant unsqueeze2
for _op in graph.all_op_nodes():
var_names = utils._get_op_output_var_names(_op)
if conv_weight_var_name in var_names and self._has_weight(_op):
weight_scale_node = self._transform_forward(graph, _op)
# insert qdq before conv2d
for var_node in op.inputs:
quant_bits = (
self._weight_bits
if var_node.name() == conv_weight_var_name
else self._activation_bits
)
quant_type = (
self._weight_quantize_type
if var_node.name() == conv_weight_var_name
else self._activation_quantize_type
)
quant_axis = -1
channel_wise = False
if quant_type == 'channel_wise_abs_max':
channel_wise = True
quant_axis = (
1 if op.name() in utils._channelwise_quant_axis1_ops else 0
)
insert_quant_pass = InsertQuantizeLinear(
self._place,
self._scope,
quant_bits=quant_bits,
quant_axis=quant_axis,
channel_wise=channel_wise,
moving_rate=self._moving_rate,
is_test=self._is_test,
)
scale_var_node = (
weight_scale_node
if var_node.name() == conv_weight_var_name
else None
)
(
quant_var_node,
scale_var_node,
) = insert_quant_pass.insert_quant_op(
graph,
var_node,
var_name=var_node.name(),
scale_var_node=scale_var_node,
)
dequant_var_node = insert_quant_pass.insert_dequant_op(
graph, quant_var_node, scale_var_node
)
graph.update_input_link(var_node, dequant_var_node, op)
def apply(self, graph): def apply(self, graph):
""" """
Quantize the graph for training process. According to weight and Quantize the graph for training process. According to weight and
...@@ -2664,6 +2733,9 @@ class QuantizationTransformPassV2(QuantizationTransformPass): ...@@ -2664,6 +2733,9 @@ class QuantizationTransformPassV2(QuantizationTransformPass):
op op
): ):
self._transform_forward(graph, op) self._transform_forward(graph, op)
else: # op is not persistable
# support conv1d quantization
self._quant_conv1d(graph, op)
t.update() t.update()
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册