From 46f2957bfd165cddd6c6da98d554a8ede274b0f3 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 22 Nov 2022 11:23:08 +0800 Subject: [PATCH] save model with clip_extra=False (#1532) --- paddleslim/auto_compression/compressor.py | 10 ++++++++-- paddleslim/auto_compression/utils/fake_ptq.py | 3 ++- paddleslim/auto_compression/utils/prune_model.py | 6 ++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index a74105a2..916074ed 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -263,7 +263,12 @@ class AutoCompression: save_path = os.path.join(save_path, "infered_shape") os.makedirs(save_path) paddle.static.save_inference_model( - save_path, feed_vars, fetch_targets, exe, program=inference_program) + save_path, + feed_vars, + fetch_targets, + exe, + program=inference_program, + clip_extra=False) _logger.info(f"Saved model infered shape to {save_path}") @property @@ -843,7 +848,8 @@ class AutoCompression: feed_vars=feed_vars, fetch_vars=test_program_info.fetch_targets, executor=self._exe, - program=test_program) + program=test_program, + clip_extra=False) def export_onnx(self, model_name='quant_model.onnx', diff --git a/paddleslim/auto_compression/utils/fake_ptq.py b/paddleslim/auto_compression/utils/fake_ptq.py index 91cccfc2..bce49b4f 100644 --- a/paddleslim/auto_compression/utils/fake_ptq.py +++ b/paddleslim/auto_compression/utils/fake_ptq.py @@ -169,5 +169,6 @@ def post_quant_fake(executor, feed_vars=feed_vars, fetch_vars=_fetch_list, executor=executor, - program=_program) + program=_program, + clip_extra=False) print("The quantized model is saved in: " + save_model_path) diff --git a/paddleslim/auto_compression/utils/prune_model.py b/paddleslim/auto_compression/utils/prune_model.py index c0da14ca..a784aa11 100644 --- a/paddleslim/auto_compression/utils/prune_model.py +++ b/paddleslim/auto_compression/utils/prune_model.py @@ -95,7 +95,8 @@ def get_sparse_model(executor, places, model_file, param_file, ratio, feed_vars=feed_vars, fetch_vars=fetch_targets, executor=executor, - program=inference_program) + program=inference_program, + clip_extra=False) print("The pruned model is saved in: ", save_path) @@ -170,4 +171,5 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path): feed_vars=feed_vars, fetch_vars=fetch_targets, executor=executor, - program=main_program) + program=main_program, + clip_extra=False) -- GitLab