未验证 提交 d2ecc862 编写于 作者: C ceci3 提交者: GitHub

fix ptq hpo (#1021)

上级 a00f830f
...@@ -55,7 +55,7 @@ def quantize(args): ...@@ -55,7 +55,7 @@ def quantize(args):
save_model_filename='__model__', save_model_filename='__model__',
save_params_filename='__params__', save_params_filename='__params__',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type='channel_wise_abs_max', weight_quantize_type=['channel_wise_abs_max'],
runcount_limit=args.max_model_quant_count) runcount_limit=args.max_model_quant_count)
def main(): def main():
......
...@@ -189,10 +189,10 @@ def eval_quant_model(): ...@@ -189,10 +189,10 @@ def eval_quant_model():
valid_data_num = 0 valid_data_num = 0
max_eval_data_num = 200 max_eval_data_num = 200
if g_quant_config.eval_sample_generator is not None: if g_quant_config.eval_sample_generator is not None:
feed_dict=False feed_dict = False
eval_dataloader = g_quant_config.eval_sample_generator eval_dataloader = g_quant_config.eval_sample_generator
else: else:
feed_dict=True feed_dict = True
eval_dataloader = g_quant_config.eval_dataloader eval_dataloader = g_quant_config.eval_dataloader
for i, data in enumerate(eval_dataloader()): for i, data in enumerate(eval_dataloader()):
with paddle.static.scope_guard(float_scope): with paddle.static.scope_guard(float_scope):
...@@ -236,12 +236,20 @@ def eval_quant_model(): ...@@ -236,12 +236,20 @@ def eval_quant_model():
def quantize(cfg): def quantize(cfg):
"""model quantize job""" """model quantize job"""
algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0] algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0]
hist_percent = cfg["hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[0][0] hist_percent = cfg[
bias_correct = cfg["bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[0][0] "hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[
batch_size = cfg["batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][0] 0][0]
batch_num = cfg["batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0] bias_correct = cfg[
weight_quantize_type = cfg["weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[0] "bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[
print(hist_percent, bias_correct, batch_size, batch_num, weight_quantize_type) 0][0]
batch_size = cfg[
"batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][
0]
batch_num = cfg[
"batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
weight_quantize_type = cfg[
"weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[
0]
quant_post( \ quant_post( \
executor=g_quant_config.executor, \ executor=g_quant_config.executor, \
...@@ -279,34 +287,35 @@ def quantize(cfg): ...@@ -279,34 +287,35 @@ def quantize(cfg):
return emd_loss return emd_loss
def quant_post_hpo(executor, def quant_post_hpo(
place, executor,
model_dir, place,
quantize_model_path, model_dir,
train_sample_generator=None, quantize_model_path,
eval_sample_generator=None, train_sample_generator=None,
train_dataloader=None, eval_sample_generator=None,
eval_dataloader=None, train_dataloader=None,
eval_function=None, eval_dataloader=None,
model_filename=None, eval_function=None,
params_filename=None, model_filename=None,
save_model_filename='__model__', params_filename=None,
save_params_filename='__params__', save_model_filename='__model__',
scope=None, save_params_filename='__params__',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], scope=None,
is_full_quantize=False, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_bits=8, is_full_quantize=False,
activation_bits=8, weight_bits=8,
weight_quantize_type=['channel_wise_abs_max'], activation_bits=8,
algo=["KL", "hist", "avg", "mse"], weight_quantize_type=['channel_wise_abs_max'],
bias_correct=[True, False], algo=["KL", "hist", "avg", "mse"],
hist_percent=[0.98, 0.999], ### uniform sample in list. bias_correct=[True, False],
batch_size=[10, 30], ### uniform sample in list. hist_percent=[0.98, 0.999], ### uniform sample in list.
batch_num=[10, 30], ### uniform sample in list. batch_size=[10, 30], ### uniform sample in list.
optimize_model=False, batch_num=[10, 30], ### uniform sample in list.
is_use_cache_file=False, optimize_model=False,
cache_dir="./temp_post_training", is_use_cache_file=False,
runcount_limit=30): cache_dir="./temp_post_training",
runcount_limit=30):
""" """
The function utilizes static post training quantization method to The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the quantize the fp32 model. It uses calibrate data to calculate the
...@@ -360,25 +369,27 @@ def quant_post_hpo(executor, ...@@ -360,25 +369,27 @@ def quant_post_hpo(executor,
global g_quant_config global g_quant_config
g_quant_config = QuantConfig( g_quant_config = QuantConfig(
executor, place, model_dir, quantize_model_path, algo, hist_percent, executor, place, model_dir, quantize_model_path, algo, hist_percent,
bias_correct, batch_size, batch_num, train_sample_generator, bias_correct, batch_size, batch_num, train_sample_generator,
eval_sample_generator, train_dataloader, eval_dataloader, eval_function, eval_sample_generator, train_dataloader, eval_dataloader, eval_function,
model_filename, params_filename, model_filename, params_filename, save_model_filename,
save_model_filename, save_params_filename, scope, quantizable_op_type, save_params_filename, scope, quantizable_op_type, is_full_quantize,
is_full_quantize, weight_bits, activation_bits, weight_quantize_type, weight_bits, activation_bits, weight_quantize_type, optimize_model,
optimize_model, is_use_cache_file, cache_dir) is_use_cache_file, cache_dir)
cs = ConfigurationSpace() cs = ConfigurationSpace()
hyper_params = [] hyper_params = []
if 'hist' in algo: if 'hist' in algo:
hist_percent = UniformFloatHyperparameter( hist_percent = UniformFloatHyperparameter(
"hist_percent", hist_percent[0], hist_percent[1], default_value=hist_percent[0]) "hist_percent",
hist_percent[0],
hist_percent[1],
default_value=hist_percent[0])
hyper_params.append(hist_percent) hyper_params.append(hist_percent)
if len(algo) > 1: if len(algo) > 1:
algo = CategoricalHyperparameter( algo = CategoricalHyperparameter("algo", algo, default_value=algo[0])
"algo", algo, default_value=algo[0])
hyper_params.append(algo) hyper_params.append(algo)
else: else:
algo = algo[0] algo = algo[0]
...@@ -397,7 +408,10 @@ def quant_post_hpo(executor, ...@@ -397,7 +408,10 @@ def quant_post_hpo(executor,
weight_quantize_type = weight_quantize_type[0] weight_quantize_type = weight_quantize_type[0]
if len(batch_size) > 1: if len(batch_size) > 1:
batch_size = UniformIntegerHyperparameter( batch_size = UniformIntegerHyperparameter(
"batch_size", batch_size[0], batch_size[1], default_value=batch_size[0]) "batch_size",
batch_size[0],
batch_size[1],
default_value=batch_size[0])
hyper_params.append(batch_size) hyper_params.append(batch_size)
else: else:
batch_size = batch_size[0] batch_size = batch_size[0]
...@@ -407,7 +421,7 @@ def quant_post_hpo(executor, ...@@ -407,7 +421,7 @@ def quant_post_hpo(executor,
"batch_num", batch_num[0], batch_num[1], default_value=batch_num[0]) "batch_num", batch_num[0], batch_num[1], default_value=batch_num[0])
hyper_params.append(batch_num) hyper_params.append(batch_num)
else: else:
batch_num = batch_num[0] batch_num = batch_num[0]
if len(hyper_params) == 0: if len(hyper_params) == 0:
quant_post( \ quant_post( \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册