未验证 提交 40aa14ec 编写于 作者: C cc 提交者: GitHub

Weight quantization support channel_wise_abs_max method to achieve higher accuracy (#23629)

* Weight quantization support channel_wise_abs_max method to achieve higher accuracy
上级 1747bbdb
......@@ -57,6 +57,14 @@ def _set_variable_data(scope, place, var_name, np_value):
tensor.set(np_value, place)
def _all_persistable_var_names(program):
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
return persistable_var_names
class PostTrainingQuantization(object):
"""
Utilizing post training quantization methon to quantize the FP32 model,
......@@ -365,11 +373,7 @@ class PostTrainingQuantization(object):
else:
self._quantized_act_var_name.add(var_name)
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
persistable_var_names = _all_persistable_var_names(self._program)
for op in self._program.global_block().ops:
op_type = op.type
# For quantized ops, sample inputs and outputs
......@@ -738,6 +742,7 @@ class PostTrainingQuantization(object):
class WeightQuantization(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
_supported_weight_quantize_type = ['channel_wise_abs_max', 'abs_max']
def __init__(self, model_dir, model_filename=None, params_filename=None):
'''
......@@ -765,6 +770,8 @@ class WeightQuantization(object):
save_params_filename=None,
quantizable_op_type=["conv2d", "mul"],
weight_bits=8,
weight_quantize_type="channel_wise_abs_max",
generate_test_model=False,
threshold_rate=0.0):
'''
In order to reduce the size of model, this api quantizes the weight
......@@ -786,6 +793,13 @@ class WeightQuantization(object):
Default is ["conv2d","mul"].
weight_bits(int, optional): The bits for the quantized weight,
and it should be 8 or 16. Default is 8.
weight_quantize_type(str, optional): quantization type for weights,
support 'channel_wise_abs_max' and 'abs_max'. Set it as
'channel_wise_abs_max', the accuracy performs better.
generate_test_model(bool, optional): If set generate_test_model
as True, it saves a fake quantized model, in which the weights
are quantized and dequantized. We can use PaddlePaddle to load
the fake quantized model and test the accuracy on GPU or CPU.
threshold_rate(float, optional): This api uses abs_max methd to
quantize the weight from float32 to int8/16, and the abs max
value is important for quantization diff. When the abs_max
......@@ -795,13 +809,35 @@ class WeightQuantization(object):
'''
for op_type in quantizable_op_type:
assert op_type in self._supported_quantizable_op_type, \
"input error:" + op_type + \
"Input error:" + op_type + \
" is not supported for weight quantization."
assert weight_bits in [8, 16], \
"input error: weight_bits should be 8 or 16."
quantize_range = (1 << (weight_bits - 1)) - 1
save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
"Input error: weight_bits should be 8 or 16."
assert weight_quantize_type in self._supported_weight_quantize_type, \
"Input error: weight_quantize_type should in {}".format(
self._supported_weight_quantize_type)
quantized_model_dir = os.path.join(save_model_dir, "quantized_model")
self._quantize_weight_to_int(quantized_model_dir, save_model_filename,
save_params_filename, quantizable_op_type,
weight_bits, weight_quantize_type, False,
threshold_rate)
if generate_test_model:
test_model_dir = os.path.join(save_model_dir, "test_model")
self._quantize_weight_to_int(
test_model_dir, save_model_filename, save_params_filename,
quantizable_op_type, weight_bits, weight_quantize_type, True,
threshold_rate)
def _quantize_weight_to_int(self, save_model_dir, save_model_filename,
save_params_filename, quantizable_op_type,
weight_bits, weight_quantize_type, for_test,
threshold_rate):
"""
Generate quantized model or fake quantized model.
"""
# Load model
place = core.CPUPlace()
exe = Executor(place)
scope = global_scope()
......@@ -811,33 +847,25 @@ class WeightQuantization(object):
model_filename=self._model_filename,
params_filename=self._params_filename)
persistable_var_names = []
for var in program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op in program.global_block().ops:
if op.type in quantizable_op_type:
for var_name in op.input_arg_names:
if var_name in persistable_var_names:
var_tensor_data = _load_variable_data(scope, var_name)
if abs(threshold_rate) < 1e-10:
threshold_value = np.max(np.abs(var_tensor_data))
else:
threshold_value = self._calculate_threshold(\
var_tensor_data, threshold_rate)
var_tensor_data[var_tensor_data >
threshold_value] = threshold_value
var_tensor_data[var_tensor_data <
-threshold_value] = -threshold_value
scale = threshold_value / quantize_range
quantized_var_tensor_data = \
np.around(var_tensor_data / scale)
quantized_var_tensor_data = \
quantized_var_tensor_data.astype(save_weight_dtype)
_set_variable_data(scope, place, var_name,
quantized_var_tensor_data)
op._set_attr(var_name + "_quant_scale", [scale])
op._set_attr('quantize_weight_bits', weight_bits)
quantized_ops = []
for index in range(program.num_blocks):
block = program.block(index)
for op in block.ops:
if op.type in quantizable_op_type:
quantized_ops.append(op)
# Quantize weights
persistable_var_names = _all_persistable_var_names(program)
for op in quantized_ops:
for var_name in op.input_arg_names:
if var_name in persistable_var_names:
if weight_quantize_type == "abs_max":
self._weight_abs_max_quantization(
scope, place, weight_bits, threshold_rate, op,
var_name, for_test)
elif weight_quantize_type == "channel_wise_abs_max":
self._weight_channel_wise_abs_max_quantization(
scope, place, weight_bits, op, var_name, for_test)
io.save_inference_model(
dirname=save_model_dir,
......@@ -848,6 +876,137 @@ class WeightQuantization(object):
model_filename=save_model_filename,
params_filename=save_params_filename)
def _weight_abs_max_quantization(self, scope, place, weight_bits,
threshold_rate, op, var_name, for_test):
'''
Use abs_max method to quantize weight.
'''
quantize_range = (1 << (weight_bits - 1)) - 1
save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
# Get quantized scale and weight data
weight_data = _load_variable_data(scope, var_name)
if abs(threshold_rate) < 1e-10:
threshold_value = np.max(np.abs(weight_data))
else:
threshold_value = self._calculate_threshold(\
weight_data, threshold_rate)
weight_data[weight_data > threshold_value] = threshold_value
weight_data[weight_data < -threshold_value] = -threshold_value
scale = threshold_value / quantize_range
quantized_weight_data = \
np.around(weight_data / scale).astype(save_weight_dtype)
# Set weight data
if not for_test:
_set_variable_data(scope, place, var_name, quantized_weight_data)
else:
dequantized_weight_data = \
(quantized_weight_data * scale).astype(np.float32)
_set_variable_data(scope, place, var_name, dequantized_weight_data)
# Save info
op._set_attr('quantization_type', 'post_weight_abs_max')
op._set_attr('quantize_weight_bits', weight_bits)
op._set_attr(var_name + "_quant_scale", [scale]) # Save as list
def _weight_channel_wise_abs_max_quantization(
self, scope, place, weight_bits, op, var_name, for_test):
'''
Use channel_wise_abs_max method to quantize weight.
'''
quantize_range = (1 << (weight_bits - 1)) - 1
save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
# Get quantized scale and weight data
weight_data = _load_variable_data(scope, var_name)
if op.type == "mul":
scales, quantized_weight_data = \
self._mul_channel_wise_quantization(weight_data,
quantize_range, save_weight_dtype)
elif op.type in ["conv2d", "depthwise_conv2d"]:
scales, quantized_weight_data = \
self._conv_channel_wise_quantization(weight_data,
quantize_range, save_weight_dtype)
else:
_logger.error(op.type + " is not supported by weight quantization")
# Set weight data
if not for_test:
_set_variable_data(scope, place, var_name, quantized_weight_data)
else:
if op.type == "mul":
dequantized_weight_data = \
self._mul_channel_wise_dequantization(quantized_weight_data, scales)
elif op.type in ["conv2d", "depthwise_conv2d"]:
dequantized_weight_data = \
self._conv_channel_wise_dequantization(quantized_weight_data, scales)
else:
_logger.error(op.type +
" is not supported by weight quantization")
_set_variable_data(scope, place, var_name, dequantized_weight_data)
# Save info
op._set_attr('quantization_type', 'post_weight_channel_wise_abs_max')
op._set_attr('quantize_weight_bits', weight_bits)
op._set_attr(var_name + "_quant_scale", scales)
def _conv_channel_wise_quantization(self, weight_data, quantize_range,
save_weight_dtype):
'''
Get channel wise scale for the weights of conv2d and depthwise_conv2d,
and quantize the weights.
'''
scales = []
quantized_weight_data = np.zeros_like(
weight_data, dtype=save_weight_dtype)
channel_num = weight_data.shape[0]
for i in range(channel_num):
scale = np.max(np.abs(weight_data[i])) / quantize_range
scales.append(scale)
quantized_weight_data[i] = \
np.around(weight_data[i] / scale).astype(save_weight_dtype)
return scales, quantized_weight_data
def _conv_channel_wise_dequantization(self, quantized_weight_data, scales):
'''
For conv2d and depthwise_conv2d, dequantize the weights to fp32.
'''
dequantized_weight_data = np.zeros_like(
quantized_weight_data, dtype=np.float32)
for i in range(len(scales)):
dequantized_weight_data[i] = \
(quantized_weight_data[i] * scales[i]).astype(np.float32)
return dequantized_weight_data
def _mul_channel_wise_quantization(self, weight_data, quantize_range,
save_weight_dtype):
'''
Get channel wise scale for the weights of conv2d and depthwise_conv2d,
and quantize the weights.
'''
scales = []
quantized_weight_data = np.zeros_like(
weight_data, dtype=save_weight_dtype)
channel_num = weight_data.shape[-1]
for i in range(channel_num):
scale = np.max(np.abs(weight_data[:, i])) / quantize_range
scales.append(scale)
quantized_weight_data[:, i] = \
np.around(weight_data[:, i] / scale).astype(save_weight_dtype)
return scales, quantized_weight_data
def _mul_channel_wise_dequantization(self, quantized_weight_data, scales):
'''
For mul, dequantize the weights to fp32.
'''
dequantized_weight_data = np.zeros_like(
quantized_weight_data, dtype=np.float32)
for i in range(len(scales)):
dequantized_weight_data[:, i] = \
(quantized_weight_data[:, i] * scales[i]).astype(np.float32)
return dequantized_weight_data
def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000):
input_abs = np.abs(input)
hist, hist_edeges = np.histogram(
......
......@@ -43,7 +43,8 @@ class TestWeightQuantization(unittest.TestCase):
os.system(cmd)
def run_test(self, model_name, model_data_url, model_data_md5, weight_bits,
quantizable_op_type, threshold_rate):
quantizable_op_type, weight_quantize_type, generate_test_model,
threshold_rate):
model_dir = self.download_model(model_name, model_data_url,
model_data_md5)
......@@ -57,6 +58,8 @@ class TestWeightQuantization(unittest.TestCase):
save_model_dir=save_model_dir,
weight_bits=weight_bits,
quantizable_op_type=quantizable_op_type,
weight_quantize_type=weight_quantize_type,
generate_test_model=generate_test_model,
threshold_rate=threshold_rate)
print("finish weight quantization for " + model_name + "\n")
......@@ -72,19 +75,45 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
def test_weight_quantization_mobilenetv1_8bit(self):
def test_weight_quantization_mobilenetv1_8bit_abs_max(self):
weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "abs_max"
generate_test_model = True
threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, threshold_rate)
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit(self):
def test_weight_quantization_mobilenetv1_8bit_channel_wise_abs_max(self):
weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "channel_wise_abs_max"
generate_test_model = True
threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit_abs_max(self):
weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "abs_max"
generate_test_model = False
threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit_channel_wise_abs_max(self):
weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
weight_quantize_type = "channel_wise_abs_max"
generate_test_model = False
threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
weight_bits, quantizable_op_type, threshold_rate)
weight_bits, quantizable_op_type, weight_quantize_type,
generate_test_model, threshold_rate)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册