未验证 提交 ba7e185c 编写于 作者: G Guanghua Yu 提交者: GitHub

fix quant export model (#3655)

上级 a82faaa2
......@@ -558,8 +558,7 @@ class Trainer(object):
shape=[None, 3, 192, 64], name='crops')
})
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
static_model = paddle.jit.to_static(
self.model, input_spec=input_spec)
# NOTE: dy2st do not pruned program, but jit.save will prune program
......@@ -567,16 +566,19 @@ class Trainer(object):
pruned_input_spec = self._prune_input_spec(
input_spec, static_model.forward.main_program,
static_model.forward.outputs)
# dy2st and save model
if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT':
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
else:
self.cfg.slim.save_quantized_model(
self.model,
os.path.join(save_dir, 'model'),
input_spec=input_spec)
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
def _prune_input_spec(self, input_spec, program, targets):
# try to prune static program to figure out pruned input spec
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册