未验证 提交 230c8c1e 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix feed var need prune when export (#2054)

* fix feed var need prune when export

* fix format

* fix format
上级 d4b30ff2
......@@ -327,5 +327,30 @@ class Trainer(object):
# dy2st and save model
static_model = paddle.jit.to_static(self.model, input_spec=input_spec)
paddle.jit.save(static_model, os.path.join(save_dir, 'model'))
# NOTE: dy2st do not pruned program, but jit.save will prune program
# input spec, prune input spec here and save with pruned input spec
pruned_input_spec = self._prune_input_spec(
input_spec, static_model.forward.main_program,
static_model.forward.outputs)
paddle.jit.save(
static_model,
os.path.join(save_dir, 'model'),
input_spec=pruned_input_spec)
logger.info("Export model and saved in {}".format(save_dir))
def _prune_input_spec(self, input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode
paddle.enable_static()
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
global_block = program.global_block()
for name, spec in input_spec[0].items():
try:
v = global_block.var(name)
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static()
return pruned_input_spec
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册