提交 6c5d1ebc 编写于 作者: W weishengyu

add pruner and quanter for theseus

上级 0c8a082d
......@@ -26,15 +26,20 @@ from .utils import *
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.engine.slim import prune_model, quantize_model
__all__ = ["build_model", "RecModel", "DistillationModel"]
def build_model(config):
config = copy.deepcopy(config)
model_type = config.pop("name")
arch_config = copy.deepcopy(config["Arch"])
model_type = arch_config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**config)
arch = getattr(mod, model_type)(**arch_config)
if isinstance(arch, TheseusLayer):
prune_model(config, arch)
quantize_model(config, arch)
return arch
......@@ -51,7 +56,7 @@ def apply_to_static(config, model):
return model
class RecModel(nn.Layer):
class RecModel(TheseusLayer):
def __init__(self, **config):
super().__init__()
backbone_config = config["Backbone"]
......
......@@ -16,6 +16,8 @@ class TheseusLayer(nn.Layer):
super(TheseusLayer, self).__init__()
self.res_dict = {}
self.res_name = self.full_name()
self.pruner = None
self.quanter = None
# stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str):
......
......@@ -44,7 +44,6 @@ from ppcls.data import create_operators
from ppcls.engine.train import train_epoch
from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead
from ppcls.engine.slim import get_pruner, get_quaner
class Engine(object):
......@@ -186,14 +185,12 @@ class Engine(object):
self.eval_metric_func = None
# build model
self.model = build_model(self.config["Arch"])
self.model = build_model(self.config)
self.quanted = self.config.get("Slim", {}).get("quant", False)
self.pruned = self.config.get("Slim", {}).get("prune", False)
# set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model)
# for slim
self.pruner = get_pruner(self.config, self.model)
self.quanter = get_quaner(self.config, self.model)
# load_pretrain
if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"):
......@@ -371,8 +368,8 @@ class Engine(object):
model.eval()
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference")
if self.quanter:
self.quanter.save_quantized_model(
if self.quanted:
model.quanter.save_quantized_model(
model.base_model,
save_path,
input_spec=[
......
......@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ppcls.engine.slim.prune import get_pruner
from ppcls.engine.slim.quant import get_quaner
from ppcls.engine.slim.prune import prune_model
from ppcls.engine.slim.quant import quantize_model
......@@ -17,7 +17,7 @@ import paddle
from ppcls.utils import logger
def get_pruner(config, model):
def prune_model(config, model):
if config.get("Slim", False) and config["Slim"].get("prune", False):
import paddleslim
prune_method_name = config["Slim"]["prune"]["name"].lower()
......@@ -25,21 +25,20 @@ def get_pruner(config, model):
"fpgm", "l1_norm"
], "The prune methods only support 'fpgm' and 'l1_norm'"
if prune_method_name == "fpgm":
pruner = paddleslim.dygraph.FPGMFilterPruner(
model.pruner = paddleslim.dygraph.FPGMFilterPruner(
model, [1] + config["Global"]["image_shape"])
else:
pruner = paddleslim.dygraph.L1NormFilterPruner(
model.pruner = paddleslim.dygraph.L1NormFilterPruner(
model, [1] + config["Global"]["image_shape"])
# prune model
_prune_model(pruner, config, model)
_prune_model(config, model)
else:
pruner = None
model.pruner = None
return pruner
def _prune_model(pruner, config, model):
def _prune_model(config, model):
from paddleslim.analysis import dygraph_flops as flops
logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(model, [1] + config["Global"]["image_shape"]) / 1e9))
......@@ -53,7 +52,7 @@ def _prune_model(pruner, config, model):
ratios = {}
for param in params:
ratios[param] = config["Slim"]["prune"]["pruned_ratio"]
plan = pruner.prune_vars(ratios, [0])
plan = model.pruner.prune_vars(ratios, [0])
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(model, [1] + config["Global"]["image_shape"]) / 1e9,
......
......@@ -40,16 +40,16 @@ QUANT_CONFIG = {
}
def get_quaner(config, model):
def quantize_model(config, model):
if config.get("Slim", False) and config["Slim"].get("quant", False):
from paddleslim.dygraph.quant import QAT
assert config["Slim"]["quant"]["name"].lower(
) == 'pact', 'Only PACT quantization method is supported now'
QUANT_CONFIG["activation_preprocess_type"] = "PACT"
quanter = QAT(config=QUANT_CONFIG)
quanter.quantize(model)
model.quanted = QAT(config=QUANT_CONFIG)
model.quanted.quantize_model(model)
logger.info("QAT model summary:")
paddle.summary(model, (1, 3, 224, 224))
else:
quanter = None
return quanter
model.quanted = None
return model.quanted
......@@ -259,7 +259,7 @@ def build(config,
# data_format should be assigned in arch-dict
input_image_channel = config["Global"]["image_shape"][
0] # default as [3, 224, 224]
model = build_model(config["Arch"])
model = build_model(config)
out = model(feeds["data"])
# end of build model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册