From 6c5d1ebc2882edc6ddbeb219471be7a14ee8cbbe Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 9 Dec 2021 14:51:40 +0800 Subject: [PATCH] add pruner and quanter for theseus --- ppcls/arch/__init__.py | 13 +++++++++---- ppcls/arch/backbone/base/theseus_layer.py | 2 ++ ppcls/engine/engine.py | 13 +++++-------- ppcls/engine/slim/__init__.py | 4 ++-- ppcls/engine/slim/prune.py | 15 +++++++-------- ppcls/engine/slim/quant.py | 10 +++++----- ppcls/static/program.py | 2 +- 7 files changed, 31 insertions(+), 28 deletions(-) diff --git a/ppcls/arch/__init__.py b/ppcls/arch/__init__.py index 657fa823..d43bc488 100644 --- a/ppcls/arch/__init__.py +++ b/ppcls/arch/__init__.py @@ -26,15 +26,20 @@ from .utils import * from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.utils import logger from ppcls.utils.save_load import load_dygraph_pretrain +from ppcls.engine.slim import prune_model, quantize_model + __all__ = ["build_model", "RecModel", "DistillationModel"] def build_model(config): - config = copy.deepcopy(config) - model_type = config.pop("name") + arch_config = copy.deepcopy(config["Arch"]) + model_type = arch_config.pop("name") mod = importlib.import_module(__name__) - arch = getattr(mod, model_type)(**config) + arch = getattr(mod, model_type)(**arch_config) + if isinstance(arch, TheseusLayer): + prune_model(config, arch) + quantize_model(config, arch) return arch @@ -51,7 +56,7 @@ def apply_to_static(config, model): return model -class RecModel(nn.Layer): +class RecModel(TheseusLayer): def __init__(self, **config): super().__init__() backbone_config = config["Backbone"] diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index 64bfed0e..f5b815d1 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -16,6 +16,8 @@ class TheseusLayer(nn.Layer): super(TheseusLayer, self).__init__() self.res_dict = {} self.res_name = self.full_name() + self.pruner = None + self.quanter = None # stop doesn't work when stop layer has a parallel branch. def stop_after(self, stop_layer_name: str): diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 83a24a60..22b3e058 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -44,7 +44,6 @@ from ppcls.data import create_operators from ppcls.engine.train import train_epoch from ppcls.engine import evaluation from ppcls.arch.gears.identity_head import IdentityHead -from ppcls.engine.slim import get_pruner, get_quaner class Engine(object): @@ -186,14 +185,12 @@ class Engine(object): self.eval_metric_func = None # build model - self.model = build_model(self.config["Arch"]) + self.model = build_model(self.config) + self.quanted = self.config.get("Slim", {}).get("quant", False) + self.pruned = self.config.get("Slim", {}).get("prune", False) # set @to_static for benchmark, skip this by default. apply_to_static(self.config, self.model) - # for slim - self.pruner = get_pruner(self.config, self.model) - self.quanter = get_quaner(self.config, self.model) - # load_pretrain if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"].startswith("http"): @@ -371,8 +368,8 @@ class Engine(object): model.eval() save_path = os.path.join(self.config["Global"]["save_inference_dir"], "inference") - if self.quanter: - self.quanter.save_quantized_model( + if self.quanted: + model.quanter.save_quantized_model( model.base_model, save_path, input_spec=[ diff --git a/ppcls/engine/slim/__init__.py b/ppcls/engine/slim/__init__.py index bdf067ab..de3d857b 100644 --- a/ppcls/engine/slim/__init__.py +++ b/ppcls/engine/slim/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ppcls.engine.slim.prune import get_pruner -from ppcls.engine.slim.quant import get_quaner +from ppcls.engine.slim.prune import prune_model +from ppcls.engine.slim.quant import quantize_model diff --git a/ppcls/engine/slim/prune.py b/ppcls/engine/slim/prune.py index fc28452c..c0c9d220 100644 --- a/ppcls/engine/slim/prune.py +++ b/ppcls/engine/slim/prune.py @@ -17,7 +17,7 @@ import paddle from ppcls.utils import logger -def get_pruner(config, model): +def prune_model(config, model): if config.get("Slim", False) and config["Slim"].get("prune", False): import paddleslim prune_method_name = config["Slim"]["prune"]["name"].lower() @@ -25,21 +25,20 @@ def get_pruner(config, model): "fpgm", "l1_norm" ], "The prune methods only support 'fpgm' and 'l1_norm'" if prune_method_name == "fpgm": - pruner = paddleslim.dygraph.FPGMFilterPruner( + model.pruner = paddleslim.dygraph.FPGMFilterPruner( model, [1] + config["Global"]["image_shape"]) else: - pruner = paddleslim.dygraph.L1NormFilterPruner( + model.pruner = paddleslim.dygraph.L1NormFilterPruner( model, [1] + config["Global"]["image_shape"]) # prune model - _prune_model(pruner, config, model) + _prune_model(config, model) else: - pruner = None + model.pruner = None - return pruner -def _prune_model(pruner, config, model): +def _prune_model(config, model): from paddleslim.analysis import dygraph_flops as flops logger.info("FLOPs before pruning: {}GFLOPs".format( flops(model, [1] + config["Global"]["image_shape"]) / 1e9)) @@ -53,7 +52,7 @@ def _prune_model(pruner, config, model): ratios = {} for param in params: ratios[param] = config["Slim"]["prune"]["pruned_ratio"] - plan = pruner.prune_vars(ratios, [0]) + plan = model.pruner.prune_vars(ratios, [0]) logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format( flops(model, [1] + config["Global"]["image_shape"]) / 1e9, diff --git a/ppcls/engine/slim/quant.py b/ppcls/engine/slim/quant.py index a6ef8a53..cfac07b6 100644 --- a/ppcls/engine/slim/quant.py +++ b/ppcls/engine/slim/quant.py @@ -40,16 +40,16 @@ QUANT_CONFIG = { } -def get_quaner(config, model): +def quantize_model(config, model): if config.get("Slim", False) and config["Slim"].get("quant", False): from paddleslim.dygraph.quant import QAT assert config["Slim"]["quant"]["name"].lower( ) == 'pact', 'Only PACT quantization method is supported now' QUANT_CONFIG["activation_preprocess_type"] = "PACT" - quanter = QAT(config=QUANT_CONFIG) - quanter.quantize(model) + model.quanted = QAT(config=QUANT_CONFIG) + model.quanted.quantize_model(model) logger.info("QAT model summary:") paddle.summary(model, (1, 3, 224, 224)) else: - quanter = None - return quanter + model.quanted = None + return model.quanted diff --git a/ppcls/static/program.py b/ppcls/static/program.py index 9075a359..7ecf6ee5 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -259,7 +259,7 @@ def build(config, # data_format should be assigned in arch-dict input_image_channel = config["Global"]["image_shape"][ 0] # default as [3, 224, 224] - model = build_model(config["Arch"]) + model = build_model(config) out = model(feeds["data"]) # end of build model -- GitLab