未验证 提交 8ae19458 编写于 作者: C ceci3 提交者: GitHub

fix ptq hpo in ac (#1423)

上级 f481cb92
......@@ -597,17 +597,23 @@ class AutoCompression:
train_config):
# start compress, including train/eval model
# TODO: add the emd loss of evaluation model.
if self.updated_model_dir != self.model_dir:
if strategy_idx == 0:
model_dir = self.model_dir
else:
model_dir = os.path.join(self.tmp_dir,
'strategy_{}'.format(str(strategy_idx)))
if self.updated_model_dir != model_dir:
# If model is ONNX, convert it to inference model firstly.
load_inference_model(
self.model_dir,
model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
executor=self._exe)
if strategy == 'quant_post':
quant_post(
self._exe,
model_dir=self.updated_model_dir,
model_dir=model_dir,
quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
data_loader=self.train_dataloader,
......@@ -633,10 +639,10 @@ class AutoCompression:
if platform.system().lower() != 'linux':
raise NotImplementedError(
"post-quant-hpo is not support in system other than linux")
if self.updated_model_dir != self.model_dir:
if self.updated_model_dir != model_dir:
# If model is ONNX, convert it to inference model firstly.
load_inference_model(
self.model_dir,
model_dir,
model_filename=self.model_filename,
params_filename=self.params_filename,
executor=self._exe)
......@@ -648,7 +654,7 @@ class AutoCompression:
post_quant_hpo.quant_post_hpo(
self._exe,
self._places,
model_dir=self.updated_model_dir,
model_dir=model_dir,
quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
train_dataloader=self.train_dataloader,
......@@ -673,12 +679,6 @@ class AutoCompression:
else:
assert 'dis' in strategy, "Only support optimizer compressed model by distillation loss."
if strategy_idx == 0:
model_dir = self.model_dir
else:
model_dir = os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx)))
[inference_program, feed_target_names, fetch_targets]= load_inference_model( \
model_dir, \
model_filename=self.model_filename, params_filename=self.params_filename,
......
......@@ -385,7 +385,9 @@ def build_quant_program(executor, place, config, train_program_info,
def _get_label_info(dataloader, feed_target_names):
label_info = {}
for data in dataloader():
for key, value in data[0].items():
if isinstance(data, list) or isinstance(data, tuple):
data = data[0]
for key, value in data.items():
if key in feed_target_names:
continue
label_info['name'] = key
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册