未验证 提交 ba7e185c 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quant export model (#3655)

上级 a82faaa2
...@@ -558,25 +558,27 @@ class Trainer(object): ...@@ -558,25 +558,27 @@ class Trainer(object):
shape=[None, 3, 192, 64], name='crops') shape=[None, 3, 192, 64], name='crops')
}) })
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT': static_model = paddle.jit.to_static(
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec) self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program # NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec # input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec( pruned_input_spec = self._prune_input_spec(
input_spec, static_model.forward.main_program, input_spec, static_model.forward.main_program,
static_model.forward.outputs) static_model.forward.outputs)
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
paddle.jit.save( paddle.jit.save(
static_model, static_model,
os.path.join(save_dir, 'model'), os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec) input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
else: else:
self.cfg.slim.save_quantized_model( self.cfg.slim.save_quantized_model(
self.model, self.model,
os.path.join(save_dir, 'model'), os.path.join(save_dir, 'model'),
input_spec=input_spec) input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
def _prune_input_spec(self, input_spec, program, targets): def _prune_input_spec(self, input_spec, program, targets):
# try to prune static program to figure out pruned input spec # try to prune static program to figure out pruned input spec
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册