提交 eafcc864 编写于 作者: D dongshuilong

add slim support

上级 4452565b
......@@ -18,12 +18,12 @@ Global:
# for paddleslim
Slim:
# for quantalization
quant:
name: pact
# quant:
# name: pact
## for prune
#prune:
# name: fpgm
# prune_ratio: 0.3
prune:
name: fpgm
pruned_ratio: 0.3
# model architecture
Arch:
......@@ -58,7 +58,7 @@ DataLoader:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
cls_label_path: ./dataset/ILSVRC2012/train.txt
transform_ops:
- DecodeImage:
to_rgb: True
......@@ -89,7 +89,7 @@ DataLoader:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
cls_label_path: ./dataset/ILSVRC2012/val.txt
transform_ops:
- DecodeImage:
to_rgb: True
......
......@@ -18,10 +18,14 @@ import os
import sys
import paddle
import paddleslim
from paddle.jit import to_static
from paddleslim.analysis import dygraph_flops as flops
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
from paddleslim.dygraph.quant import QAT
from ppcls.engine.trainer import Trainer
from ppcls.utils import config, logger
from ppcls.utils.save_load import load_dygraph_pretrain
......@@ -53,9 +57,12 @@ quant_config = {
class Trainer_slim(Trainer):
def __init__(self, config, mode="train"):
super().__init__(config, mode)
# self.pact = self.config["Slim"].get("pact", False)
self.pact = True
if self.pact:
pact = self.config["Slim"].get("quant", False)
self.pact = pact.get("name", False) if pact else pact
if self.pact and str(self.pact.lower()) != 'pact':
raise RuntimeError("The quantization only support 'PACT'!")
if pact:
quant_config["activation_preprocess_type"] = "PACT"
self.quanter = QAT(config=quant_config)
self.quanter.quantize(self.model)
......@@ -64,6 +71,31 @@ class Trainer_slim(Trainer):
else:
self.quanter = None
prune_config = self.config["Slim"].get("prune", False)
if prune_config:
if prune_config["name"].lower() not in ["fpgm", "l1_norm"]:
raise RuntimeError(
"The prune methods only support 'fpgm' and 'l1_norm'")
else:
logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(self.model, [1] + self.config["Global"][
"image_shape"]) / 1000000))
self.model.eval()
if prune_config["name"].lower() == "fpgm":
self.model.eval()
self.pruner = paddleslim.dygraph.FPGMFilterPruner(
self.model, [1] + self.config["Global"]["image_shape"])
else:
self.pruner = paddleslim.dygraph.L1NormFilterPruner(
self.model, [1] + self.config["Global"]["image_shape"])
self.prune_model()
else:
self.pruner = None
if self.quanter is None and self.pruner is None:
logger.info("Training without slim")
def train(self):
super().train()
if self.config["Global"].get("save_inference_dir", None):
......@@ -86,17 +118,48 @@ class Trainer_slim(Trainer):
raise RuntimeError(
"The best_model or pretraine_model should exist to generate inference model"
)
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference")
if self.quanter:
self.quanter.save_quantized_model(
self.model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
else:
model = to_static(
self.model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32',
name="image")
])
paddle.jit.save(model, save_path)
def prune_model(self):
params = []
for sublayer in self.model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
ratios = {}
for param in params:
ratios[param] = self.config["Slim"]["prune"]["pruned_ratio"]
plan = self.pruner.prune_vars(ratios, [0])
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(self.model, [1] + self.config["Global"]["image_shape"]) /
1000000, plan.pruned_flops))
for param in self.model.parameters():
if "conv2d" in param.name:
logger.info("{}\t{}".format(param.name, param.shape))
assert self.quanter
self.quanter.save_quantized_model(
self.model,
os.path.join(self.config["Global"]["save_inference_dir"],
"inference"),
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
self.model.train()
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册