未验证 提交 29b63f0a 编写于 作者: J juncaipeng 提交者: GitHub

support set model_filename and params_filename in post_training_quantization, test=develop (#21213)

* support set model_filename and params_filename in post_training_quantization, test=develop
上级 ccbdd7aa
...@@ -34,8 +34,10 @@ _logger = get_logger( ...@@ -34,8 +34,10 @@ _logger = get_logger(
class PostTrainingQuantization(object): class PostTrainingQuantization(object):
def __init__(self, def __init__(self,
executor, executor,
model_path, sample_generator,
data_reader, model_dir,
model_filename=None,
params_filename=None,
batch_size=10, batch_size=10,
batch_nums=None, batch_nums=None,
scope=None, scope=None,
...@@ -51,13 +53,22 @@ class PostTrainingQuantization(object): ...@@ -51,13 +53,22 @@ class PostTrainingQuantization(object):
Args: Args:
executor(fluid.Executor): The executor to load, run and save the executor(fluid.Executor): The executor to load, run and save the
quantized model. quantized model.
model_path(str): The path of fp32 model that will be quantized. sample_generator(Python Generator): The sample generator provides
data_reader(Reader): The data reader generates a sample every time, calibrate data for DataLoader, and it only returns a sample every
and it provides calibrate data for DataLoader. time.
batch_size(int, optional): The batch size of DataLoader, default is 10. model_dir(str): The path of the fp32 model that will be quantized,
batch_nums(int, optional): If set batch_nums, the number of calibrate and the model and params files are under the path.
data is batch_size*batch_nums. If batch_nums=None, use all data model_filename(str, optional): The name of file to load the inference
provided by data_reader as calibrate data. program. If it is None, the default filename '__model__' will
be used. Default is 'None'.
params_filename(str, optional): The name of file to load all parameters.
When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope(). and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to algo(str, optional): If algo=KL, use KL-divergenc method to
...@@ -79,18 +90,29 @@ class PostTrainingQuantization(object): ...@@ -79,18 +90,29 @@ class PostTrainingQuantization(object):
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
model_path = load_fp32_model_path model_dir = path/to/fp32_model_params
save_model_path = save_int8_path # set model_filename as None when the filename is __model__,
data_reader = your_data_reader # otherwise set it as the real filename
model_filename = None
# set params_filename as None when all parameters were saved in
# separate files, otherwise set it as the real filename
params_filename = None
save_model_path = path/to/save_model_path
# prepare the sample generator according to the model, and the
# sample generator must return a simple every time. The reference
# document: https://www.paddlepaddle.org.cn/documentation/docs/zh
# /user_guides/howto/prepare_data/use_py_reader.html
sample_generator = your_sample_generator
batch_size = 10 batch_size = 10
batch_nums = 10 batch_nums = 10
algo = "KL" algo = "KL"
quantizable_op_type = ["conv2d", \ quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
"depthwise_conv2d", "mul", "pool2d", "elementwise_add"]
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(
executor=exe, executor=exe,
model_path=model_path, sample_generator=sample_generator,
data_reader=data_reader, model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size, batch_size=batch_size,
batch_nums=batch_nums, batch_nums=batch_nums,
algo=algo, algo=algo,
...@@ -99,8 +121,10 @@ class PostTrainingQuantization(object): ...@@ -99,8 +121,10 @@ class PostTrainingQuantization(object):
ptq.save_quantized_model(save_model_path) ptq.save_quantized_model(save_model_path)
''' '''
self._executor = executor self._executor = executor
self._model_path = model_path self._sample_generator = sample_generator
self._data_reader = data_reader self._model_dir = model_dir
self._model_filename = model_filename
self._params_filename = params_filename
self._batch_size = batch_size self._batch_size = batch_size
self._batch_nums = batch_nums self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope self._scope = global_scope() if scope == None else scope
...@@ -148,7 +172,8 @@ class PostTrainingQuantization(object): ...@@ -148,7 +172,8 @@ class PostTrainingQuantization(object):
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
fetch_list=self._fetch_list) fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data() self._sample_data()
if batch_id % 5 == 0: if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id)) _logger.info("run batch: " + str(batch_id))
...@@ -189,13 +214,16 @@ class PostTrainingQuantization(object): ...@@ -189,13 +214,16 @@ class PostTrainingQuantization(object):
''' '''
# load model and set data loader # load model and set data loader
[self._program, self._feed_list, self._fetch_list] = \ [self._program, self._feed_list, self._fetch_list] = \
io.load_inference_model(self._model_path, self._executor) io.load_inference_model(dirname=self._model_dir,
executor=self._executor,
model_filename=self._model_filename,
params_filename=self._params_filename)
feed_vars = [framework._get_var(str(var_name), self._program) \ feed_vars = [framework._get_var(str(var_name), self._program) \
for var_name in self._feed_list] for var_name in self._feed_list]
self._data_loader = io.DataLoader.from_generator( self._data_loader = io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_generator( self._data_loader.set_sample_generator(
self._data_reader, self._sample_generator,
batch_size=self._batch_size, batch_size=self._batch_size,
drop_last=True, drop_last=True,
places=self._place) places=self._place)
...@@ -348,11 +376,11 @@ class PostTrainingQuantization(object): ...@@ -348,11 +376,11 @@ class PostTrainingQuantization(object):
if op.type in self._quantizable_op_type: if op.type in self._quantizable_op_type:
output_name_list = self._op_real_in_out_name[op.type][1] output_name_list = self._op_real_in_out_name[op.type][1]
for output_name in output_name_list: for output_name in output_name_list:
output_var_name = op.output(output_name)[0] for output_var_name in op.output(output_name):
if output_var_name in self._quantized_var_scale_factor: if output_var_name in self._quantized_var_scale_factor:
op._set_attr( op._set_attr(output_scale_name,
output_scale_name, self._quantized_var_scale_factor[
self._quantized_var_scale_factor[output_var_name]) output_var_name])
def _load_var_value(self, var_name): def _load_var_value(self, var_name):
''' '''
......
...@@ -256,9 +256,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -256,9 +256,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(
executor=exe, executor=exe,
scope=scope, sample_generator=val_reader,
model_path=model_path, model_dir=model_path,
data_reader=val_reader,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize) is_full_quantize=is_full_quantize)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册