未验证 提交 29b63f0a 编写于 作者: J juncaipeng 提交者: GitHub

support set model_filename and params_filename in post_training_quantization, test=develop (#21213)

* support set model_filename and params_filename in post_training_quantization, test=develop
上级 ccbdd7aa
......@@ -34,8 +34,10 @@ _logger = get_logger(
class PostTrainingQuantization(object):
def __init__(self,
executor,
model_path,
data_reader,
sample_generator,
model_dir,
model_filename=None,
params_filename=None,
batch_size=10,
batch_nums=None,
scope=None,
......@@ -51,13 +53,22 @@ class PostTrainingQuantization(object):
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
model_path(str): The path of fp32 model that will be quantized.
data_reader(Reader): The data reader generates a sample every time,
and it provides calibrate data for DataLoader.
batch_size(int, optional): The batch size of DataLoader, default is 10.
batch_nums(int, optional): If set batch_nums, the number of calibrate
data is batch_size*batch_nums. If batch_nums=None, use all data
provided by data_reader as calibrate data.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every
time.
model_dir(str): The path of the fp32 model that will be quantized,
and the model and params files are under the path.
model_filename(str, optional): The name of file to load the inference
program. If it is None, the default filename '__model__' will
be used. Default is 'None'.
params_filename(str, optional): The name of file to load all parameters.
When all parameters were saved in a single binary file, set it
as the real filename. If parameters were saved in separate files,
set it as 'None'. Default is 'None'.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
......@@ -79,18 +90,29 @@ class PostTrainingQuantization(object):
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
exe = fluid.Executor(fluid.CPUPlace())
model_path = load_fp32_model_path
save_model_path = save_int8_path
data_reader = your_data_reader
model_dir = path/to/fp32_model_params
# set model_filename as None when the filename is __model__,
# otherwise set it as the real filename
model_filename = None
# set params_filename as None when all parameters were saved in
# separate files, otherwise set it as the real filename
params_filename = None
save_model_path = path/to/save_model_path
# prepare the sample generator according to the model, and the
# sample generator must return a simple every time. The reference
# document: https://www.paddlepaddle.org.cn/documentation/docs/zh
# /user_guides/howto/prepare_data/use_py_reader.html
sample_generator = your_sample_generator
batch_size = 10
batch_nums = 10
algo = "KL"
quantizable_op_type = ["conv2d", \
"depthwise_conv2d", "mul", "pool2d", "elementwise_add"]
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
ptq = PostTrainingQuantization(
executor=exe,
model_path=model_path,
data_reader=data_reader,
sample_generator=sample_generator,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
batch_nums=batch_nums,
algo=algo,
......@@ -99,8 +121,10 @@ class PostTrainingQuantization(object):
ptq.save_quantized_model(save_model_path)
'''
self._executor = executor
self._model_path = model_path
self._data_reader = data_reader
self._sample_generator = sample_generator
self._model_dir = model_dir
self._model_filename = model_filename
self._params_filename = params_filename
self._batch_size = batch_size
self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope
......@@ -148,7 +172,8 @@ class PostTrainingQuantization(object):
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list)
fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data()
if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id))
......@@ -189,13 +214,16 @@ class PostTrainingQuantization(object):
'''
# load model and set data loader
[self._program, self._feed_list, self._fetch_list] = \
io.load_inference_model(self._model_path, self._executor)
io.load_inference_model(dirname=self._model_dir,
executor=self._executor,
model_filename=self._model_filename,
params_filename=self._params_filename)
feed_vars = [framework._get_var(str(var_name), self._program) \
for var_name in self._feed_list]
self._data_loader = io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_generator(
self._data_reader,
self._sample_generator,
batch_size=self._batch_size,
drop_last=True,
places=self._place)
......@@ -348,11 +376,11 @@ class PostTrainingQuantization(object):
if op.type in self._quantizable_op_type:
output_name_list = self._op_real_in_out_name[op.type][1]
for output_name in output_name_list:
output_var_name = op.output(output_name)[0]
if output_var_name in self._quantized_var_scale_factor:
op._set_attr(
output_scale_name,
self._quantized_var_scale_factor[output_var_name])
for output_var_name in op.output(output_name):
if output_var_name in self._quantized_var_scale_factor:
op._set_attr(output_scale_name,
self._quantized_var_scale_factor[
output_var_name])
def _load_var_value(self, var_name):
'''
......
......@@ -256,9 +256,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq = PostTrainingQuantization(
executor=exe,
scope=scope,
model_path=model_path,
data_reader=val_reader,
sample_generator=val_reader,
model_dir=model_path,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册