未验证 提交 52679889 编写于 作者: J Jiaqi Liu 提交者: GitHub

cherry pick from #38686 and solve conflicts (#38729)

上级 b533090e
...@@ -17,6 +17,7 @@ import re ...@@ -17,6 +17,7 @@ import re
import logging import logging
import numpy as np import numpy as np
import shutil import shutil
from inspect import isgeneratorfunction
from .... import io from .... import io
from .... import core from .... import core
from .... import framework from .... import framework
...@@ -136,6 +137,7 @@ class PostTrainingQuantization(object): ...@@ -136,6 +137,7 @@ class PostTrainingQuantization(object):
params_filename=None, params_filename=None,
batch_generator=None, batch_generator=None,
sample_generator=None, sample_generator=None,
data_loader=None,
batch_size=10, batch_size=10,
batch_nums=None, batch_nums=None,
algo="KL", algo="KL",
...@@ -175,6 +177,9 @@ class PostTrainingQuantization(object): ...@@ -175,6 +177,9 @@ class PostTrainingQuantization(object):
calibrate data for DataLoader, and it only returns a sample every calibrate data for DataLoader, and it only returns a sample every
time. Note that, sample_generator and batch_generator, only one time. Note that, sample_generator and batch_generator, only one
should be set. Beisdes, sample_generator dose not support lod tensor. should be set. Beisdes, sample_generator dose not support lod tensor.
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
return a batch every time.
batch_size(int, optional): The batch size of DataLoader. Default is 10. batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use calibrate data is batch_size*batch_nums. If batch_nums is None, use
...@@ -279,8 +284,11 @@ class PostTrainingQuantization(object): ...@@ -279,8 +284,11 @@ class PostTrainingQuantization(object):
assert executor is not None, "The executor cannot be None." assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None." assert model_dir is not None, "The model_dir cannot be None."
assert any([gen is not None] for gen in [sample_generator, assert any([gen is not None] for gen in [sample_generator,
batch_generator]), "The sample_generator and batch_generator " \ batch_generator, data_loader]), "The sample_generator, batch_generator " \
"cannot be None in the same time." "and data_loader cannot be None in the same time."
if data_loader is not None:
assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction))), \
"data_loader only accepts `paddle.io.DataLoader` or Generator instance."
assert batch_size > 0, "The batch_size should be greater than 0." assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \ assert algo in self._support_algo_type, \
"The algo should be KL, hist, mse, avg, abs_max or min_max." "The algo should be KL, hist, mse, avg, abs_max or min_max."
...@@ -323,7 +331,7 @@ class PostTrainingQuantization(object): ...@@ -323,7 +331,7 @@ class PostTrainingQuantization(object):
self._program = None self._program = None
self._feed_list = None self._feed_list = None
self._fetch_list = None self._fetch_list = None
self._data_loader = None self._data_loader = data_loader
self._out_scale_op_list = _out_scale_op_list self._out_scale_op_list = _out_scale_op_list
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
...@@ -460,6 +468,9 @@ class PostTrainingQuantization(object): ...@@ -460,6 +468,9 @@ class PostTrainingQuantization(object):
feed_vars = [framework._get_var(str(var_name), self._program) \ feed_vars = [framework._get_var(str(var_name), self._program) \
for var_name in self._feed_list] for var_name in self._feed_list]
if self._data_loader is not None:
return
self._data_loader = io.DataLoader.from_generator( self._data_loader = io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
if self._sample_generator is not None: if self._sample_generator is not None:
......
...@@ -115,17 +115,28 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -115,17 +115,28 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file=False, is_use_cache_file=False,
is_optimize_model=False, is_optimize_model=False,
batch_size=10, batch_size=10,
batch_nums=10): batch_nums=10,
is_data_loader=False):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
scope = fluid.global_scope() scope = fluid.global_scope()
val_reader = paddle.dataset.mnist.train() val_reader = paddle.dataset.mnist.train()
def val_data_generator():
batches = []
for data in val_reader():
batches.append(data[0].reshape(1, 28, 28))
if len(batches) == batch_size:
batches = np.asarray(batches)
yield {"img": batches}
batches = []
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(
executor=exe, executor=exe,
model_dir=model_path, model_dir=model_path,
sample_generator=val_reader, sample_generator=val_reader if not is_data_loader else None,
data_loader=val_data_generator if is_data_loader else None,
batch_size=batch_size, batch_size=batch_size,
batch_nums=batch_nums, batch_nums=batch_nums,
algo=algo, algo=algo,
...@@ -148,7 +159,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -148,7 +159,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
diff_threshold, diff_threshold,
batch_size=10, batch_size=10,
infer_iterations=10, infer_iterations=10,
quant_iterations=5): quant_iterations=5,
is_data_loader=False):
origin_model_path = self.download_model(data_url, data_md5, model_name) origin_model_path = self.download_model(data_url, data_md5, model_name)
origin_model_path = os.path.join(origin_model_path, model_name) origin_model_path = os.path.join(origin_model_path, model_name)
...@@ -161,8 +173,15 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -161,8 +173,15 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size)) format(model_name, quant_iterations * batch_size))
self.generate_quantized_model( self.generate_quantized_model(
origin_model_path, algo, quantizable_op_type, is_full_quantize, origin_model_path,
is_use_cache_file, is_optimize_model, batch_size, quant_iterations) algo,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
batch_size,
quant_iterations,
is_data_loader=is_data_loader)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size)) model_name, infer_iterations * batch_size))
...@@ -283,6 +302,21 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): ...@@ -283,6 +302,21 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
diff_threshold, batch_size, infer_iterations, diff_threshold, batch_size, infer_iterations,
quant_iterations) quant_iterations)
self.run_test(
model_name,
data_url,
data_md5,
algo,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
is_data_loader=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册