未验证 提交 d2ecc862 编写于 作者: C ceci3 提交者: GitHub

fix ptq hpo (#1021)

上级 a00f830f
......@@ -55,7 +55,7 @@ def quantize(args):
save_model_filename='__model__',
save_params_filename='__params__',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_quantize_type='channel_wise_abs_max',
weight_quantize_type=['channel_wise_abs_max'],
runcount_limit=args.max_model_quant_count)
def main():
......
......@@ -189,10 +189,10 @@ def eval_quant_model():
valid_data_num = 0
max_eval_data_num = 200
if g_quant_config.eval_sample_generator is not None:
feed_dict=False
feed_dict = False
eval_dataloader = g_quant_config.eval_sample_generator
else:
feed_dict=True
feed_dict = True
eval_dataloader = g_quant_config.eval_dataloader
for i, data in enumerate(eval_dataloader()):
with paddle.static.scope_guard(float_scope):
......@@ -236,12 +236,20 @@ def eval_quant_model():
def quantize(cfg):
"""model quantize job"""
algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0]
hist_percent = cfg["hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[0][0]
bias_correct = cfg["bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[0][0]
batch_size = cfg["batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][0]
batch_num = cfg["batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
weight_quantize_type = cfg["weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[0]
print(hist_percent, bias_correct, batch_size, batch_num, weight_quantize_type)
hist_percent = cfg[
"hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[
0][0]
bias_correct = cfg[
"bias_correct"] if "bias_correct" in cfg else g_quant_config.bias_correct[
0][0]
batch_size = cfg[
"batch_size"] if "batch_size" in cfg else g_quant_config.batch_size[0][
0]
batch_num = cfg[
"batch_num"] if "batch_num" in cfg else g_quant_config.batch_num[0][0]
weight_quantize_type = cfg[
"weight_quantize_type"] if "weight_quantize_type" in cfg else g_quant_config.weight_quantize_type[
0]
quant_post( \
executor=g_quant_config.executor, \
......@@ -279,34 +287,35 @@ def quantize(cfg):
return emd_loss
def quant_post_hpo(executor,
place,
model_dir,
quantize_model_path,
train_sample_generator=None,
eval_sample_generator=None,
train_dataloader=None,
eval_dataloader=None,
eval_function=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
scope=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
weight_quantize_type=['channel_wise_abs_max'],
algo=["KL", "hist", "avg", "mse"],
bias_correct=[True, False],
hist_percent=[0.98, 0.999], ### uniform sample in list.
batch_size=[10, 30], ### uniform sample in list.
batch_num=[10, 30], ### uniform sample in list.
optimize_model=False,
is_use_cache_file=False,
cache_dir="./temp_post_training",
runcount_limit=30):
def quant_post_hpo(
executor,
place,
model_dir,
quantize_model_path,
train_sample_generator=None,
eval_sample_generator=None,
train_dataloader=None,
eval_dataloader=None,
eval_function=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
scope=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
weight_quantize_type=['channel_wise_abs_max'],
algo=["KL", "hist", "avg", "mse"],
bias_correct=[True, False],
hist_percent=[0.98, 0.999], ### uniform sample in list.
batch_size=[10, 30], ### uniform sample in list.
batch_num=[10, 30], ### uniform sample in list.
optimize_model=False,
is_use_cache_file=False,
cache_dir="./temp_post_training",
runcount_limit=30):
"""
The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the
......@@ -360,25 +369,27 @@ def quant_post_hpo(executor,
global g_quant_config
g_quant_config = QuantConfig(
executor, place, model_dir, quantize_model_path, algo, hist_percent,
executor, place, model_dir, quantize_model_path, algo, hist_percent,
bias_correct, batch_size, batch_num, train_sample_generator,
eval_sample_generator, train_dataloader, eval_dataloader, eval_function,
model_filename, params_filename,
save_model_filename, save_params_filename, scope, quantizable_op_type,
is_full_quantize, weight_bits, activation_bits, weight_quantize_type,
optimize_model, is_use_cache_file, cache_dir)
model_filename, params_filename, save_model_filename,
save_params_filename, scope, quantizable_op_type, is_full_quantize,
weight_bits, activation_bits, weight_quantize_type, optimize_model,
is_use_cache_file, cache_dir)
cs = ConfigurationSpace()
hyper_params = []
if 'hist' in algo:
hist_percent = UniformFloatHyperparameter(
"hist_percent", hist_percent[0], hist_percent[1], default_value=hist_percent[0])
"hist_percent",
hist_percent[0],
hist_percent[1],
default_value=hist_percent[0])
hyper_params.append(hist_percent)
if len(algo) > 1:
algo = CategoricalHyperparameter(
"algo", algo, default_value=algo[0])
algo = CategoricalHyperparameter("algo", algo, default_value=algo[0])
hyper_params.append(algo)
else:
algo = algo[0]
......@@ -397,7 +408,10 @@ def quant_post_hpo(executor,
weight_quantize_type = weight_quantize_type[0]
if len(batch_size) > 1:
batch_size = UniformIntegerHyperparameter(
"batch_size", batch_size[0], batch_size[1], default_value=batch_size[0])
"batch_size",
batch_size[0],
batch_size[1],
default_value=batch_size[0])
hyper_params.append(batch_size)
else:
batch_size = batch_size[0]
......@@ -407,7 +421,7 @@ def quant_post_hpo(executor,
"batch_num", batch_num[0], batch_num[1], default_value=batch_num[0])
hyper_params.append(batch_num)
else:
batch_num = batch_num[0]
batch_num = batch_num[0]
if len(hyper_params) == 0:
quant_post( \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册