未验证 提交 d5090d5a 编写于 作者: G guofei 提交者: GitHub

Fix the bug of export quantized inference model (#2300)

* Fix the bug of export quantized inference model

* Fix the bug of export quantized inference model
上级 b8f70f95
...@@ -50,6 +50,7 @@ class Trainer(object): ...@@ -50,6 +50,7 @@ class Trainer(object):
"mode should be 'train', 'eval' or 'test'" "mode should be 'train', 'eval' or 'test'"
self.mode = mode.lower() self.mode = mode.lower()
self.optimizer = None self.optimizer = None
self.slim = None
# build model # build model
self.model = create(cfg.architecture) self.model = create(cfg.architecture)
...@@ -58,8 +59,8 @@ class Trainer(object): ...@@ -58,8 +59,8 @@ class Trainer(object):
if 'slim' in cfg and cfg.slim: if 'slim' in cfg and cfg.slim:
if self.mode == 'train': if self.mode == 'train':
self.load_weights(cfg.pretrain_weights, cfg.weight_type) self.load_weights(cfg.pretrain_weights, cfg.weight_type)
slim = create(cfg.slim) self.slim = create(cfg.slim)
slim(self.model) self.slim(self.model)
# build data loader # build data loader
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
...@@ -385,7 +386,9 @@ class Trainer(object): ...@@ -385,7 +386,9 @@ class Trainer(object):
}] }]
# dy2st and save model # dy2st and save model
static_model = paddle.jit.to_static(self.model, input_spec=input_spec) if self.slim is None or self.cfg['slim'] != 'quant':
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program # NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec # input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec( pruned_input_spec = self._prune_input_spec(
...@@ -396,6 +399,11 @@ class Trainer(object): ...@@ -396,6 +399,11 @@ class Trainer(object):
os.path.join(save_dir, 'model'), os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec) input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir)) logger.info("Export model and saved in {}".format(save_dir))
else:
self.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=input_spec)
def _prune_input_spec(self, input_spec, program, targets): def _prune_input_spec(self, input_spec, program, targets):
# try to prune static program to figure out pruned input spec # try to prune static program to figure out pruned input spec
......
...@@ -46,3 +46,7 @@ class QAT(object): ...@@ -46,3 +46,7 @@ class QAT(object):
logger.info(model) logger.info(model)
return model return model
def save_quantized_model(self, layer, path, input_spec=None, **config):
self.quanter.save_quantized_model(
model=layer, path=path, input_spec=input_spec, **config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册