diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index d0f2d64721f05a29089cab29db4252973ffe04e2..9ff807f7bfac68f9548a397a5705f7822de5572c 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -183,6 +183,11 @@ class Engine(object): self.model = build_model(self.config["Arch"]) # set @to_static for benchmark, skip this by default. apply_to_static(self.config, self.model) + + # for slim + self.pruner = get_pruner(self.config, self.model) + self.quanter = get_quaner(self.config, self.model) + # load_pretrain if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"].startswith("http"): @@ -192,10 +197,6 @@ class Engine(object): load_dygraph_pretrain( self.model, self.config["Global"]["pretrained_model"]) - # for slim - self.pruner = get_pruner(self.config, self.model) - self.quanter = get_quaner(self.config, self.model) - # build optimizer if self.mode == 'train': self.optimizer, self.lr_sch = build_optimizer(