提交 0b9a4c4c 编写于 作者: S sunyanfang01

fis the post quant

上级 f7d2a00e
......@@ -15,7 +15,6 @@
from __future__ import absolute_import
import paddle.fluid as fluid
import os
import sys
import numpy as np
import time
import math
......@@ -139,9 +138,10 @@ class BaseAPI:
dataset.num_samples = batch_size * batch_num
try:
from .slim.post_quantization import PaddleXPostTrainingQuantization
PaddleXPostTrainingQuantization._collect_target_varnames
except:
raise Exception(
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0"
)
is_use_cache_file = True
if cache_dir is None:
......@@ -252,9 +252,6 @@ class BaseAPI:
del self.init_params['self']
if '__class__' in self.init_params:
del self.init_params['__class__']
if 'model_name' in self.init_params:
del self.init_params['model_name']
info['_init_params'] = self.init_params
info['_Attributes']['num_classes'] = self.num_classes
......@@ -375,8 +372,6 @@ class BaseAPI:
use_vdl=False,
early_stop=False,
early_stop_patience=5):
if train_dataset.num_samples < train_batch_size:
raise Exception('The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir):
if osp.exists(save_dir):
os.remove(save_dir)
......@@ -434,7 +429,9 @@ class BaseAPI:
if use_vdl:
# VisualDL component
log_writer = LogWriter(vdl_logdir)
log_writer = LogWriter(vdl_logdir, sync_cycle=20)
train_step_component = OrderedDict()
eval_component = OrderedDict()
thresh = 0.0001
if early_stop:
......@@ -472,7 +469,13 @@ class BaseAPI:
if use_vdl:
for k, v in step_metrics.items():
log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
if k not in train_step_component.keys():
with log_writer.mode('Each_Step_while_Training'
) as step_logger:
train_step_component[
k] = step_logger.scalar(
'Training: {}'.format(k))
train_step_component[k].add_record(num_steps, v)
# 估算剩余时间
avg_step_time = np.mean(time_stat)
......@@ -533,7 +536,12 @@ class BaseAPI:
if isinstance(v, np.ndarray):
if v.size > 1:
continue
log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
if k not in eval_component:
with log_writer.mode('Each_Epoch_on_Eval_Data'
) as eval_logger:
eval_component[k] = eval_logger.scalar(
'Evaluation: {}'.format(k))
eval_component[k].add_record(i + 1, v)
self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time()
......
......@@ -14,7 +14,7 @@
from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import paddlex.utils.logging as logging
import paddle.fluid as fluid
......@@ -44,7 +44,6 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
......@@ -78,6 +77,21 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
Returns:
None
'''
self._support_activation_quantize_type = [
'range_abs_max', 'moving_average_abs_max', 'abs_max'
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
AddQuantDequantPass._supported_quantizable_op_type))
# Check inputs
assert executor is not None, "The executor cannot be None."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max."
self._executor = executor
self._dataset = dataset
self._batch_size = batch_size
......@@ -86,18 +100,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._algo = algo
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
self._activation_bits = 8
self._weight_bits = 8
self._activation_quantize_type = 'range_abs_max'
self._weight_quantize_type = 'channel_wise_abs_max'
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type
self._quantizable_op_type = self._support_quantize_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \
assert op_type in self._support_quantize_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization."
......@@ -107,25 +122,29 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._fetch_list = list(outputs.values())
self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name
self._out_scale_op_list = _out_scale_op_list
self._bit_length = 8
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._sampling_data = {}
self._quantized_var_scale_factor = {}
self._quantized_var_kl_threshold = {}
self._quantized_var_min = {}
self._quantized_var_max = {}
self._quantized_var_abs_max = {}
def quantize(self):
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self._preprocess()
self._load_model_data()
self._collect_target_varnames()
self._set_activation_persistable()
batch_ct = 0
for data in self._data_loader():
batch_ct += 1
......@@ -140,7 +159,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
feed=data,
fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data(batch_id)
if self._algo == "KL":
self._sample_data(batch_id)
else:
self._sample_threshold()
end = time.time()
logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
str(batch_id + 1),
......@@ -150,19 +172,23 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
if self._batch_nums and batch_id >= self._batch_nums:
break
logging.info("All run batch: ".format(batch_id))
self._reset_activation_persistable()
logging.info("Calculate scale factor ...")
self._calculate_scale_factor()
if self._algo == "KL":
self._calculate_kl_threshold()
logging.info("Update the program ...")
self._update_program()
if self._algo in ["KL", "abs_max"]:
self._update_program()
else:
self._save_input_threhold()
logging.info("Save ...")
self._save_output_scale()
self._save_output_threshold()
logging.info("Finish quant!")
return self._program
def save_quantized_model(self, save_model_path):
'''
Save the quantized model to the disk.
Args:
save_model_path(str): The path to save the quantized model
Returns:
......@@ -176,88 +202,47 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
executor=self._executor,
params_filename='__params__',
main_program=self._program)
def _preprocess(self):
def _load_model_data(self):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
Set data loader.
'''
feed_vars = [fluid.framework._get_var(var.name, self._program) \
for var in self._feed_list]
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_list_generator(
self._dataset.generator(self._batch_size, drop_last=True),
places=self._place)
# collect the variable names for sampling
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op in self._program.global_block().ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul":
if self._is_input_all_not_persistable(
op, persistable_var_names):
op._set_attr("skip_quant", True)
logging.warning(
"Skip quant a mul op for two input variables are not persistable"
)
else:
self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
op, persistable_var_names):
input_output_name_list = self._op_real_in_out_name[
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
def _calculate_scale_factor(self):
def _calculate_kl_threshold(self):
'''
Calculate the scale factor of quantized variables.
Calculate the KL threshold of quantized variables.
'''
# apply channel_wise_abs_max quantization for weights
assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
ct = 1
# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name:
start = time.time()
data = self._sampling_data[var_name]
scale_factor_per_channel = []
for i in range(data.shape[0]):
abs_max_value = np.max(np.abs(data[i]))
scale_factor_per_channel.append(abs_max_value)
self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel
weight_data = self._sampling_data[var_name]
weight_threshold = None
if self._weight_quantize_type == "abs_max":
weight_threshold = np.max(np.abs(weight_data))
elif self._weight_quantize_type == "channel_wise_abs_max":
weight_threshold = []
for i in range(weight_data.shape[0]):
abs_max_value = np.max(np.abs(weight_data[i]))
weight_threshold.append(abs_max_value)
self._quantized_var_kl_threshold[var_name] = weight_threshold
end = time.time()
logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
str(ct),
str(len(self._quantized_weight_var_name)),
str(end-start)))
ct += 1
ct = 1
# apply kl quantization for activation
# KL threshold for activations
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
start = time.time()
......@@ -269,13 +254,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
str(ct),
......@@ -287,15 +267,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
start = time.time()
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
str(ct),
str(len(self._quantized_act_var_name)),
str(end-start)))
ct += 1
\ No newline at end of file
ct += 1
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册