未验证 提交 d5090d5a 编写于 作者: G guofei 提交者: GitHub

Fix the bug of export quantized inference model (#2300)

* Fix the bug of export quantized inference model

* Fix the bug of export quantized inference model
上级 b8f70f95
......@@ -50,6 +50,7 @@ class Trainer(object):
"mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower()
self.optimizer = None
self.slim = None
# build model
self.model = create(cfg.architecture)
......@@ -58,8 +59,8 @@ class Trainer(object):
if 'slim' in cfg and cfg.slim:
if self.mode == 'train':
self.load_weights(cfg.pretrain_weights, cfg.weight_type)
slim = create(cfg.slim)
slim(self.model)
self.slim = create(cfg.slim)
self.slim(self.model)
# build data loader
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
......@@ -385,17 +386,24 @@ class Trainer(object):
}]
# dy2st and save model
static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec(
input_spec, static_model.forward.main_program,
static_model.forward.outputs)
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
if self.slim is None or self.cfg['slim'] != 'quant':
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec(
input_spec, static_model.forward.main_program,
static_model.forward.outputs)
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
else:
self.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=input_spec)
def _prune_input_spec(self, input_spec, program, targets):
# try to prune static program to figure out pruned input spec
......
......@@ -46,3 +46,7 @@ class QAT(object):
logger.info(model)
return model
def save_quantized_model(self, layer, path, input_spec=None, **config):
self.quanter.save_quantized_model(
model=layer, path=path, input_spec=input_spec, **config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册