提交 d143f70a 编写于 作者: C cc 提交者: GitHub

Post_training_quantization support set quant 8/16 bits (#22492)

* post_training_quantization support set bits, test=develop

* up, test=develop
上级 de009152
...@@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value): ...@@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value):
class PostTrainingQuantization(object): class PostTrainingQuantization(object):
def __init__(self, def __init__(self,
executor, executor=None,
sample_generator, scope=None,
model_dir, model_dir=None,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
sample_generator=None,
batch_size=10, batch_size=10,
batch_nums=None, batch_nums=None,
scope=None,
algo="KL", algo="KL",
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False, is_full_quantize=False,
weight_bits=8,
activation_bits=8,
is_use_cache_file=False, is_use_cache_file=False,
cache_dir="./temp_post_training"): cache_dir="./temp_post_training"):
''' '''
...@@ -76,9 +78,8 @@ class PostTrainingQuantization(object): ...@@ -76,9 +78,8 @@ class PostTrainingQuantization(object):
Args: Args:
executor(fluid.Executor): The executor to load, run and save the executor(fluid.Executor): The executor to load, run and save the
quantized model. quantized model.
sample_generator(Python Generator): The sample generator provides scope(fluid.Scope, optional): The scope of the program, use it to load
calibrate data for DataLoader, and it only returns a sample every and save variables. If scope=None, get scope by global_scope().
time.
model_dir(str): The path of the fp32 model that will be quantized, model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path. and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference model_filename(str, optional): The name of file to load the inference
...@@ -88,12 +89,13 @@ class PostTrainingQuantization(object): ...@@ -88,12 +89,13 @@ class PostTrainingQuantization(object):
When all parameters were saved in a single binary file, set it When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files, as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'. set it as 'None'. Default is 'None'.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every
time.
batch_size(int, optional): The batch size of DataLoader. Default is 10. batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data. all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL. abs_max methon to get the scale factor. Default is KL.
...@@ -104,6 +106,8 @@ class PostTrainingQuantization(object): ...@@ -104,6 +106,8 @@ class PostTrainingQuantization(object):
apply quantization to all supported quantizable op type. If set apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type. according to the input quantizable_op_type.
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
is_use_cache_file(bool, optional): If set is_use_cache_file as False, is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True, all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or it will save temp data to disk. When the fp32 model is complex or
...@@ -150,14 +154,20 @@ class PostTrainingQuantization(object): ...@@ -150,14 +154,20 @@ class PostTrainingQuantization(object):
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(save_model_path) ptq.save_quantized_model(save_model_path)
''' '''
assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None."
assert sample_generator is not None, \
"The sample_generator cannot be None."
self._executor = executor self._executor = executor
self._sample_generator = sample_generator self._scope = global_scope() if scope == None else scope
self._model_dir = model_dir self._model_dir = model_dir
self._model_filename = model_filename self._model_filename = model_filename
self._params_filename = params_filename self._params_filename = params_filename
self._sample_generator = sample_generator
self._batch_size = batch_size self._batch_size = batch_size
self._batch_nums = batch_nums self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope
self._algo = algo self._algo = algo
self._is_use_cache_file = is_use_cache_file self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir self._cache_dir = cache_dir
...@@ -604,7 +614,7 @@ class WeightQuantization(object): ...@@ -604,7 +614,7 @@ class WeightQuantization(object):
save_model_filename=None, save_model_filename=None,
save_params_filename=None, save_params_filename=None,
quantizable_op_type=["conv2d", "mul"], quantizable_op_type=["conv2d", "mul"],
quantize_weight_bits=8, weight_bits=8,
threshold_rate=0.0): threshold_rate=0.0):
''' '''
In order to reduce the size of model, this api quantizes the weight In order to reduce the size of model, this api quantizes the weight
...@@ -624,8 +634,8 @@ class WeightQuantization(object): ...@@ -624,8 +634,8 @@ class WeightQuantization(object):
that will be quantized, and the quantized ops should be that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"]. contained in ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d","mul"]. Default is ["conv2d","mul"].
quantize_weight_bits(int, optional): The bits for the quantized weight_bits(int, optional): The bits for the quantized weight,
weight, and it should be 8 or 16. Default is 8. and it should be 8 or 16. Default is 8.
threshold_rate(float, optional): This api uses abs_max methd to threshold_rate(float, optional): This api uses abs_max methd to
quantize the weight from float32 to int8/16, and the abs max quantize the weight from float32 to int8/16, and the abs max
value is important for quantization diff. When the abs_max value is important for quantization diff. When the abs_max
...@@ -637,10 +647,10 @@ class WeightQuantization(object): ...@@ -637,10 +647,10 @@ class WeightQuantization(object):
assert op_type in self._supported_quantizable_op_type, \ assert op_type in self._supported_quantizable_op_type, \
"input error:" + op_type + \ "input error:" + op_type + \
" is not supported for weight quantization." " is not supported for weight quantization."
assert quantize_weight_bits in [8, 16], \ assert weight_bits in [8, 16], \
"input error: quantize_weight_bits should be 8 or 16." "input error: weight_bits should be 8 or 16."
quantize_range = (1 << (quantize_weight_bits - 1)) - 1 quantize_range = (1 << (weight_bits - 1)) - 1
save_weight_dtype = np.int8 if quantize_weight_bits == 8 else np.int16 save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
place = core.CPUPlace() place = core.CPUPlace()
exe = Executor(place) exe = Executor(place)
...@@ -677,8 +687,7 @@ class WeightQuantization(object): ...@@ -677,8 +687,7 @@ class WeightQuantization(object):
_set_variable_data(scope, place, var_name, _set_variable_data(scope, place, var_name,
quantized_var_tensor_data) quantized_var_tensor_data)
op._set_attr(var_name + "_quant_scale", [scale]) op._set_attr(var_name + "_quant_scale", [scale])
op._set_attr('quantize_weight_bits', op._set_attr('quantize_weight_bits', weight_bits)
quantize_weight_bits)
io.save_inference_model( io.save_inference_model(
dirname=save_model_dir, dirname=save_model_dir,
......
...@@ -42,8 +42,8 @@ class TestWeightQuantization(unittest.TestCase): ...@@ -42,8 +42,8 @@ class TestWeightQuantization(unittest.TestCase):
zip_path) zip_path)
os.system(cmd) os.system(cmd)
def run_test(self, model_name, model_data_url, model_data_md5, def run_test(self, model_name, model_data_url, model_data_md5, weight_bits,
quantize_weight_bits, quantizable_op_type, threshold_rate): quantizable_op_type, threshold_rate):
model_dir = self.download_model(model_name, model_data_url, model_dir = self.download_model(model_name, model_data_url,
model_data_md5) model_data_md5)
...@@ -51,11 +51,11 @@ class TestWeightQuantization(unittest.TestCase): ...@@ -51,11 +51,11 @@ class TestWeightQuantization(unittest.TestCase):
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
save_model_dir = os.path.join( save_model_dir = os.path.join(
os.getcwd(), os.getcwd(),
model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp) model_name + "_wq_" + str(weight_bits) + "_" + timestamp)
weight_quant = WeightQuantization(model_dir=model_dir + "/model") weight_quant = WeightQuantization(model_dir=model_dir + "/model")
weight_quant.quantize_weight_to_int( weight_quant.quantize_weight_to_int(
save_model_dir=save_model_dir, save_model_dir=save_model_dir,
quantize_weight_bits=quantize_weight_bits, weight_bits=weight_bits,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
threshold_rate=threshold_rate) threshold_rate=threshold_rate)
print("finish weight quantization for " + model_name + "\n") print("finish weight quantization for " + model_name + "\n")
...@@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization): ...@@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b" model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
def test_weight_quantization_mobilenetv1_8bit(self): def test_weight_quantization_mobilenetv1_8bit(self):
quantize_weight_bits = 8 weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 0.0 threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5, self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate) weight_bits, quantizable_op_type, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit(self): def test_weight_quantization_mobilenetv1_16bit(self):
quantize_weight_bits = 16 weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul'] quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 1e-9 threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5, self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate) weight_bits, quantizable_op_type, threshold_rate)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册