未验证 提交 d0d43915 编写于 作者: C ceci3 提交者: GitHub

fix final quant strategy (#1207)

上级 6a438d9c
...@@ -65,8 +65,7 @@ ac = AutoCompression( ...@@ -65,8 +65,7 @@ ac = AutoCompression(
model_filename="inference.pdmodel", model_filename="inference.pdmodel",
params_filename="inference.pdiparams", params_filename="inference.pdiparams",
save_dir="output", save_dir="output",
strategy_config=None, config=None,
train_config=None,
train_dataloader=train_loader, train_dataloader=train_loader,
eval_dataloader=train_loader) # eval_function to verify accuracy eval_dataloader=train_loader) # eval_function to verify accuracy
ac.compress() ac.compress()
......
...@@ -239,8 +239,18 @@ def prepare_strategy(executor, ...@@ -239,8 +239,18 @@ def prepare_strategy(executor,
return strategy_config return strategy_config
def get_final_quant_config(ptq_loss): def get_final_quant_config(ptq_loss, model_type=None):
""" transform quantization tester config to real quantization config """ """ transform quantization tester config to real quantization config """
### use ptq & hpo when model_type is transformer
if model_type == 'transformer':
quant_config = Quantization(**default_quant_config)
hpo_config = HyperParameterOptimization(**default_hpo_config)
configs = [{
'Quantization': quant_config,
'HyperParameterOptimization': hpo_config
}]
return configs
### if emd loss less than MAGIC_MIN_EMD_DISTANCE, final compress. ### if emd loss less than MAGIC_MIN_EMD_DISTANCE, final compress.
if ptq_loss < MAGIC_MIN_EMD_DISTANCE: if ptq_loss < MAGIC_MIN_EMD_DISTANCE:
return None return None
......
...@@ -135,9 +135,6 @@ class AutoCompression: ...@@ -135,9 +135,6 @@ class AutoCompression:
os.makedirs(self.final_dir) os.makedirs(self.final_dir)
# load config # load config
assert type(config) in [
dict, str, set, list, tuple
], f"The type of config should be in [dict, str, set, list, tuple] but got {type(config)}"
if isinstance(config, str): if isinstance(config, str):
config = load_config(config) config = load_config(config)
self.strategy_config = extract_strategy_config(config) self.strategy_config = extract_strategy_config(config)
...@@ -562,7 +559,8 @@ class AutoCompression: ...@@ -562,7 +559,8 @@ class AutoCompression:
).lower() == 'linux': ).lower() == 'linux':
ptq_loss = post_quant_hpo.g_min_emd_loss ptq_loss = post_quant_hpo.g_min_emd_loss
final_quant_config = get_final_quant_config(ptq_loss) final_quant_config = get_final_quant_config(ptq_loss,
self.model_type)
if final_quant_config is not None: if final_quant_config is not None:
quant_strategy, quant_config = self._prepare_strategy( quant_strategy, quant_config = self._prepare_strategy(
final_quant_config) final_quant_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册