未验证 提交 6dac07f6 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix quant logic (#1941)

* fix quant logic

* add support for trt+int8 inference

* add support for swin quant

* fix swin and quant

* fix assert info

* fix assert info

* fix log
上级 821e5509
...@@ -49,10 +49,15 @@ class ClsPredictor(Predictor): ...@@ -49,10 +49,15 @@ class ClsPredictor(Predictor):
pid = os.getpid() pid = os.getpid()
size = config["PreProcess"]["transform_ops"][1]["CropImage"][ size = config["PreProcess"]["transform_ops"][1]["CropImage"][
"size"] "size"]
if config["Global"].get("use_int8", False):
precision = "int8"
elif config["Global"].get("use_fp16", False):
precision = "fp16"
else:
precision = "fp32"
self.auto_logger = auto_log.AutoLogger( self.auto_logger = auto_log.AutoLogger(
model_name=config["Global"].get("model_name", "cls"), model_name=config["Global"].get("model_name", "cls"),
model_precision='fp16' model_precision=precision,
if config["Global"]["use_fp16"] else 'fp32',
batch_size=config["Global"].get("batch_size", 1), batch_size=config["Global"].get("batch_size", 1),
data_shape=[3, size, size], data_shape=[3, size, size],
save_path=config["Global"].get("save_log_path", save_path=config["Global"].get("save_log_path",
......
...@@ -42,8 +42,22 @@ class Predictor(object): ...@@ -42,8 +42,22 @@ class Predictor(object):
def create_paddle_predictor(self, args, inference_model_dir=None): def create_paddle_predictor(self, args, inference_model_dir=None):
if inference_model_dir is None: if inference_model_dir is None:
inference_model_dir = args.inference_model_dir inference_model_dir = args.inference_model_dir
params_file = os.path.join(inference_model_dir, "inference.pdiparams") if "inference_int8.pdiparams" in os.listdir(inference_model_dir):
model_file = os.path.join(inference_model_dir, "inference.pdmodel") params_file = os.path.join(inference_model_dir,
"inference_int8.pdiparams")
model_file = os.path.join(inference_model_dir,
"inference_int8.pdmodel")
assert args.get(
"use_fp16", False
) is False, "fp16 mode is not supported for int8 model inference, please set use_fp16 as False during inference."
else:
params_file = os.path.join(inference_model_dir,
"inference.pdiparams")
model_file = os.path.join(inference_model_dir, "inference.pdmodel")
assert args.get(
"use_int8", False
) is False, "int8 mode is not supported for fp32 model inference, please set use_int8 as False during inference."
config = Config(model_file, params_file) config = Config(model_file, params_file)
if args.use_gpu: if args.use_gpu:
...@@ -63,12 +77,18 @@ class Predictor(object): ...@@ -63,12 +77,18 @@ class Predictor(object):
config.disable_glog_info() config.disable_glog_info()
config.switch_ir_optim(args.ir_optim) # default true config.switch_ir_optim(args.ir_optim) # default true
if args.use_tensorrt: if args.use_tensorrt:
precision = Config.Precision.Float32
if args.get("use_int8", False):
precision = Config.Precision.Int8
elif args.get("use_fp16", False):
precision = Config.Precision.Half
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
precision_mode=Config.Precision.Half precision_mode=precision,
if args.use_fp16 else Config.Precision.Float32,
max_batch_size=args.batch_size, max_batch_size=args.batch_size,
workspace_size=1 << 30, workspace_size=1 << 30,
min_subgraph_size=30) min_subgraph_size=30,
use_calib_mode=False)
config.enable_memory_optim() config.enable_memory_optim()
# use zero copy # use zero copy
......
...@@ -33,7 +33,7 @@ from ppcls.arch.distill.afd_attention import LinearTransformStudent, LinearTrans ...@@ -33,7 +33,7 @@ from ppcls.arch.distill.afd_attention import LinearTransformStudent, LinearTrans
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"] __all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
def build_model(config): def build_model(config, mode="train"):
arch_config = copy.deepcopy(config["Arch"]) arch_config = copy.deepcopy(config["Arch"])
model_type = arch_config.pop("name") model_type = arch_config.pop("name")
use_sync_bn = arch_config.pop("use_sync_bn", False) use_sync_bn = arch_config.pop("use_sync_bn", False)
...@@ -44,7 +44,7 @@ def build_model(config): ...@@ -44,7 +44,7 @@ def build_model(config):
if isinstance(arch, TheseusLayer): if isinstance(arch, TheseusLayer):
prune_model(config, arch) prune_model(config, arch)
quantize_model(config, arch) quantize_model(config, arch, mode)
logger.info("The FLOPs and Params of Arch:") logger.info("The FLOPs and Params of Arch:")
try: try:
......
...@@ -52,7 +52,7 @@ from ppcls.arch.backbone.model_zoo.darknet import DarkNet53 ...@@ -52,7 +52,7 @@ from ppcls.arch.backbone.model_zoo.darknet import DarkNet53
from ppcls.arch.backbone.model_zoo.regnet import RegNetX_200MF, RegNetX_4GF, RegNetX_32GF, RegNetY_200MF, RegNetY_4GF, RegNetY_32GF from ppcls.arch.backbone.model_zoo.regnet import RegNetX_200MF, RegNetX_4GF, RegNetX_32GF, RegNetY_200MF, RegNetY_4GF, RegNetY_32GF
from ppcls.arch.backbone.model_zoo.vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT_base_patch16_384, ViT_base_patch32_384, ViT_large_patch16_224, ViT_large_patch16_384, ViT_large_patch32_384 from ppcls.arch.backbone.model_zoo.vision_transformer import ViT_small_patch16_224, ViT_base_patch16_224, ViT_base_patch16_384, ViT_base_patch32_384, ViT_large_patch16_224, ViT_large_patch16_384, ViT_large_patch32_384
from ppcls.arch.backbone.model_zoo.distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384 from ppcls.arch.backbone.model_zoo.distilled_vision_transformer import DeiT_tiny_patch16_224, DeiT_small_patch16_224, DeiT_base_patch16_224, DeiT_tiny_distilled_patch16_224, DeiT_small_distilled_patch16_224, DeiT_base_distilled_patch16_224, DeiT_base_patch16_384, DeiT_base_distilled_patch16_384
from ppcls.arch.backbone.model_zoo.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384 from ppcls.arch.backbone.legendary_models.swin_transformer import SwinTransformer_tiny_patch4_window7_224, SwinTransformer_small_patch4_window7_224, SwinTransformer_base_patch4_window7_224, SwinTransformer_base_patch4_window12_384, SwinTransformer_large_patch4_window7_224, SwinTransformer_large_patch4_window12_384
from ppcls.arch.backbone.model_zoo.cswin_transformer import CSWinTransformer_tiny_224, CSWinTransformer_small_224, CSWinTransformer_base_224, CSWinTransformer_large_224, CSWinTransformer_base_384, CSWinTransformer_large_384 from ppcls.arch.backbone.model_zoo.cswin_transformer import CSWinTransformer_tiny_224, CSWinTransformer_small_224, CSWinTransformer_base_224, CSWinTransformer_large_224, CSWinTransformer_base_384, CSWinTransformer_large_384
from ppcls.arch.backbone.model_zoo.mixnet import MixNet_S, MixNet_M, MixNet_L from ppcls.arch.backbone.model_zoo.mixnet import MixNet_S, MixNet_M, MixNet_L
from ppcls.arch.backbone.model_zoo.rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0 from ppcls.arch.backbone.model_zoo.rexnet import ReXNet_1_0, ReXNet_1_3, ReXNet_1_5, ReXNet_2_0, ReXNet_3_0
......
...@@ -21,8 +21,8 @@ import paddle.nn as nn ...@@ -21,8 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.nn.initializer import TruncatedNormal, Constant from paddle.nn.initializer import TruncatedNormal, Constant
from .vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.arch.backbone.model_zoo.vision_transformer import trunc_normal_, zeros_, ones_, to_2tuple, DropPath, Identity
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = { MODEL_URLS = {
...@@ -589,7 +589,7 @@ class PatchEmbed(nn.Layer): ...@@ -589,7 +589,7 @@ class PatchEmbed(nn.Layer):
return flops return flops
class SwinTransformer(nn.Layer): class SwinTransformer(TheseusLayer):
""" Swin Transformer """ Swin Transformer
A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030 https://arxiv.org/pdf/2103.14030
......
...@@ -40,12 +40,14 @@ QUANT_CONFIG = { ...@@ -40,12 +40,14 @@ QUANT_CONFIG = {
} }
def quantize_model(config, model): def quantize_model(config, model, mode="train"):
if config.get("Slim", False) and config["Slim"].get("quant", False): if config.get("Slim", False) and config["Slim"].get("quant", False):
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
assert config["Slim"]["quant"]["name"].lower( assert config["Slim"]["quant"]["name"].lower(
) == 'pact', 'Only PACT quantization method is supported now' ) == 'pact', 'Only PACT quantization method is supported now'
QUANT_CONFIG["activation_preprocess_type"] = "PACT" QUANT_CONFIG["activation_preprocess_type"] = "PACT"
if mode in ["infer", "export"]:
QUANT_CONFIG['activation_preprocess_type'] = None
model.quanter = QAT(config=QUANT_CONFIG) model.quanter = QAT(config=QUANT_CONFIG)
model.quanter.quantize(model) model.quanter.quantize(model)
logger.info("QAT model summary:") logger.info("QAT model summary:")
......
...@@ -189,7 +189,7 @@ class Engine(object): ...@@ -189,7 +189,7 @@ class Engine(object):
self.eval_metric_func = None self.eval_metric_func = None
# build model # build model
self.model = build_model(self.config) self.model = build_model(self.config, self.mode)
# set @to_static for benchmark, skip this by default. # set @to_static for benchmark, skip this by default.
apply_to_static(self.config, self.model) apply_to_static(self.config, self.model)
...@@ -472,23 +472,19 @@ class Engine(object): ...@@ -472,23 +472,19 @@ class Engine(object):
save_path = os.path.join(self.config["Global"]["save_inference_dir"], save_path = os.path.join(self.config["Global"]["save_inference_dir"],
"inference") "inference")
if model.quanter:
model.quanter.save_quantized_model( model = paddle.jit.to_static(
model.base_model, model,
save_path, input_spec=[
input_spec=[ paddle.static.InputSpec(
paddle.static.InputSpec( shape=[None] + self.config["Global"]["image_shape"],
shape=[None] + self.config["Global"]["image_shape"], dtype='float32')
dtype='float32') ])
]) if hasattr(model.base_model,
"quanter") and model.base_model.quanter is not None:
model.base_model.quanter.save_quantized_model(model,
save_path + "_int8")
else: else:
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + self.config["Global"]["image_shape"],
dtype='float32')
])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info( logger.info(
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"." f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册