未验证 提交 1785605f 编写于 作者: J Jason 提交者: GitHub

Merge pull request #53 from SunAhong1993/syf_slim

add quant log
...@@ -19,6 +19,9 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization ...@@ -19,6 +19,9 @@ from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
import paddle.fluid as fluid import paddle.fluid as fluid
import os import os
import re
import numpy as np
import time
class PaddleXPostTrainingQuantization(PostTrainingQuantization): class PaddleXPostTrainingQuantization(PostTrainingQuantization):
...@@ -123,28 +126,37 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -123,28 +126,37 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
the program of quantized model. the program of quantized model.
''' '''
self._preprocess() self._preprocess()
batch_ct = 0
for data in self._data_loader():
batch_ct += 1
if self._batch_nums and batch_ct >= self._batch_nums:
break
batch_id = 0 batch_id = 0
logging.info("Start to run batch!")
for data in self._data_loader(): for data in self._data_loader():
start = time.time()
self._executor.run( self._executor.run(
program=self._program, program=self._program,
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False) return_numpy=False)
self._sample_data(batch_id) self._sample_data(batch_id)
end = time.time()
if batch_id % 5 == 0: logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
logging.info("run batch: {}".format(batch_id)) str(batch_id + 1),
str(batch_ct),
str(end-start)))
batch_id += 1 batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
logging.info("all run batch: ".format(batch_id)) logging.info("All run batch: ".format(batch_id))
logging.info("calculate scale factor ...") logging.info("Calculate scale factor ...")
self._calculate_scale_factor() self._calculate_scale_factor()
logging.info("update the program ...") logging.info("Update the program ...")
self._update_program() self._update_program()
logging.info("Save ...")
self._save_output_scale() self._save_output_scale()
logging.info("Finish quant!")
return self._program return self._program
def save_quantized_model(self, save_model_path): def save_quantized_model(self, save_model_path):
...@@ -221,3 +233,69 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -221,3 +233,69 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
for var in self._program.list_vars(): for var in self._program.list_vars():
if var.name in self._quantized_act_var_name: if var.name in self._quantized_act_var_name:
var.persistable = True var.persistable = True
def _calculate_scale_factor(self):
'''
Calculate the scale factor of quantized variables.
'''
# apply channel_wise_abs_max quantization for weights
ct = 1
for var_name in self._quantized_weight_var_name:
start = time.time()
data = self._sampling_data[var_name]
scale_factor_per_channel = []
for i in range(data.shape[0]):
abs_max_value = np.max(np.abs(data[i]))
scale_factor_per_channel.append(abs_max_value)
self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel
end = time.time()
logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
str(ct),
str(len(self._quantized_weight_var_name)),
str(end-start)))
ct += 1
ct = 1
# apply kl quantization for activation
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
start = time.time()
sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
str(ct),
str(len(self._quantized_act_var_name)),
str(end-start)))
ct += 1
else:
for var_name in self._quantized_act_var_name:
start = time.time()
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
str(ct),
str(len(self._quantized_act_var_name)),
str(end-start)))
ct += 1
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册