未验证 提交 dfe5bb3e 编写于 作者: C ceci3 提交者: GitHub

save model with clip_extra=False (#1530)

上级 717d7262
...@@ -265,7 +265,12 @@ class AutoCompression: ...@@ -265,7 +265,12 @@ class AutoCompression:
save_path = os.path.join(save_path, "infered_shape") save_path = os.path.join(save_path, "infered_shape")
os.makedirs(save_path) os.makedirs(save_path)
paddle.static.save_inference_model( paddle.static.save_inference_model(
save_path, feed_vars, fetch_targets, exe, program=inference_program) save_path,
feed_vars,
fetch_targets,
exe,
program=inference_program,
clip_extra=False)
_logger.info(f"Saved model infered shape to {save_path}") _logger.info(f"Saved model infered shape to {save_path}")
@property @property
...@@ -901,7 +906,8 @@ class AutoCompression: ...@@ -901,7 +906,8 @@ class AutoCompression:
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=test_program_info.fetch_targets, fetch_vars=test_program_info.fetch_targets,
executor=self._exe, executor=self._exe,
program=test_program) program=test_program,
clip_extra=False)
def export_onnx(self, def export_onnx(self,
model_name='quant_model.onnx', model_name='quant_model.onnx',
......
...@@ -169,5 +169,6 @@ def post_quant_fake(executor, ...@@ -169,5 +169,6 @@ def post_quant_fake(executor,
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=_fetch_list, fetch_vars=_fetch_list,
executor=executor, executor=executor,
program=_program) program=_program,
clip_extra=False)
print("The quantized model is saved in: " + save_model_path) print("The quantized model is saved in: " + save_model_path)
...@@ -95,7 +95,8 @@ def get_sparse_model(executor, places, model_file, param_file, ratio, ...@@ -95,7 +95,8 @@ def get_sparse_model(executor, places, model_file, param_file, ratio,
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=fetch_targets, fetch_vars=fetch_targets,
executor=executor, executor=executor,
program=inference_program) program=inference_program,
clip_extra=False)
print("The pruned model is saved in: ", save_path) print("The pruned model is saved in: ", save_path)
...@@ -170,4 +171,5 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path): ...@@ -170,4 +171,5 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path):
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=fetch_targets, fetch_vars=fetch_targets,
executor=executor, executor=executor,
program=main_program) program=main_program,
clip_extra=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册