diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index d43bc488059681881ea0f28333668d430f0a2f6c..f2c2e412c757c269d1475734e7f378312f4e7775 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -26,7 +26,7 @@ from .utils import * from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.utils import logger from ppcls.utils.save_load import load_dygraph_pretrain -from ppcls.engine.slim import prune_model, quantize_model +from ppcls.arch.slim import prune_model, quantize_model __all__ = ["build_model", "RecModel", "DistillationModel"] diff --git a/ppcls/engine/slim/__init__.py b/ppcls/arch/slim/__init__.py similarity index 86% rename from ppcls/engine/slim/__init__.py rename to ppcls/arch/slim/__init__.py index de3d857b213dcb6734916ba79c2d08bd64d76fa7..e2842472244a719619ffaf45d3759166fac63893 100644 --- a/ppcls/engine/slim/__init__.py +++ b/ppcls/arch/slim/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ppcls.engine.slim.prune import prune_model -from ppcls.engine.slim.quant import quantize_model +from ppcls.arch.slim.prune import prune_model diff --git a/ppcls/engine/slim/prune.py b/ppcls/arch/slim/prune.py similarity index 100% rename from ppcls/engine/slim/prune.py rename to ppcls/arch/slim/prune.py diff --git a/ppcls/engine/slim/quant.py b/ppcls/arch/slim/quant.py similarity index 100% rename from ppcls/engine/slim/quant.py rename to ppcls/arch/slim/quant.py diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 22b3e058972b75d271ca3ccd8e2f07155a0754f7..c53594889ad2468dca4f823c99599f9f2cd34317 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -186,8 +186,6 @@ class Engine(object): # build model self.model = build_model(self.config) - self.quanted = self.config.get("Slim", {}).get("quant", False) - self.pruned = self.config.get("Slim", {}).get("prune", False) # set @to_static for benchmark, skip this by default. apply_to_static(self.config, self.model) @@ -368,7 +366,7 @@ class Engine(object): model.eval() save_path = os.path.join(self.config["Global"]["save_inference_dir"], "inference") - if self.quanted: + if model.quanter: model.quanter.save_quantized_model( model.base_model, save_path,