From 230c8c1efa92989a2a17b9f46c1129862ffeac5c Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 18 Jan 2021 19:06:39 +0800 Subject: [PATCH] fix feed var need prune when export (#2054) * fix feed var need prune when export * fix format * fix format --- dygraph/ppdet/engine/trainer.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/dygraph/ppdet/engine/trainer.py b/dygraph/ppdet/engine/trainer.py index b32d74a8c..8eb5d0e39 100644 --- a/dygraph/ppdet/engine/trainer.py +++ b/dygraph/ppdet/engine/trainer.py @@ -327,5 +327,30 @@ class Trainer(object): # dy2st and save model static_model = paddle.jit.to_static(self.model, input_spec=input_spec) - paddle.jit.save(static_model, os.path.join(save_dir, 'model')) + # NOTE: dy2st do not pruned program, but jit.save will prune program + # input spec, prune input spec here and save with pruned input spec + pruned_input_spec = self._prune_input_spec( + input_spec, static_model.forward.main_program, + static_model.forward.outputs) + paddle.jit.save( + static_model, + os.path.join(save_dir, 'model'), + input_spec=pruned_input_spec) logger.info("Export model and saved in {}".format(save_dir)) + + def _prune_input_spec(self, input_spec, program, targets): + # try to prune static program to figure out pruned input spec + # so we perform following operations in static mode + paddle.enable_static() + pruned_input_spec = [{}] + program = program.clone() + program = program._prune(targets=targets) + global_block = program.global_block() + for name, spec in input_spec[0].items(): + try: + v = global_block.var(name) + pruned_input_spec[0][name] = spec + except Exception: + pass + paddle.disable_static() + return pruned_input_spec -- GitLab