未验证 提交 51c3a342 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1540 from weisy11/refor_quant

Refor quant
...@@ -26,15 +26,20 @@ from .utils import * ...@@ -26,15 +26,20 @@ from .utils import *
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.arch.slim import prune_model, quantize_model
__all__ = ["build_model", "RecModel", "DistillationModel"] __all__ = ["build_model", "RecModel", "DistillationModel"]
def build_model(config): def build_model(config):
config = copy.deepcopy(config) arch_config = copy.deepcopy(config["Arch"])
model_type = config.pop("name") model_type = arch_config.pop("name")
mod = importlib.import_module(__name__) mod = importlib.import_module(__name__)
arch = getattr(mod, model_type)(**config) arch = getattr(mod, model_type)(**arch_config)
if isinstance(arch, TheseusLayer):
prune_model(config, arch)
quantize_model(config, arch)
return arch return arch
...@@ -51,7 +56,7 @@ def apply_to_static(config, model): ...@@ -51,7 +56,7 @@ def apply_to_static(config, model):
return model return model
class RecModel(nn.Layer): class RecModel(TheseusLayer):
def __init__(self, **config): def __init__(self, **config):
super().__init__() super().__init__()
backbone_config = config["Backbone"] backbone_config = config["Backbone"]
......
...@@ -16,6 +16,8 @@ class TheseusLayer(nn.Layer): ...@@ -16,6 +16,8 @@ class TheseusLayer(nn.Layer):
super(TheseusLayer, self).__init__() super(TheseusLayer, self).__init__()
self.res_dict = {} self.res_dict = {}
self.res_name = self.full_name() self.res_name = self.full_name()
self.pruner = None
self.quanter = None
# stop doesn't work when stop layer has a parallel branch. # stop doesn't work when stop layer has a parallel branch.
def stop_after(self, stop_layer_name: str): def stop_after(self, stop_layer_name: str):
......
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ppcls.engine.slim.prune import get_pruner from ppcls.arch.slim.prune import prune_model
from ppcls.engine.slim.quant import get_quaner from ppcls.arch.slim.quant import quantize_model
...@@ -17,7 +17,7 @@ import paddle ...@@ -17,7 +17,7 @@ import paddle
from ppcls.utils import logger from ppcls.utils import logger
def get_pruner(config, model): def prune_model(config, model):
if config.get("Slim", False) and config["Slim"].get("prune", False): if config.get("Slim", False) and config["Slim"].get("prune", False):
import paddleslim import paddleslim
prune_method_name = config["Slim"]["prune"]["name"].lower() prune_method_name = config["Slim"]["prune"]["name"].lower()
...@@ -25,21 +25,20 @@ def get_pruner(config, model): ...@@ -25,21 +25,20 @@ def get_pruner(config, model):
"fpgm", "l1_norm" "fpgm", "l1_norm"
], "The prune methods only support 'fpgm' and 'l1_norm'" ], "The prune methods only support 'fpgm' and 'l1_norm'"
if prune_method_name == "fpgm": if prune_method_name == "fpgm":
pruner = paddleslim.dygraph.FPGMFilterPruner( model.pruner = paddleslim.dygraph.FPGMFilterPruner(
model, [1] + config["Global"]["image_shape"]) model, [1] + config["Global"]["image_shape"])
else: else:
pruner = paddleslim.dygraph.L1NormFilterPruner( model.pruner = paddleslim.dygraph.L1NormFilterPruner(
model, [1] + config["Global"]["image_shape"]) model, [1] + config["Global"]["image_shape"])
# prune model # prune model
_prune_model(pruner, config, model) _prune_model(config, model)
else: else:
pruner = None model.pruner = None
return pruner
def _prune_model(pruner, config, model): def _prune_model(config, model):
from paddleslim.analysis import dygraph_flops as flops from paddleslim.analysis import dygraph_flops as flops
logger.info("FLOPs before pruning: {}GFLOPs".format( logger.info("FLOPs before pruning: {}GFLOPs".format(
flops(model, [1] + config["Global"]["image_shape"]) / 1e9)) flops(model, [1] + config["Global"]["image_shape"]) / 1e9))
...@@ -53,7 +52,7 @@ def _prune_model(pruner, config, model): ...@@ -53,7 +52,7 @@ def _prune_model(pruner, config, model):
ratios = {} ratios = {}
for param in params: for param in params:
ratios[param] = config["Slim"]["prune"]["pruned_ratio"] ratios[param] = config["Slim"]["prune"]["pruned_ratio"]
plan = pruner.prune_vars(ratios, [0]) plan = model.pruner.prune_vars(ratios, [0])
logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format( logger.info("FLOPs after pruning: {}GFLOPs; pruned ratio: {}".format(
flops(model, [1] + config["Global"]["image_shape"]) / 1e9, flops(model, [1] + config["Global"]["image_shape"]) / 1e9,
......
...@@ -40,16 +40,16 @@ QUANT_CONFIG = { ...@@ -40,16 +40,16 @@ QUANT_CONFIG = {
} }
def get_quaner(config, model): def quantize_model(config, model):
if config.get("Slim", False) and config["Slim"].get("quant", False): if config.get("Slim", False) and config["Slim"].get("quant", False):
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
assert config["Slim"]["quant"]["name"].lower( assert config["Slim"]["quant"]["name"].lower(
) == 'pact', 'Only PACT quantization method is supported now' ) == 'pact', 'Only PACT quantization method is supported now'
QUANT_CONFIG["activation_preprocess_type"] = "PACT" QUANT_CONFIG["activation_preprocess_type"] = "PACT"
quanter = QAT(config=QUANT_CONFIG) model.quanter = QAT(config=QUANT_CONFIG)
quanter.quantize(model) model.quanter.quantize(model)
logger.info("QAT model summary:") logger.info("QAT model summary:")
paddle.summary(model, (1, 3, 224, 224)) paddle.summary(model, (1, 3, 224, 224))
else: else:
quanter = None model.quanter = None
return quanter return
...@@ -29,7 +29,7 @@ from ppcls.utils import logger ...@@ -29,7 +29,7 @@ from ppcls.utils import logger
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.arch import build_model, RecModel, DistillationModel from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
from ppcls.arch import apply_to_static from ppcls.arch import apply_to_static
from ppcls.loss import build_loss from ppcls.loss import build_loss
from ppcls.metric import build_metrics from ppcls.metric import build_metrics
...@@ -44,7 +44,6 @@ from ppcls.data import create_operators ...@@ -44,7 +44,6 @@ from ppcls.data import create_operators
from ppcls.engine.train import train_epoch from ppcls.engine.train import train_epoch
from ppcls.engine import evaluation from ppcls.engine import evaluation
from ppcls.arch.gears.identity_head import IdentityHead from ppcls.arch.gears.identity_head import IdentityHead
from ppcls.engine.slim import get_pruner, get_quaner
class Engine(object): class Engine(object):
...@@ -186,14 +185,10 @@ class Engine(object): ...@@ -186,14 +185,10 @@ class Engine(object):
self.eval_metric_func = None self.eval_metric_func = None
# build model # build model
self.model = build_model(self.config["Arch"]) self.model = build_model(self.config)
# set @to_static for benchmark, skip this by default. # set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model) apply_to_static(self.config, self.model)
# for slim
self.pruner = get_pruner(self.config, self.model)
self.quanter = get_quaner(self.config, self.model)
# load_pretrain # load_pretrain
if self.config["Global"]["pretrained_model"] is not None: if self.config["Global"]["pretrained_model"] is not None:
if self.config["Global"]["pretrained_model"].startswith("http"): if self.config["Global"]["pretrained_model"].startswith("http"):
...@@ -371,8 +366,8 @@ class Engine(object): ...@@ -371,8 +366,8 @@ class Engine(object):
model.eval() model.eval()
save_path = os.path.join(self.config["Global"]["save_inference_dir"], save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference") "inference")
if self.quanter: if model.quanter:
self.quanter.save_quantized_model( model.quanter.save_quantized_model(
model.base_model, model.base_model,
save_path, save_path,
input_spec=[ input_spec=[
...@@ -391,7 +386,7 @@ class Engine(object): ...@@ -391,7 +386,7 @@ class Engine(object):
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
class ExportModel(nn.Layer): class ExportModel(TheseusLayer):
""" """
ExportModel: add softmax onto the model ExportModel: add softmax onto the model
""" """
......
...@@ -259,7 +259,7 @@ def build(config, ...@@ -259,7 +259,7 @@ def build(config,
# data_format should be assigned in arch-dict # data_format should be assigned in arch-dict
input_image_channel = config["Global"]["image_shape"][ input_image_channel = config["Global"]["image_shape"][
0] # default as [3, 224, 224] 0] # default as [3, 224, 224]
model = build_model(config["Arch"]) model = build_model(config)
out = model(feeds["data"]) out = model(feeds["data"])
# end of build model # end of build model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册