From 8ae194583868e78369ec395c7bc28cf8fecac100 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 21 Oct 2022 19:41:46 +0800 Subject: [PATCH] fix ptq hpo in ac (#1423) --- paddleslim/auto_compression/compressor.py | 24 +++++++++---------- .../create_compressed_program.py | 4 +++- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index 16537fe8..ad07fa44 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -597,17 +597,23 @@ class AutoCompression: train_config): # start compress, including train/eval model # TODO: add the emd loss of evaluation model. - if self.updated_model_dir != self.model_dir: + if strategy_idx == 0: + model_dir = self.model_dir + else: + model_dir = os.path.join(self.tmp_dir, + 'strategy_{}'.format(str(strategy_idx))) + + if self.updated_model_dir != model_dir: # If model is ONNX, convert it to inference model firstly. load_inference_model( - self.model_dir, + model_dir, model_filename=self.model_filename, params_filename=self.params_filename, executor=self._exe) if strategy == 'quant_post': quant_post( self._exe, - model_dir=self.updated_model_dir, + model_dir=model_dir, quantize_model_path=os.path.join( self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), data_loader=self.train_dataloader, @@ -633,10 +639,10 @@ class AutoCompression: if platform.system().lower() != 'linux': raise NotImplementedError( "post-quant-hpo is not support in system other than linux") - if self.updated_model_dir != self.model_dir: + if self.updated_model_dir != model_dir: # If model is ONNX, convert it to inference model firstly. load_inference_model( - self.model_dir, + model_dir, model_filename=self.model_filename, params_filename=self.params_filename, executor=self._exe) @@ -648,7 +654,7 @@ class AutoCompression: post_quant_hpo.quant_post_hpo( self._exe, self._places, - model_dir=self.updated_model_dir, + model_dir=model_dir, quantize_model_path=os.path.join( self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), train_dataloader=self.train_dataloader, @@ -673,12 +679,6 @@ class AutoCompression: else: assert 'dis' in strategy, "Only support optimizer compressed model by distillation loss." - if strategy_idx == 0: - model_dir = self.model_dir - else: - model_dir = os.path.join( - self.tmp_dir, 'strategy_{}'.format(str(strategy_idx))) - [inference_program, feed_target_names, fetch_targets]= load_inference_model( \ model_dir, \ model_filename=self.model_filename, params_filename=self.params_filename, diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index 8a6c7db2..011af243 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -385,7 +385,9 @@ def build_quant_program(executor, place, config, train_program_info, def _get_label_info(dataloader, feed_target_names): label_info = {} for data in dataloader(): - for key, value in data[0].items(): + if isinstance(data, list) or isinstance(data, tuple): + data = data[0] + for key, value in data.items(): if key in feed_target_names: continue label_info['name'] = key -- GitLab