提交 d143f70a 编写于 作者: C cc 提交者: GitHub

Post_training_quantization support set quant 8/16 bits (#22492)

* post_training_quantization support set bits, test=develop

* up, test=develop
上级 de009152
......@@ -54,17 +54,19 @@ def _set_variable_data(scope, place, var_name, np_value):
class PostTrainingQuantization(object):
def __init__(self,
executor,
sample_generator,
model_dir,
executor=None,
scope=None,
model_dir=None,
model_filename=None,
params_filename=None,
sample_generator=None,
batch_size=10,
batch_nums=None,
scope=None,
algo="KL",
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
'''
......@@ -76,9 +78,8 @@ class PostTrainingQuantization(object):
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every
time.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference
......@@ -88,12 +89,13 @@ class PostTrainingQuantization(object):
When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every
time.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
......@@ -104,6 +106,8 @@ class PostTrainingQuantization(object):
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
weight_bits(int, optional): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
......@@ -150,14 +154,20 @@ class PostTrainingQuantization(object):
ptq.quantize()
ptq.save_quantized_model(save_model_path)
'''
assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None."
assert sample_generator is not None, \
"The sample_generator cannot be None."
self._executor = executor
self._sample_generator = sample_generator
self._scope = global_scope() if scope == None else scope
self._model_dir = model_dir
self._model_filename = model_filename
self._params_filename = params_filename
self._sample_generator = sample_generator
self._batch_size = batch_size
self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope
self._algo = algo
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
......@@ -604,7 +614,7 @@ class WeightQuantization(object):
save_model_filename=None,
save_params_filename=None,
quantizable_op_type=["conv2d", "mul"],
quantize_weight_bits=8,
weight_bits=8,
threshold_rate=0.0):
'''
In order to reduce the size of model, this api quantizes the weight
......@@ -624,8 +634,8 @@ class WeightQuantization(object):
that will be quantized, and the quantized ops should be
contained in ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d","mul"].
quantize_weight_bits(int, optional): The bits for the quantized
weight, and it should be 8 or 16. Default is 8.
weight_bits(int, optional): The bits for the quantized weight,
and it should be 8 or 16. Default is 8.
threshold_rate(float, optional): This api uses abs_max methd to
quantize the weight from float32 to int8/16, and the abs max
value is important for quantization diff. When the abs_max
......@@ -637,10 +647,10 @@ class WeightQuantization(object):
assert op_type in self._supported_quantizable_op_type, \
"input error:" + op_type + \
" is not supported for weight quantization."
assert quantize_weight_bits in [8, 16], \
"input error: quantize_weight_bits should be 8 or 16."
quantize_range = (1 << (quantize_weight_bits - 1)) - 1
save_weight_dtype = np.int8 if quantize_weight_bits == 8 else np.int16
assert weight_bits in [8, 16], \
"input error: weight_bits should be 8 or 16."
quantize_range = (1 << (weight_bits - 1)) - 1
save_weight_dtype = np.int8 if weight_bits == 8 else np.int16
place = core.CPUPlace()
exe = Executor(place)
......@@ -677,8 +687,7 @@ class WeightQuantization(object):
_set_variable_data(scope, place, var_name,
quantized_var_tensor_data)
op._set_attr(var_name + "_quant_scale", [scale])
op._set_attr('quantize_weight_bits',
quantize_weight_bits)
op._set_attr('quantize_weight_bits', weight_bits)
io.save_inference_model(
dirname=save_model_dir,
......
......@@ -42,8 +42,8 @@ class TestWeightQuantization(unittest.TestCase):
zip_path)
os.system(cmd)
def run_test(self, model_name, model_data_url, model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate):
def run_test(self, model_name, model_data_url, model_data_md5, weight_bits,
quantizable_op_type, threshold_rate):
model_dir = self.download_model(model_name, model_data_url,
model_data_md5)
......@@ -51,11 +51,11 @@ class TestWeightQuantization(unittest.TestCase):
timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
save_model_dir = os.path.join(
os.getcwd(),
model_name + "_wq_" + str(quantize_weight_bits) + "_" + timestamp)
model_name + "_wq_" + str(weight_bits) + "_" + timestamp)
weight_quant = WeightQuantization(model_dir=model_dir + "/model")
weight_quant.quantize_weight_to_int(
save_model_dir=save_model_dir,
quantize_weight_bits=quantize_weight_bits,
weight_bits=weight_bits,
quantizable_op_type=quantizable_op_type,
threshold_rate=threshold_rate)
print("finish weight quantization for " + model_name + "\n")
......@@ -73,18 +73,18 @@ class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"
def test_weight_quantization_mobilenetv1_8bit(self):
quantize_weight_bits = 8
weight_bits = 8
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 0.0
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate)
weight_bits, quantizable_op_type, threshold_rate)
def test_weight_quantization_mobilenetv1_16bit(self):
quantize_weight_bits = 16
weight_bits = 16
quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
threshold_rate = 1e-9
self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
quantize_weight_bits, quantizable_op_type, threshold_rate)
weight_bits, quantizable_op_type, threshold_rate)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册