未验证 提交 d0d43915 编写于 作者: C ceci3 提交者: GitHub

fix final quant strategy (#1207)

上级 6a438d9c
......@@ -65,8 +65,7 @@ ac = AutoCompression(
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="output",
strategy_config=None,
train_config=None,
config=None,
train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress()
......
......@@ -239,8 +239,18 @@ def prepare_strategy(executor,
return strategy_config
def get_final_quant_config(ptq_loss):
def get_final_quant_config(ptq_loss, model_type=None):
""" transform quantization tester config to real quantization config """
### use ptq & hpo when model_type is transformer
if model_type == 'transformer':
quant_config = Quantization(**default_quant_config)
hpo_config = HyperParameterOptimization(**default_hpo_config)
configs = [{
'Quantization': quant_config,
'HyperParameterOptimization': hpo_config
}]
return configs
### if emd loss less than MAGIC_MIN_EMD_DISTANCE, final compress.
if ptq_loss < MAGIC_MIN_EMD_DISTANCE:
return None
......
......@@ -135,9 +135,6 @@ class AutoCompression:
os.makedirs(self.final_dir)
# load config
assert type(config) in [
dict, str, set, list, tuple
], f"The type of config should be in [dict, str, set, list, tuple] but got {type(config)}"
if isinstance(config, str):
config = load_config(config)
self.strategy_config = extract_strategy_config(config)
......@@ -562,7 +559,8 @@ class AutoCompression:
).lower() == 'linux':
ptq_loss = post_quant_hpo.g_min_emd_loss
final_quant_config = get_final_quant_config(ptq_loss)
final_quant_config = get_final_quant_config(ptq_loss,
self.model_type)
if final_quant_config is not None:
quant_strategy, quant_config = self._prepare_strategy(
final_quant_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册