diff --git a/demo/auto_compression/README.md b/demo/auto_compression/README.md index 7ce0bea0f3e2f7f61b801f5be8169c8733c3f11a..4b8c68766b2c69157fa41e5fe3473a6cda7bd3a7 100644 --- a/demo/auto_compression/README.md +++ b/demo/auto_compression/README.md @@ -65,8 +65,7 @@ ac = AutoCompression( model_filename="inference.pdmodel", params_filename="inference.pdiparams", save_dir="output", - strategy_config=None, - train_config=None, + config=None, train_dataloader=train_loader, eval_dataloader=train_loader) # eval_function to verify accuracy ac.compress() diff --git a/paddleslim/auto_compression/auto_strategy.py b/paddleslim/auto_compression/auto_strategy.py index 22477721877390db10aecd1d27f9f6b2419502d4..a1eb0b0f2df7282826d63c58311a571a18d5fdb8 100644 --- a/paddleslim/auto_compression/auto_strategy.py +++ b/paddleslim/auto_compression/auto_strategy.py @@ -239,8 +239,18 @@ def prepare_strategy(executor, return strategy_config -def get_final_quant_config(ptq_loss): +def get_final_quant_config(ptq_loss, model_type=None): """ transform quantization tester config to real quantization config """ + ### use ptq & hpo when model_type is transformer + if model_type == 'transformer': + quant_config = Quantization(**default_quant_config) + hpo_config = HyperParameterOptimization(**default_hpo_config) + configs = [{ + 'Quantization': quant_config, + 'HyperParameterOptimization': hpo_config + }] + return configs + ### if emd loss less than MAGIC_MIN_EMD_DISTANCE, final compress. if ptq_loss < MAGIC_MIN_EMD_DISTANCE: return None diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index dd67782af84f74a9e2ee030104570137674261bb..fd6f0d05fd9746c9315b4385be656f26199c20d8 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -135,9 +135,6 @@ class AutoCompression: os.makedirs(self.final_dir) # load config - assert type(config) in [ - dict, str, set, list, tuple - ], f"The type of config should be in [dict, str, set, list, tuple] but got {type(config)}" if isinstance(config, str): config = load_config(config) self.strategy_config = extract_strategy_config(config) @@ -562,7 +559,8 @@ class AutoCompression: ).lower() == 'linux': ptq_loss = post_quant_hpo.g_min_emd_loss - final_quant_config = get_final_quant_config(ptq_loss) + final_quant_config = get_final_quant_config(ptq_loss, + self.model_type) if final_quant_config is not None: quant_strategy, quant_config = self._prepare_strategy( final_quant_config)