未验证 提交 8ae19458 编写于 作者: C ceci3 提交者: GitHub

fix ptq hpo in ac (#1423)

上级 f481cb92
...@@ -597,17 +597,23 @@ class AutoCompression: ...@@ -597,17 +597,23 @@ class AutoCompression:
train_config): train_config):
# start compress, including train/eval model # start compress, including train/eval model
# TODO: add the emd loss of evaluation model. # TODO: add the emd loss of evaluation model.
if self.updated_model_dir != self.model_dir: if strategy_idx == 0:
model_dir = self.model_dir
else:
model_dir = os.path.join(self.tmp_dir,
'strategy_{}'.format(str(strategy_idx)))
if self.updated_model_dir != model_dir:
# If model is ONNX, convert it to inference model firstly. # If model is ONNX, convert it to inference model firstly.
load_inference_model( load_inference_model(
self.model_dir, model_dir,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
executor=self._exe) executor=self._exe)
if strategy == 'quant_post': if strategy == 'quant_post':
quant_post( quant_post(
self._exe, self._exe,
model_dir=self.updated_model_dir, model_dir=model_dir,
quantize_model_path=os.path.join( quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
data_loader=self.train_dataloader, data_loader=self.train_dataloader,
...@@ -633,10 +639,10 @@ class AutoCompression: ...@@ -633,10 +639,10 @@ class AutoCompression:
if platform.system().lower() != 'linux': if platform.system().lower() != 'linux':
raise NotImplementedError( raise NotImplementedError(
"post-quant-hpo is not support in system other than linux") "post-quant-hpo is not support in system other than linux")
if self.updated_model_dir != self.model_dir: if self.updated_model_dir != model_dir:
# If model is ONNX, convert it to inference model firstly. # If model is ONNX, convert it to inference model firstly.
load_inference_model( load_inference_model(
self.model_dir, model_dir,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
executor=self._exe) executor=self._exe)
...@@ -648,7 +654,7 @@ class AutoCompression: ...@@ -648,7 +654,7 @@ class AutoCompression:
post_quant_hpo.quant_post_hpo( post_quant_hpo.quant_post_hpo(
self._exe, self._exe,
self._places, self._places,
model_dir=self.updated_model_dir, model_dir=model_dir,
quantize_model_path=os.path.join( quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
train_dataloader=self.train_dataloader, train_dataloader=self.train_dataloader,
...@@ -673,12 +679,6 @@ class AutoCompression: ...@@ -673,12 +679,6 @@ class AutoCompression:
else: else:
assert 'dis' in strategy, "Only support optimizer compressed model by distillation loss." assert 'dis' in strategy, "Only support optimizer compressed model by distillation loss."
if strategy_idx == 0:
model_dir = self.model_dir
else:
model_dir = os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx)))
[inference_program, feed_target_names, fetch_targets]= load_inference_model( \ [inference_program, feed_target_names, fetch_targets]= load_inference_model( \
model_dir, \ model_dir, \
model_filename=self.model_filename, params_filename=self.params_filename, model_filename=self.model_filename, params_filename=self.params_filename,
......
...@@ -385,7 +385,9 @@ def build_quant_program(executor, place, config, train_program_info, ...@@ -385,7 +385,9 @@ def build_quant_program(executor, place, config, train_program_info,
def _get_label_info(dataloader, feed_target_names): def _get_label_info(dataloader, feed_target_names):
label_info = {} label_info = {}
for data in dataloader(): for data in dataloader():
for key, value in data[0].items(): if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
for key, value in data.items():
if key in feed_target_names: if key in feed_target_names:
continue continue
label_info['name'] = key label_info['name'] = key
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册