未验证 提交 52679889 编写于 作者: J Jiaqi Liu 提交者: GitHub

cherry pick from #38686 and solve conflicts (#38729)

上级 b533090e
......@@ -17,6 +17,7 @@ import re
import logging
import numpy as np
import shutil
from inspect import isgeneratorfunction
from .... import io
from .... import core
from .... import framework
......@@ -136,6 +137,7 @@ class PostTrainingQuantization(object):
params_filename=None,
batch_generator=None,
sample_generator=None,
data_loader=None,
batch_size=10,
batch_nums=None,
algo="KL",
......@@ -175,6 +177,9 @@ class PostTrainingQuantization(object):
calibrate data for DataLoader, and it only returns a sample every
time. Note that, sample_generator and batch_generator, only one
should be set. Beisdes, sample_generator dose not support lod tensor.
data_loader(Python Generator, Paddle.io.DataLoader, optional): The
Generator or Dataloader provides calibrate data, and it could
return a batch every time.
batch_size(int, optional): The batch size of DataLoader. Default is 10.
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
......@@ -279,8 +284,11 @@ class PostTrainingQuantization(object):
assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None."
assert any([gen is not None] for gen in [sample_generator,
batch_generator]), "The sample_generator and batch_generator " \
"cannot be None in the same time."
batch_generator, data_loader]), "The sample_generator, batch_generator " \
"and data_loader cannot be None in the same time."
if data_loader is not None:
assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction))), \
"data_loader only accepts `paddle.io.DataLoader` or Generator instance."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
"The algo should be KL, hist, mse, avg, abs_max or min_max."
......@@ -323,7 +331,7 @@ class PostTrainingQuantization(object):
self._program = None
self._feed_list = None
self._fetch_list = None
self._data_loader = None
self._data_loader = data_loader
self._out_scale_op_list = _out_scale_op_list
self._quantized_weight_var_name = set()
......@@ -460,6 +468,9 @@ class PostTrainingQuantization(object):
feed_vars = [framework._get_var(str(var_name), self._program) \
for var_name in self._feed_list]
if self._data_loader is not None:
return
self._data_loader = io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
if self._sample_generator is not None:
......
......@@ -115,17 +115,28 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file=False,
is_optimize_model=False,
batch_size=10,
batch_nums=10):
batch_nums=10,
is_data_loader=False):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = paddle.dataset.mnist.train()
def val_data_generator():
batches = []
for data in val_reader():
batches.append(data[0].reshape(1, 28, 28))
if len(batches) == batch_size:
batches = np.asarray(batches)
yield {"img": batches}
batches = []
ptq = PostTrainingQuantization(
executor=exe,
model_dir=model_path,
sample_generator=val_reader,
sample_generator=val_reader if not is_data_loader else None,
data_loader=val_data_generator if is_data_loader else None,
batch_size=batch_size,
batch_nums=batch_nums,
algo=algo,
......@@ -148,7 +159,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
diff_threshold,
batch_size=10,
infer_iterations=10,
quant_iterations=5):
quant_iterations=5,
is_data_loader=False):
origin_model_path = self.download_model(data_url, data_md5, model_name)
origin_model_path = os.path.join(origin_model_path, model_name)
......@@ -161,8 +173,15 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size))
self.generate_quantized_model(
origin_model_path, algo, quantizable_op_type, is_full_quantize,
is_use_cache_file, is_optimize_model, batch_size, quant_iterations)
origin_model_path,
algo,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
batch_size,
quant_iterations,
is_data_loader=is_data_loader)
print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size))
......@@ -283,6 +302,21 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
diff_threshold, batch_size, infer_iterations,
quant_iterations)
self.run_test(
model_name,
data_url,
data_md5,
algo,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
is_data_loader=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册