未验证 提交 d2ecc862 编写于 作者: C ceci3 提交者: GitHub

fix ptq hpo (#1021)

上级 a00f830f
......@@ -55,7 +55,7 @@ def quantize(args):
save_model_filename='__model__',
save_params_filename='__params__',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type='channel_wise_abs_max',
weight_quantize_type=['channel_wise_abs_max'],
runcount_limit=args.max_model_quant_count)
def main():
......
......@@ -189,10 +189,10 @@ def eval_quant_model():
valid_data_num = 0
max_eval_data_num = 200
if g_quant_config.eval_sample_generator is not None:
feed_dict=False
feed_dict = False
eval_dataloader = g_quant_config.eval_sample_generator
else:
feed_dict=True
feed_dict = True
eval_dataloader = g_quant_config.eval_dataloader
for i, data in enumerate(eval_dataloader()):
with paddle.static.scope_guard(float_scope):
......@@ -236,12 +236,20 @@ def eval_quant_model():
def quantize(cfg):
"""model quantize job"""
algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0]
hist_percent = cfg["hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[0][0]
bias_correct = cfg["bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[0][0]
batch_size = cfg["batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][0]
batch_num = cfg["batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
weight_quantize_type = cfg["weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[0]
print(hist_percent, bias_correct, batch_size, batch_num, weight_quantize_type)
hist_percent = cfg[
"hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[
0][0]
bias_correct = cfg[
"bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[
0][0]
batch_size = cfg[
"batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][
0]
batch_num = cfg[
"batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
weight_quantize_type = cfg[
"weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[
0]
quant_post( \
executor=g_quant_config.executor, \
......@@ -279,7 +287,8 @@ def quantize(cfg):
return emd_loss
def quant_post_hpo(executor,
def quant_post_hpo(
executor,
place,
model_dir,
quantize_model_path,
......@@ -363,22 +372,24 @@ def quant_post_hpo(executor,
executor, place, model_dir, quantize_model_path, algo, hist_percent,
bias_correct, batch_size, batch_num, train_sample_generator,
eval_sample_generator, train_dataloader, eval_dataloader, eval_function,
model_filename, params_filename,
save_model_filename, save_params_filename, scope, quantizable_op_type,
is_full_quantize, weight_bits, activation_bits, weight_quantize_type,
optimize_model, is_use_cache_file, cache_dir)
model_filename, params_filename, save_model_filename,
save_params_filename, scope, quantizable_op_type, is_full_quantize,
weight_bits, activation_bits, weight_quantize_type, optimize_model,
is_use_cache_file, cache_dir)
cs = ConfigurationSpace()
hyper_params = []
if 'hist' in algo:
hist_percent = UniformFloatHyperparameter(
"hist_percent", hist_percent[0], hist_percent[1], default_value=hist_percent[0])
"hist_percent",
hist_percent[0],
hist_percent[1],
default_value=hist_percent[0])
hyper_params.append(hist_percent)
if len(algo) > 1:
algo = CategoricalHyperparameter(
"algo", algo, default_value=algo[0])
algo = CategoricalHyperparameter("algo", algo, default_value=algo[0])
hyper_params.append(algo)
else:
algo = algo[0]
......@@ -397,7 +408,10 @@ def quant_post_hpo(executor,
weight_quantize_type = weight_quantize_type[0]
if len(batch_size) > 1:
batch_size = UniformIntegerHyperparameter(
"batch_size", batch_size[0], batch_size[1], default_value=batch_size[0])
"batch_size",
batch_size[0],
batch_size[1],
default_value=batch_size[0])
hyper_params.append(batch_size)
else:
batch_size = batch_size[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册