提交 32c842bb 编写于 作者: S sunyanfang01

add quant log

上级 2484756a
......@@ -19,6 +19,9 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import paddlex.utils.logging as logging
import paddle.fluid as fluid
import os
import re
import numpy as np
import datetime
class PaddleXPostTrainingQuantization(PostTrainingQuantization):
......@@ -123,28 +126,37 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
the program of quantized model.
'''
self._preprocess()
batch_ct = 0
for data in self._data_loader():
batch_ct += 1
if self._batch_nums and batch_ct >= self._batch_nums:
break
batch_id = 0
logging.info("Start to run batch!")
for data in self._data_loader():
start = datetime.datetime.now()
self._executor.run(
program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data(batch_id)
if batch_id % 5 == 0:
logging.info("run batch: {}".format(batch_id))
end = datetime.datetime.now()
logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} ms.'.format(
str(batch_id + 1),
str(batch_ct),
str((end-start).microseconds)))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
logging.info("all run batch: ".format(batch_id))
logging.info("calculate scale factor ...")
logging.info("All run batch: ".format(batch_id))
logging.info("Calculate scale factor ...")
self._calculate_scale_factor()
logging.info("update the program ...")
logging.info("Update the program ...")
self._update_program()
logging.info("Save ...")
self._save_output_scale()
logging.info("Finish quant!")
return self._program
def save_quantized_model(self, save_model_path):
......@@ -221,3 +233,69 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
def _calculate_scale_factor(self):
'''
Calculate the scale factor of quantized variables.
'''
# apply channel_wise_abs_max quantization for weights
ct = 1
for var_name in self._quantized_weight_var_name:
start = datetime.datetime.now()
data = self._sampling_data[var_name]
scale_factor_per_channel = []
for i in range(data.shape[0]):
abs_max_value = np.max(np.abs(data[i]))
scale_factor_per_channel.append(abs_max_value)
self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel
end = datetime.datetime.now()
logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} ms.'.format(
str(ct),
str(len(self._quantized_weight_var_name)),
str((end-start).microseconds)))
ct += 1
ct = 1
# apply kl quantization for activation
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
start = datetime.datetime.now()
sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
end = datetime.datetime.now()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} ms.'.format(
str(ct),
str(len(self._quantized_act_var_name)),
str((end-start).microseconds)))
ct += 1
else:
for var_name in self._quantized_act_var_name:
start = datetime.datetime.now()
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
end = datetime.datetime.now()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} ms.'.format(
str(ct),
str(len(self._quantized_act_var_name)),
str((end-start).microseconds)))
ct += 1
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册