diff --git a/dygraph/ppdet/engine/trainer.py b/dygraph/ppdet/engine/trainer.py index 131c66199adf4255a91b23d9216aaa31afff3b36..c96109a785920d4942c2de7408d891aa942d9210 100644 --- a/dygraph/ppdet/engine/trainer.py +++ b/dygraph/ppdet/engine/trainer.py @@ -50,6 +50,7 @@ class Trainer(object): "mode should be 'train', 'eval' or 'test'" self.mode = mode.lower() self.optimizer = None + self.slim = None # build model self.model = create(cfg.architecture) @@ -58,8 +59,8 @@ class Trainer(object): if 'slim' in cfg and cfg.slim: if self.mode == 'train': self.load_weights(cfg.pretrain_weights, cfg.weight_type) - slim = create(cfg.slim) - slim(self.model) + self.slim = create(cfg.slim) + self.slim(self.model) # build data loader self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] @@ -385,17 +386,24 @@ class Trainer(object): }] # dy2st and save model - static_model = paddle.jit.to_static(self.model, input_spec=input_spec) - # NOTE: dy2st do not pruned program, but jit.save will prune program - # input spec, prune input spec here and save with pruned input spec - pruned_input_spec = self._prune_input_spec( - input_spec, static_model.forward.main_program, - static_model.forward.outputs) - paddle.jit.save( - static_model, - os.path.join(save_dir, 'model'), - input_spec=pruned_input_spec) - logger.info("Export model and saved in {}".format(save_dir)) + if self.slim is None or self.cfg['slim'] != 'quant': + static_model = paddle.jit.to_static( + self.model, input_spec=input_spec) + # NOTE: dy2st do not pruned program, but jit.save will prune program + # input spec, prune input spec here and save with pruned input spec + pruned_input_spec = self._prune_input_spec( + input_spec, static_model.forward.main_program, + static_model.forward.outputs) + paddle.jit.save( + static_model, + os.path.join(save_dir, 'model'), + input_spec=pruned_input_spec) + logger.info("Export model and saved in {}".format(save_dir)) + else: + self.slim.save_quantized_model( + self.model, + os.path.join(save_dir, 'model'), + input_spec=input_spec) def _prune_input_spec(self, input_spec, program, targets): # try to prune static program to figure out pruned input spec diff --git a/dygraph/ppdet/slim/quant.py b/dygraph/ppdet/slim/quant.py index abf123bd199eaf218569d9ea2e5293abe7e49de9..a1fe126ef2e3c8b614c2171c4f3a43ae709f480f 100644 --- a/dygraph/ppdet/slim/quant.py +++ b/dygraph/ppdet/slim/quant.py @@ -46,3 +46,7 @@ class QAT(object): logger.info(model) return model + + def save_quantized_model(self, layer, path, input_spec=None, **config): + self.quanter.save_quantized_model( + model=layer, path=path, input_spec=input_spec, **config)