From d0d439154e2b26554b9c640cf66f6fab7d41848f Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 30 Jun 2022 11:35:31 +0800 Subject: [PATCH] fix final quant strategy (#1207) --- demo/auto_compression/README.md | 3 +-- paddleslim/auto_compression/auto_strategy.py | 12 +++++++++++- paddleslim/auto_compression/compressor.py | 6 ++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/demo/auto_compression/README.md b/demo/auto_compression/README.md index 7ce0bea0..4b8c6876 100644 --- a/demo/auto_compression/README.md +++ b/demo/auto_compression/README.md @@ -65,8 +65,7 @@ ac = AutoCompression( model_filename="inference.pdmodel", params_filename="inference.pdiparams", save_dir="output", - strategy_config=None, - train_config=None, + config=None, train_dataloader=train_loader, eval_dataloader=train_loader) # eval_function to verify accuracy ac.compress() diff --git a/paddleslim/auto_compression/auto_strategy.py b/paddleslim/auto_compression/auto_strategy.py index 22477721..a1eb0b0f 100644 --- a/paddleslim/auto_compression/auto_strategy.py +++ b/paddleslim/auto_compression/auto_strategy.py @@ -239,8 +239,18 @@ def prepare_strategy(executor, return strategy_config -def get_final_quant_config(ptq_loss): +def get_final_quant_config(ptq_loss, model_type=None): """ transform quantization tester config to real quantization config """ + ### use ptq & hpo when model_type is transformer + if model_type == 'transformer': + quant_config = Quantization(**default_quant_config) + hpo_config = HyperParameterOptimization(**default_hpo_config) + configs = [{ + 'Quantization': quant_config, + 'HyperParameterOptimization': hpo_config + }] + return configs + ### if emd loss less than MAGIC_MIN_EMD_DISTANCE, final compress. if ptq_loss < MAGIC_MIN_EMD_DISTANCE: return None diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index dd67782a..fd6f0d05 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -135,9 +135,6 @@ class AutoCompression: os.makedirs(self.final_dir) # load config - assert type(config) in [ - dict, str, set, list, tuple - ], f"The type of config should be in [dict, str, set, list, tuple] but got {type(config)}" if isinstance(config, str): config = load_config(config) self.strategy_config = extract_strategy_config(config) @@ -562,7 +559,8 @@ class AutoCompression: ).lower() == 'linux': ptq_loss = post_quant_hpo.g_min_emd_loss - final_quant_config = get_final_quant_config(ptq_loss) + final_quant_config = get_final_quant_config(ptq_loss, + self.model_type) if final_quant_config is not None: quant_strategy, quant_config = self._prepare_strategy( final_quant_config) -- GitLab