未验证 提交 d2ecc862 编写于 作者: C ceci3 提交者: GitHub

fix ptq hpo (#1021)

上级 a00f830f
...@@ -55,7 +55,7 @@ def quantize(args): ...@@ -55,7 +55,7 @@ def quantize(args):
save_model_filename='__model__', save_model_filename='__model__',
save_params_filename='__params__', save_params_filename='__params__',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type='channel_wise_abs_max', weight_quantize_type=['channel_wise_abs_max'],
runcount_limit=args.max_model_quant_count) runcount_limit=args.max_model_quant_count)
def main(): def main():
......
...@@ -189,10 +189,10 @@ def eval_quant_model(): ...@@ -189,10 +189,10 @@ def eval_quant_model():
valid_data_num = 0 valid_data_num = 0
max_eval_data_num = 200 max_eval_data_num = 200
if g_quant_config.eval_sample_generator is not None: if g_quant_config.eval_sample_generator is not None:
feed_dict=False feed_dict = False
eval_dataloader = g_quant_config.eval_sample_generator eval_dataloader = g_quant_config.eval_sample_generator
else: else:
feed_dict=True feed_dict = True
eval_dataloader = g_quant_config.eval_dataloader eval_dataloader = g_quant_config.eval_dataloader
for i, data in enumerate(eval_dataloader()): for i, data in enumerate(eval_dataloader()):
with paddle.static.scope_guard(float_scope): with paddle.static.scope_guard(float_scope):
...@@ -236,12 +236,20 @@ def eval_quant_model(): ...@@ -236,12 +236,20 @@ def eval_quant_model():
def quantize(cfg): def quantize(cfg):
"""model quantize job""" """model quantize job"""
algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0] algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0]
hist_percent = cfg["hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[0][0] hist_percent = cfg[
bias_correct = cfg["bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[0][0] "hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[
batch_size = cfg["batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][0] 0][0]
batch_num = cfg["batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0] bias_correct = cfg[
weight_quantize_type = cfg["weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[0] "bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[
print(hist_percent, bias_correct, batch_size, batch_num, weight_quantize_type) 0][0]
batch_size = cfg[
"batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][
0]
batch_num = cfg[
"batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
weight_quantize_type = cfg[
"weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[
0]
quant_post( \ quant_post( \
executor=g_quant_config.executor, \ executor=g_quant_config.executor, \
...@@ -279,7 +287,8 @@ def quantize(cfg): ...@@ -279,7 +287,8 @@ def quantize(cfg):
return emd_loss return emd_loss
def quant_post_hpo(executor, def quant_post_hpo(
executor,
place, place,
model_dir, model_dir,
quantize_model_path, quantize_model_path,
...@@ -363,22 +372,24 @@ def quant_post_hpo(executor, ...@@ -363,22 +372,24 @@ def quant_post_hpo(executor,
executor, place, model_dir, quantize_model_path, algo, hist_percent, executor, place, model_dir, quantize_model_path, algo, hist_percent,
bias_correct, batch_size, batch_num, train_sample_generator, bias_correct, batch_size, batch_num, train_sample_generator,
eval_sample_generator, train_dataloader, eval_dataloader, eval_function, eval_sample_generator, train_dataloader, eval_dataloader, eval_function,
model_filename, params_filename, model_filename, params_filename, save_model_filename,
save_model_filename, save_params_filename, scope, quantizable_op_type, save_params_filename, scope, quantizable_op_type, is_full_quantize,
is_full_quantize, weight_bits, activation_bits, weight_quantize_type, weight_bits, activation_bits, weight_quantize_type, optimize_model,
optimize_model, is_use_cache_file, cache_dir) is_use_cache_file, cache_dir)
cs = ConfigurationSpace() cs = ConfigurationSpace()
hyper_params = [] hyper_params = []
if 'hist' in algo: if 'hist' in algo:
hist_percent = UniformFloatHyperparameter( hist_percent = UniformFloatHyperparameter(
"hist_percent", hist_percent[0], hist_percent[1], default_value=hist_percent[0]) "hist_percent",
hist_percent[0],
hist_percent[1],
default_value=hist_percent[0])
hyper_params.append(hist_percent) hyper_params.append(hist_percent)
if len(algo) > 1: if len(algo) > 1:
algo = CategoricalHyperparameter( algo = CategoricalHyperparameter("algo", algo, default_value=algo[0])
"algo", algo, default_value=algo[0])
hyper_params.append(algo) hyper_params.append(algo)
else: else:
algo = algo[0] algo = algo[0]
...@@ -397,7 +408,10 @@ def quant_post_hpo(executor, ...@@ -397,7 +408,10 @@ def quant_post_hpo(executor,
weight_quantize_type = weight_quantize_type[0] weight_quantize_type = weight_quantize_type[0]
if len(batch_size) > 1: if len(batch_size) > 1:
batch_size = UniformIntegerHyperparameter( batch_size = UniformIntegerHyperparameter(
"batch_size", batch_size[0], batch_size[1], default_value=batch_size[0]) "batch_size",
batch_size[0],
batch_size[1],
default_value=batch_size[0])
hyper_params.append(batch_size) hyper_params.append(batch_size)
else: else:
batch_size = batch_size[0] batch_size = batch_size[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册