diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 53e82792ac2dc6e4558dcda7ba827cc675152283..d8a2803d903feb018dbac9938e2fe3262543e583 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -558,25 +558,27 @@ class Trainer(object): shape=[None, 3, 192, 64], name='crops') }) - # dy2st and save model - if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT': - static_model = paddle.jit.to_static( + + static_model = paddle.jit.to_static( self.model, input_spec=input_spec) - # NOTE: dy2st do not pruned program, but jit.save will prune program - # input spec, prune input spec here and save with pruned input spec - pruned_input_spec = self._prune_input_spec( + # NOTE: dy2st do not pruned program, but jit.save will prune program + # input spec, prune input spec here and save with pruned input spec + pruned_input_spec = self._prune_input_spec( input_spec, static_model.forward.main_program, static_model.forward.outputs) + + # dy2st and save model + if 'slim' not in self.cfg or self.cfg['slim_type'] != 'QAT': paddle.jit.save( static_model, os.path.join(save_dir, 'model'), input_spec=pruned_input_spec) - logger.info("Export model and saved in {}".format(save_dir)) else: self.cfg.slim.save_quantized_model( self.model, os.path.join(save_dir, 'model'), - input_spec=input_spec) + input_spec=pruned_input_spec) + logger.info("Export model and saved in {}".format(save_dir)) def _prune_input_spec(self, input_spec, program, targets): # try to prune static program to figure out pruned input spec