提交 eafcc864 编写于 作者: D dongshuilong

add slim support

上级 4452565b
...@@ -18,12 +18,12 @@ Global: ...@@ -18,12 +18,12 @@ Global:
# for paddleslim # for paddleslim
Slim: Slim:
# for quantalization # for quantalization
quant: # quant:
name: pact # name: pact
## for prune ## for prune
#prune: prune:
# name: fpgm name: fpgm
# prune_ratio: 0.3 pruned_ratio: 0.3
# model architecture # model architecture
Arch: Arch:
...@@ -58,7 +58,7 @@ DataLoader: ...@@ -58,7 +58,7 @@ DataLoader:
dataset: dataset:
name: ImageNetDataset name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/ image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt cls_label_path: ./dataset/ILSVRC2012/train.txt
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
...@@ -89,7 +89,7 @@ DataLoader: ...@@ -89,7 +89,7 @@ DataLoader:
dataset: dataset:
name: ImageNetDataset name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/ image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt cls_label_path: ./dataset/ILSVRC2012/val.txt
transform_ops: transform_ops:
- DecodeImage: - DecodeImage:
to_rgb: True to_rgb: True
......
...@@ -18,10 +18,14 @@ import os ...@@ -18,10 +18,14 @@ import os
import sys import sys
import paddle import paddle
import paddleslim
from paddle.jit import to_static
from paddleslim.analysis import dygraph_flops as flops
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
from ppcls.engine.trainer import Trainer from ppcls.engine.trainer import Trainer
from ppcls.utils import config, logger from ppcls.utils import config, logger
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
...@@ -53,9 +57,12 @@ quant_config = { ...@@ -53,9 +57,12 @@ quant_config = {
class Trainer_slim(Trainer): class Trainer_slim(Trainer):
def __init__(self, config, mode="train"): def __init__(self, config, mode="train"):
super().__init__(config, mode) super().__init__(config, mode)
# self.pact = self.config["Slim"].get("pact", False) pact = self.config["Slim"].get("quant", False)
self.pact = True self.pact = pact.get("name", False) if pact else pact
if self.pact:
if self.pact and str(self.pact.lower()) != 'pact':
raise RuntimeError("The quantization only support 'PACT'!")
if pact:
quant_config["activation_preprocess_type"] = "PACT" quant_config["activation_preprocess_type"] = "PACT"
self.quanter = QAT(config=quant_config) self.quanter = QAT(config=quant_config)
self.quanter.quantize(self.model) self.quanter.quantize(self.model)
...@@ -64,6 +71,31 @@ class Trainer_slim(Trainer): ...@@ -64,6 +71,31 @@ class Trainer_slim(Trainer):
else: else:
self.quanter = None self.quanter = None
prune_config = self.config["Slim"].get("prune", False)
if prune_config:
if prune_config["name"].lower() not in ["fpgm", "l1_norm"]:
raise RuntimeError(
"The prune methods only support 'fpgm' and 'l1_norm'")
else:
logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(self.model, [1] + self.config["Global"][
"image_shape"]) / 1000000))
self.model.eval()
if prune_config["name"].lower() == "fpgm":
self.model.eval()
self.pruner = paddleslim.dygraph.FPGMFilterPruner(
self.model, [1] + self.config["Global"]["image_shape"])
else:
self.pruner = paddleslim.dygraph.L1NormFilterPruner(
self.model, [1] + self.config["Global"]["image_shape"])
self.prune_model()
else:
self.pruner = None
if self.quanter is None and self.pruner is None:
logger.info("Training without slim")
def train(self): def train(self):
super().train() super().train()
if self.config["Global"].get("save_inference_dir", None): if self.config["Global"].get("save_inference_dir", None):
...@@ -86,17 +118,48 @@ class Trainer_slim(Trainer): ...@@ -86,17 +118,48 @@ class Trainer_slim(Trainer):
raise RuntimeError( raise RuntimeError(
"The best_model or pretraine_model should exist to generate inference model" "The best_model or pretraine_model should exist to generate inference model"
) )
save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference")
if self.quanter:
self.quanter.save_quantized_model(
self.model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
else:
model = to_static(
self.model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32',
name="image")
])
paddle.jit.save(model, save_path)
def prune_model(self):
params = []
for sublayer in self.model.sublayers():
for param in sublayer.parameters(include_sublayers=False):
if isinstance(sublayer, paddle.nn.Conv2D):
params.append(param.name)
ratios = {}
for param in params:
ratios[param] = self.config["Slim"]["prune"]["pruned_ratio"]
plan = self.pruner.prune_vars(ratios, [0])
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(self.model, [1] + self.config["Global"]["image_shape"]) /
1000000, plan.pruned_flops))
for param in self.model.parameters():
if "conv2d" in param.name:
logger.info("{}\t{}".format(param.name, param.shape))
assert self.quanter self.model.train()
self.quanter.save_quantized_model(
self.model,
os.path.join(self.config["Global"]["save_inference_dir"],
"inference"),
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册