diff --git a/demo/prune/train.py b/demo/prune/train.py index 23379a326f6c6a93ca4e45e58de6e30263ba7a34..1a25f3f1e4dcce30f31797b3f1acb1276a744581 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -8,7 +8,7 @@ import math import time import numpy as np import paddle.fluid as fluid -from paddleslim.prune import Pruner +from paddleslim.prune import Pruner, save_model from paddleslim.common import get_logger from paddleslim.analysis import flops sys.path.append(sys.path[0] + "/../") @@ -223,7 +223,7 @@ def compress(args): train(i, pruned_program) if i % args.test_period == 0: test(i, pruned_val_program) - save_model(pruned_val_program, + save_model(exe, pruned_val_program, os.path.join(args.model_path, str(i))) diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 1520ba0b04d98aa44d76820bf74c6491785dd569..c46fd75dd3220abffcaabcadc78b271e48cb5489 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -23,8 +23,8 @@ from .sensitive import * from ..prune import sensitive from .prune_walker import * from ..prune import prune_walker -from io import * -from ..prune import io +from .prune_io import * +from ..prune import prune_io __all__ = [] @@ -33,4 +33,4 @@ __all__ += auto_pruner.__all__ __all__ += sensitive_pruner.__all__ __all__ += sensitive.__all__ __all__ += prune_walker.__all__ -__all__ += io.__all__ +__all__ += prune_io.__all__ diff --git a/paddleslim/prune/io.py b/paddleslim/prune/prune_io.py similarity index 100% rename from paddleslim/prune/io.py rename to paddleslim/prune/prune_io.py