未验证 提交 dede4a80 编写于 作者: C ceci3 提交者: GitHub

update auto compressor (#1065)

* update dynabert

* update

* update

* update

* add transformer_pruner to autocompress

* add auto

* update transformer pruner

* ptq-hpo support eval function

* update

* update

* update

* update docs

* add comment

* fix patterns
上级 d0e11f4b
......@@ -27,7 +27,6 @@ add_arg('model_dir', str, None, "inference model di
add_arg('model_filename', str, None, "inference model filename.")
add_arg('params_filename', str, None, "inference params filename.")
add_arg('save_dir', str, None, "directory to save compressed model.")
add_arg('devices', str, 'gpu', "which device used to compress.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('task', str, 'sst-2', "task name in glue.")
add_arg('config_path', str, None, "path of compression strategy config.")
......@@ -161,6 +160,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
def apply_decay_param_fun(name):
if name.find("bias") > -1:
return True
elif name.find("b_0") > -1:
return True
elif name.find("norm") > -1:
return True
else:
......@@ -189,7 +190,9 @@ if __name__ == '__main__':
strategy_config=compress_config,
train_config=train_config,
train_dataloader=train_dataloader,
eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else eval_dataloader,
devices=args.devices)
eval_callback=eval_function
if 'HyperParameterOptimization' not in compress_config else
eval_dataloader,
eval_dataloader=eval_dataloader)
ac.compress()
......@@ -22,7 +22,6 @@ add_arg('model_dir', str, None, "inference model di
add_arg('model_filename', str, None, "inference model filename.")
add_arg('params_filename', str, None, "inference params filename.")
add_arg('save_dir', str, None, "directory to save compressed model.")
add_arg('devices', str, 'gpu', "which device used to compress.")
add_arg('batch_size', int, 1, "train batch size.")
add_arg('config_path', str, None, "path of compression strategy config.")
add_arg('data_dir', str, None, "path of dataset")
......@@ -37,10 +36,13 @@ def reader_wrapper(reader):
return gen
def eval_reader(data_dir, batch_size):
val_reader = paddle.batch(reader.val(data_dir=data_dir), batch_size=batch_size)
val_reader = paddle.batch(
reader.val(data_dir=data_dir), batch_size=batch_size)
return val_reader
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
val_reader = eval_reader(data_dir, batch_size=1)
image = paddle.static.data(
......@@ -104,7 +106,7 @@ if __name__ == '__main__':
strategy_config=compress_config,
train_config=train_config,
train_dataloader=train_dataloader,
eval_callback=eval_function if 'HyperParameterOptimization' not in compress_config else reader_wrapper(eval_reader(data_dir, 64)),
devices=args.devices)
eval_callback=eval_function,
eval_dataloader=eader_wrapper(eval_reader(data_dir, 64)))
ac.compress()
......@@ -147,8 +147,7 @@ def main(args):
strategy_config=compress_config,
train_config=train_config,
train_dataloader=train_loader,
eval_callback=eval_func,
devices=args.devices)
eval_callback=eval_func)
ac.compress()
......
......@@ -15,20 +15,26 @@ AutoCompression
- **model_filename(str)** - 需要压缩的推理模型文件名称。
- **params_filename(str)** - 需要压缩的推理模型参数文件名称。
- **save_dir(str)** - 压缩后模型的所保存的目录。
- **strategy_config(dict)** - 使用的压缩策略。字典的关键字必须在: ``Quantization`` (量化配置, 可配置的参数参考 `<>`_ ), ``Distillation`` (蒸馏配置, 可配置的参数参考 `<>`_),
``MultiTeacherDistillation`` (多teacher蒸馏配置, 可配置的参数参考 `<>`_), ``HyperParameterOptimization`` (超参搜索配置, 可配置的参数参考 `<>`_),
``Prune`` (剪枝配置, 可配置的参数参考 `<>`_), ``UnstructurePrune`` (非结构化稀疏配置, 可配置的参数参考 `<>`_) 之间选择。目前关键字只支持以下几种配置:
- **train_dataloader(paddle.io.DataLoader)** - 训练数据迭代器。注意:如果选择离线量化超参搜索策略的话, ``train_dataloader`` 和 ``eval_callback`` 设置相同的数据读取即可。
- **train_config(dict)** - 训练配置。可以配置的参数请参考: `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L103>`_ 。注意:如果选择离线量化超参搜索策略的话, ``train_config`` 直接设置为 ``None`` 即可。
- **strategy_config(dict, list(dict), 可选)** - 使用的压缩策略,可以通过设置多个单种策略来并行使用这些压缩方式。字典的关键字必须在:
``Quantization`` (量化配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L24>`_ ),
``Distillation`` (蒸馏配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L39>`_),
``MultiTeacherDistillation`` (多teacher蒸馏配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L56>`_),
``HyperParameterOptimization`` (超参搜索配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L73>`_),
``Prune`` (剪枝配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L82>`_),
``UnstructurePrune`` (非结构化稀疏配置, 可配置的参数参考 `<https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/auto_compression/strategy_config.py#L91>`_) 之间选择。
目前关键字只支持以下几种组合策略或者单策略配置:
1) ``Quantization`` & ``HyperParameterOptimization``: 离线量化超参搜索策略;
2) ``Quantization`` & ``Distillation``: 量化训练和蒸馏的策略;
3) ``Prune`` & ``Distillation``: 结构化剪枝和蒸馏的策略;
4) ``UnstructurePrune`` & ``Distillation``: 非结构化稀疏和蒸馏的策略;
5) ``Distillation``: 单独单蒸馏策略;
6) ``MultiTeacherDistillation``: 多teacher蒸馏策略。
每种配置的具体参数信息可以参考:。
- **train_config(dict)** - 训练配置。可以配置的参数请参考: `<>`_ 。注意:如果选择离线量化超参搜索策略的话, ``train_config`` 直接设置为 ``None`` 即可。
- **train_dataloader(paddle.io.DataLoader)** - 训练数据迭代器。注意:如果选择离线量化超参搜索策略的话, ``train_dataloader`` 和 ``eval_callback`` 设置相同的数据读取即可。
- **eval_callback(paddle.io.DataLoader|function)** - eval回调函数和测试数据迭代器之间必须传入一个,如果传入回调函数,则使用回调函数判断模型训练情况, 回调函数的写法参考: `<>`_ 。如果传入测试数据迭代器,则使用 ``EMD`` 距离判断压缩前后模型之间的差别,目前仅支持离线量化超参搜索使用这种方式判断压缩前后模型的压缩。
- **devices(str)** - 确定特定的运行设备,可以是 ``cpu`` , ``gpu``, ``npu``, ``gpu:x``, ``xpu:x``, ``npu:x`` 。其中, ``x`` 是GPU, XPU 或者是NPU的编号。默认: ``gpu`` 。
设置为None的话会自动的选择策略去做压缩。默认:None。
- **eval_callback(function, 可选)** - eval回调函数,使用回调函数判断模型训练情况, 回调函数的写法参考: `<//github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/api_cn/static/auto-compression/custom_function.rst>`_ 。 ``eval_callback`` 和 ``eval_dataloader`` 不能都设置为None。默认:None。
- **eval_dataloader(paddle.io.Dataloader, 可选)** - 如果传入测试数据迭代器,则使用 ``EMD`` 距离判断压缩前后模型之间的差别,目前仅支持离线量化超参搜索使用这种方式判断压缩前后模型的压缩。
- **deploy_hardware(str, 可选)** - 压缩后模型的部署硬件。默认: ``gpu`` 。
**返回:** 一个AutoCompression类的实例。
......@@ -37,28 +43,49 @@ AutoCompression
```shell
import paddle
from paddleslim.auto_compression import AutoCompression
default_qat_config = {
"quantize_op_types": ["conv2d", "depthwise_conv2d", "mul"],
"weight_bits": 8,
"activation_bits": 8,
"is_full_quantize": False,
"not_quant_pattern": ["skip_quant"],
}
default_distill_config = {
"distill_loss": args.distill_loss,
"distill_node_pair": args.distill_node_pair,
"distill_lambda": args.distill_lambda,
"teacher_model_dir": args.teacher_model_dir,
"teacher_model_filename": args.teacher_model_filename,
"teacher_params_filename": args.teacher_params_filename,
}
train_dataloader = Cifar10(mode='train')
eval_dataloader = Cifar10(mode='eval')
ac = AutoCompression(model_path, model_filename, params_filename, save_dir, \
strategy_config="Quantization": Quantization(**default_ptq_config),
"HyperParameterOptimization": HyperParameterOptimization(**default_hpo_config)}, \
train_config=None, train_dataloader=train_dataloader, eval_callback=eval_dataloader,devices='gpu')
```
......@@ -107,6 +134,7 @@ Quantization
**参数:**
- **quantize_op_types(list[str])** - 需要进行量化的 op 类型。
- **weight_quantize_type(str)** - 参数量化方式,可选: ['channel_wise_abs_max', 'abs_max']。
- **weight_bits(int)** - 参数量化bit数。
- **activation_bits(int)** - 激活量化bit数。
- **is_full_quantize(bool)** - 是否量化所有可支持op类型。
......
......@@ -28,7 +28,7 @@
1.2 输出
##########
回调函数必须有1个输
回调函数必须有1个输
**result(float)**: 模型的计算指标,仅返回最重要的指标即可,返回的指标用来判断是否数据读取是否正确,和训练过程中是否达到了设定的优化目标。
......
......@@ -16,10 +16,11 @@ from __future__ import absolute_import
from .compressor import *
from .strategy_config import *
from .config_helpers import *
from .utils import *
__all__ = [
"AutoCompression", "Quantization", "Distillation",
"MultiTeacherDistillation", "HyperParameterOptimization", "Prune",
"UnstructurePrune", "ProgramInfo", "TrainConfig", "save_config",
"load_config"
"load_config", "predict_compressed_model"
]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import platform
from ..common import get_logger
from .utils.predict import predict_compressed_model
from .strategy_config import *
_logger = get_logger(__name__, level=logging.INFO)
__all__ = [
"prepare_strategy", "create_strategy_config", "get_final_quant_config"
]
### config tester to test the loss of quant_post
hpo_config_tester = {
"ptq_algo": ["avg", "mse", "KL"],
"weight_quantize_type": ['channel_wise_abs_max', 'abs_max'],
"bias_correct": [False],
"batch_num": [2, 3],
"max_quant_count": 1,
}
### default hpo config
default_hpo_config = {
"ptq_algo": ["KL", "hist", "avg", "mse"],
"weight_quantize_type": ['channel_wise_abs_max', 'abs_max'],
"bias_correct": [True, False],
"hist_percent": [0.98, 0.999],
"batch_num": [10, 30],
"max_quant_count": 20,
}
### default quant config, can be used by ptq&hpo and qat&distillation
default_quant_config = {
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul', 'matmul'],
'weight_bits': 8,
'activation_bits': 8
}
EXPERIENCE_STRATEGY_WITHOUT_LOSS = [
'sparse_0.75_fp32', 'prune_0.3_fp32', 'origin_int8', 'sparse_0.75_int8',
'prune_0.3_int8'
]
MAGIC_SPARSE_RATIO = 0.75
### TODO: 0.03 threshold maybe not suitable, need to check
MAGIC_EMD_DISTANCE = 0.03
DEFAULT_TRANSFORMER_STRATEGY = 'prune_0.25_int8'
DEFAULT_STRATEGY = 'origin_int8'
DEFAULT_QUANT_SPEEDUP = 0.7
def create_strategy_config(strategy_str, model_type):
""" create config according to string"""
tmp_s = strategy_str.split('_')
configs = []
dis_config = Distillation()
if len(tmp_s) == 3:
tmp_s[0] = tmp_s[0].replace('prune', 'Prune')
tmp_s[0] = tmp_s[0].replace('sparse', 'UnstructurePrune')
### TODO(ceci3): auto choose prune algo
default_prune_config = {
'pruned_ratio': float(tmp_s[1]),
'prune_algo': 'prune',
'criterion': 'l1_norm'
}
if model_type == 'transformer' and tmp_s[0] == 'Prune':
default_prune_config['prune_algo'] = 'transformer_pruner'
prune_config = eval(tmp_s[0])(**default_prune_config)
configs.append({tmp_s[0]: prune_config, 'Distillation': dis_config})
### TODO(ceci3): support skip some layer and full quant
if tmp_s[-1] == 'int8':
### only platform is linux can use smac to do hyperparameter optimization
### choose quant_aware to do quantization in other platform
if platform.system().lower() == 'linux':
quant_config = Quantization(**default_quant_config)
hpo_config = HyperParameterOptimization(**hpo_config_tester)
configs.append({
'Quantization': quant_config,
'HyperParameterOptimization': hpo_config
})
else:
quant_config = Quantization(**default_quant_config)
dis_config = Distillation()
configs.append({
'Quantization': quant_config,
'Distillation': dis_config
})
return configs
def prepare_strategy(model_dir,
model_filename,
params_filename,
target_speedup=None,
deploy_hardware=None,
model_type=None):
""" prepare compression config automatically """
final_strategy = None
### use hardware latency tabel if support
if deploy_hardware is not None:
compressed_time_dict = predict_compressed_model(
model_dir,
model_filename,
params_filename,
hardware=deploy_hardware)
baseline = compressed_time_dict['origin_fp32']
speedup_ratio = {}
for strategy, latency in compressed_time_dict.items():
speedup_ratio[strategy] = 1.0 - float(latency) / baseline
sorted_speedup_ratio = sorted(speedup_ratio.items(), key=lambda x: x[1])
### if target speedup is None, choose strategy by experience.
if target_speedup is None:
max_speedup = -1.0
for s in EXPERIENCE_STRATEGY_WITHOUT_LOSS:
if s not in speedup_ratio:
_logger.info(f"cannot get the speed up of strategy {s}")
continue
if speedup_ratio[s] > max_speedup:
max_speedup = speedup_ratio[s]
final_strategy = s
else:
candidate_s = []
pre_s = None
for strategy, ratio in sorted_speedup_ratio:
if abs(ratio - target_speedup) <= 0.1:
candidate_s.append(strategy)
### if there is no strategy satisfy target speedup
### choose the most recent speedup
if ratio > target_speedup and len(candidate_s) == 0:
if pre_s is not None:
candidate_s.append(pre_s)
candidate_s.append(strategy)
pre_s = strategy
if 'origin_int8' in candidate_s:
final_strategy = candidate_s
else:
candidate_s = sorted(candidate_s, key=lambda x: x.split('_')[1])
for c in candidate_s:
if c.startswith('sparse') and float(c.split('_')[
1]) <= MAGIC_SPARSE_RATIO:
final_strategy = c
if final_strategy is None:
final_strategy = candidate_s[0]
### if deploy_hardware is not None
else:
### default speedup ratio of quantization is 70% compare to fp32
### TODO(ceci3): full quant or skip some layer later
if target_speedup is None:
if model_type == 'transformer':
final_strategy = DEFAULT_TRANSFORMER_STRATEGY
else:
final_strategy = DEFAULT_STRATEGY
elif target_speedup > DEFAULT_QUANT_SPEEDUP:
prune_ratio = target_speedup - DEFAULT_QUANT_SPEEDUP
if prune_ratio > 1.0:
raise NotImplementedError(
"target_speedup {} is improper".format(target_speedup))
final_strategy = 'prune_{}_int8'.format(str(prune_ratio))
else:
raise NotImplementedError("target_speedup {} is improper".format(
target_speedup))
strategy_config = create_strategy_config(final_strategy, model_type)
return strategy_config
def get_final_quant_config(ptq_loss):
""" transform quantization tester config to real quantization config """
if ptq_loss <= MAGIC_EMD_DISTANCE:
quant_config = Quantization(**default_quant_config)
hpo_config = HyperParameterOptimization(**default_hpo_config)
configs = [{
'Quantization': quant_config,
'HyperParameterOptimization': hpo_config
}]
else:
quant_config = Quantization(**default_quant_config)
dis_config = Distillation()
configs = [{'Quantization': quant_config, 'Distillation': dis_config}]
return configs
if __name__ == '__main__':
create_strategy_config('sparse_0.75_int8', 'transformer')
......@@ -13,6 +13,7 @@
# limitations under the License.
import logging
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.optimizer as optimizer
......@@ -96,21 +97,21 @@ def _load_program_and_merge(executor,
params_filename,
teacher_idx=None,
feed_target_names=None):
scope = paddle.static.global_scope()
new_scope = paddle.static.Scope()
print(model_dir, model_filename, params_filename)
try:
with paddle.static.scope_guard(new_scope):
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \
dirname=model_dir, \
model_filename=model_filename, \
params_filename=params_filename, \
executor=executor)
dirname=model_dir, \
model_filename=model_filename, \
params_filename=params_filename, \
executor=executor)
except:
with paddle.static.scope_guard(new_scope):
[teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \
path_prefix=model_dir, \
executor=executor)
path_prefix=model_dir, \
executor=executor)
_remove_fetch_node(teacher_program)
......@@ -150,7 +151,8 @@ def build_distill_program(executor,
train_config,
train_program_info=None,
pruner=None,
dist_strategy=None):
dist_strategy=None,
default_distill_node_pair=None):
"""build distill program with infermodel"""
startup_program = paddle.static.Program()
if train_program_info is None:
......@@ -253,7 +255,7 @@ def build_distill_program(executor,
**train_config['amp_config'])
distill_loss, losses = _parse_distill_loss(
config['distill_node_pair'],
config.get('distill_node_pair') or default_distill_node_pair,
config.get('distill_loss') or
'l2_loss', ### default loss is l2_loss
config.get('distill_lambda') or 1.0) ### default lambda is 1.0
......@@ -324,7 +326,28 @@ def build_quant_program(executor, place, config, train_program_info,
return train_program_info, test_program_info, config
def build_prune_program(executor, place, config, train_program_info, strategy):
def _get_label_info(dataloader, feed_target_names):
label_info = {}
for data in dataloader():
for key, value in data[0].items():
if key in feed_target_names:
continue
label_info['name'] = key
label_info['dtype'] = np.array(value).dtype
label_info['shape'] = list(np.array(value).shape)
label_info['shape'][0] = -1
break
break
return label_info
def build_prune_program(executor,
place,
config,
train_program_info,
strategy,
patterns,
eval_dataloader=None):
if 'unstructure' in strategy:
from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
if config["prune_strategy"] is None:
......@@ -340,16 +363,16 @@ def build_prune_program(executor, place, config, train_program_info, strategy):
pruner = GMPUnstructuredPruner(
train_program_info.program,
ratio=config['pruned_ratio'],
threshold=config['threshold'],
prune_params_type=config['prune_params_type'],
place=place,
local_sparsity=config['local_sparsity'],
config=config['gmp_config'])
configs=config['gmp_config'])
else:
if config['prune_algo'] == 'prune':
from ..prune import Pruner
pruner = Pruner(config["criterion"])
params = []
### TODO(ceci3): set default prune weight
for param in train_program_info.program.global_block(
).all_parameters():
if config[
......@@ -369,6 +392,7 @@ def build_prune_program(executor, place, config, train_program_info, strategy):
from paddle.static import sparsity
pruner = sparsity
excluded_params_name = []
### TODO(ceci3): set default prune weight
for param in train_program_info.program.global_block(
).all_parameters():
if config[
......@@ -377,6 +401,24 @@ def build_prune_program(executor, place, config, train_program_info, strategy):
excluded_params_name.append(param.name)
pruner.set_excluded_layers(train_program_info.program,
excluded_params_name)
elif config['prune_algo'] == 'transformer_pruner':
from .transformer_pruner import TransformerPruner
assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader"
label_info = _get_label_info(eval_dataloader,
train_program_info.feed_target_names)
assert len(label_info) != 0, \
"maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader"
pruner = TransformerPruner(
executor,
place,
train_program_info.program,
patterns,
label_info,
width_mult=(1.0 - config['pruned_ratio']),
dataloader=eval_dataloader,
fetch_targets=train_program_info.fetch_targets)
pruned_program = pruner.prune()
train_program_info.program = pruned_program
else:
raise NotImplementedError(
"prune_algo must be choice in [\"prune\", \"asp\"], {} is not support".
......
......@@ -94,7 +94,7 @@ UnstructurePrune = namedtuple("UnstructurePrune", [
"prune_strategy",
"prune_mode",
"threshold",
"prune_ratio",
"pruned_ratio",
"gmp_config",
"prune_params_type",
"local_sparsity",
......@@ -121,10 +121,10 @@ TrainConfig.__new__.__defaults__ = (None, ) * len(TrainConfig._fields)
def merge_config(*args):
fields = tuple()
fields = set()
cfg = dict()
for arg in args:
fields += arg._fields
fields = fields.union(arg._fields)
cfg.update(dict(arg._asdict()))
MergeConfig = namedtuple("MergeConfig", fields)
return MergeConfig(**cfg)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import paddle
from ..core import GraphWrapper
from ..common import get_logger
from ..common.recover_program import recover_inference_program
from ..common.transformer_pattern import preprocess_transformer_patterns
_logger = get_logger(__name__, level=logging.INFO)
global_idx = 0
### start to create trainable program with head mask
def _feed_op_num(program):
""" Get the numeber of feed op """
num = 0
for block in program.blocks:
ops = list(block.ops)
for op in ops:
if op.type == "feed":
num += 1
return num
def find_next_ops(block, var_name):
"""
Find all followed ops for the input variable.
"""
res_ops = []
for op in block.ops:
if var_name in op.input_arg_names:
res_ops.append(op)
return res_ops
def insert_eltmul_op(block, op, head_mask, block_num):
""" Insert elementwise mul op to matmul input_mask and head_mask to program"""
op_idx = block.ops.index(op)
var_name = op.output_arg_names
for var_name in op.output_arg_names:
next_op = find_next_ops(block, var_name)
score_name = var_name
if len(next_op) > 0:
break
next_op = next_op[0]
### start to insert matmul op
score = block.var(score_name)
matmul_out_var = block.create_var(
type=score.type,
name="{}_eltmul_mask".format(score.name),
shape=score.shape,
dtype=score.dtype)
mask = slice_op(block, block_num, head_mask, op_idx + 1)
inputs = {"X": score, "Y": mask}
outputs = {"Out": matmul_out_var}
block._insert_op(
op_idx + 2, type='elementwise_mul', inputs=inputs, outputs=outputs)
next_op_new_input = matmul_out_var.name
next_op._rename_input(score_name, next_op_new_input)
def fill_constant_op(block,
op_idx,
shape,
value,
force_cpu=False,
out=None,
stop_gradient=True):
""" Insert fill_constant op to program"""
block._insert_op(
op_idx,
type='fill_constant',
outputs={'Out': out},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = stop_gradient
return out
def unsqueeze_op(block, axis, inputs, op_idx):
""" Insert unsqueeze op to program"""
out_name = inputs.name
out_shape = list(inputs.shape)
out_shape.insert(axis, 1)
global global_idx
out = block.create_var(
name='{}.unsqueeze_out.tmp_{}'.format(out_name, global_idx),
shape=out_shape,
dtype=inputs.dtype)
global_idx += 1
block._insert_op(
op_idx,
type='unsqueeze',
inputs={'X': inputs},
outputs={'Out': out},
attrs={"axes": [axis]})
return out
def feed_op(block, op_idx, out):
""" Insert feed op to program"""
feed_var = block.var('feed')
block._prepend_op(
op_idx,
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': op_idx})
return out
def slice_op(block, axis, inputs, op_idx):
""" Insert slice op to program"""
out_name = inputs.name
out_shape = list(inputs.shape)
out_shape.pop(0)
global global_idx
out = block.create_var(
name='{}.slice_out.tmp_{}'.format(out_name, global_idx),
shape=out_shape,
dtype=inputs.dtype)
global_idx += 1
attrs = {
"axes": [0],
"starts": [axis],
"ends": [axis + 1],
"decrease_axis": [0]
}
block._insert_op(
op_idx,
type='slice',
inputs={'Input': inputs},
attrs=attrs,
outputs={'Out': out})
return out
def softmax_with_cross_entropy_op(block, logits, labels):
""" Insert softmax_with_cross_entropy op to program"""
global global_idx
softmax = block.create_var(
name='{}.sce.softmax_tmp_{}'.format(logits.name, global_idx),
shape=logits.shape,
dtype=logits.dtype)
loss = block.create_var(
name='{}.sce.loss_tmp_{}'.format(logits.name, global_idx),
shape=logits.shape,
dtype=logits.dtype)
global_idx += 1
attrs = {
'soft_label': False,
'ignore_index': -100,
'numeric_stable_mode': True,
'axis': -1
}
inputs = {'Logits': logits, 'Label': labels}
outputs = {'Softmax': softmax, 'Loss': loss}
block.append_op(
type='softmax_with_cross_entropy',
inputs=inputs,
outputs=outputs,
attrs=attrs)
return loss, softmax
def mean_op(block, inputs, axis=None, keepdim=False):
""" Insert mean op to program"""
global global_idx
if isinstance(axis, int):
axis = [axis]
reduce_all = True if axis is None \
or len(axis)==0 \
or len(axis) == len(inputs.shape) else False
if axis is None or len(axis) == 0:
axis = [0]
if reduce_all == True:
out_shape = [1]
else:
out_shape = list(inputs.shape)
for idx in sorted(axis, reverse=True):
out_shape.pop(idx)
out = block.create_var(
name='{}.mean_tmp_{}'.format(inputs.name, global_idx),
shape=out_shape,
dtype=inputs.dtype)
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
block.append_op(
type='reduce_mean',
inputs={'X': inputs},
outputs={'Out': out},
attrs=attrs)
return out
class TransformerPruner:
def __init__(self, exe, places, inference_program, patterns, label_info,
width_mult, fetch_targets, dataloader):
self.exe = exe
self.places = places
self.inference_program = inference_program
self.graph = GraphWrapper(inference_program)
self.patterns = patterns
self.label_info = label_info
self.width_mult = width_mult
self.fetch_targets = fetch_targets
self.dataloader = dataloader
self.scope = paddle.static.global_scope()
input_mask_op, layer_num, head_num, mha_weight, ffn_weight = self._preprocess_patterns(
patterns, self.graph)
self.input_mask_op = input_mask_op
self.mha_weight = mha_weight
self.ffn_weight = ffn_weight
_logger.info("start to reorder weight in program")
self.scope = self.reorder(inference_program, self.scope, patterns,
layer_num, head_num, mha_weight, ffn_weight)
def _preprocess_patterns(self, patterns, graph):
""" Preprocess pattern of the program, get some info need by reorder"""
input_mask_op = patterns['input_mask']
layer_num = int((len(patterns) - 1) / 2)
head_num = len(input_mask_op.input_arg_names)
mha_weight, ffn_weight = preprocess_transformer_patterns(patterns,
graph)
return input_mask_op, layer_num, head_num, mha_weight, ffn_weight
def _program_add_mask(self, program, patterns, layer_num, head_num,
label_info, fetch_targets):
""" Add head mask for program to compute the importance of weight and head """
fetch_list = []
for ft in fetch_targets:
fetch_list.append(ft.name)
program = recover_inference_program(program)
block = program.global_block()
head_mask = block.create_var(
name='head_mask',
shape=[layer_num, head_num],
dtype='float32',
persistable=True)
feed_num = _feed_op_num(program)
fill_constant_op(
block,
feed_num, [layer_num, head_num],
1.0,
out=head_mask,
stop_gradient=False)
head_mask = unsqueeze_op(
block, -1,
unsqueeze_op(block, -1,
unsqueeze_op(block, 1, head_mask, feed_num + 1),
feed_num + 2), feed_num + 3)
for pattern_name, pattern in patterns.items():
if 'MHA' in pattern_name:
block_num = int(pattern_name.split('$')[-1])
for op in pattern:
if op.type() == 'softmax':
var_name = op._op.output_arg_names[0]
next_op = find_next_ops(block, var_name)
if next_op[0].type == 'dropout':
op = next_op[0]
insert_eltmul_op(block, op, head_mask, block_num)
logits = block.var(fetch_list[0])
labels = block.create_var(
name=label_info['name'],
shape=label_info['shape'],
dtype=label_info['dtype'],
persistable=False)
labels = feed_op(block, feed_num, labels)
ce_loss, probs = softmax_with_cross_entropy_op(
block, logits=logits, labels=labels)
loss = mean_op(block, ce_loss)
program._sync_with_cpp()
paddle.static.append_backward(loss)
program._sync_with_cpp()
return program
def compute_importance(self, exe, program, patterns, ffn_weight, layer_num,
head_num, label_info, fetch_targets, dataloader):
""" Compute weight importance according weights and gradients of weight
Compute head importance according gradients of head_mask"""
program = self._program_add_mask(program, patterns, layer_num, head_num,
label_info, fetch_targets)
### define importance matrix
head_importance = np.zeros(shape=[layer_num, head_num], dtype='float32')
neuron_importance = []
intermediate_weight = []
intermediate_bias = []
output_weight = []
fetch_list = ['head_mask@GRAD']
### append weight name to fetch list
for l, wp in ffn_weight.items():
intermediate_weight.append(wp['P1'][0])
intermediate_bias.append(wp['P1'][1])
output_weight.append(wp['P2'][0])
fetch_list.extend(intermediate_weight)
fetch_list.extend(intermediate_bias)
fetch_list.extend(output_weight)
for out_ws in [intermediate_weight, intermediate_bias, output_weight]:
for out_w in out_ws:
fetch_list.append(out_w + '@GRAD')
for w_name in intermediate_weight:
neuron_importance.append(
np.zeros(
shape=[program.global_block().var(w_name).shape[1]],
dtype='float32'))
exe.run(paddle.static.default_startup_program())
### need to send a dataloader with label
for batch_id, data in enumerate(dataloader()):
outs = exe.run(program, feed=data, fetch_list=fetch_list)
hm_grad_value = outs.pop(0)
head_importance += np.abs(hm_grad_value)
part_len = int(len(outs) / 6)
t_intermediate_weight = outs[:part_len]
t_intermediate_bias = outs[part_len:2 * part_len]
t_output_weight = outs[2 * part_len:3 * part_len]
t_intermediate_weight_grad = outs[3 * part_len:4 * part_len]
t_intermediate_bias_grad = outs[4 * part_len:5 * part_len]
t_output_weight_grad = outs[5 * part_len:]
for w1, w1_g, b1, b1_g, w2, w2_g, current_importance in zip(
t_intermediate_weight, t_intermediate_weight_grad,
t_intermediate_bias, t_intermediate_bias_grad,
t_output_weight, t_output_weight_grad, neuron_importance):
current_importance += np.abs(
(np.sum(w1 * w1_g, axis=0) + b1 * b1_g))
current_importance += np.abs(np.sum(w2 * w2_g, axis=1))
return program, head_importance, neuron_importance
### REORDER
def _reorder_head(self, scope, place, weight, head_num, idx):
""" Start to reorder head according to importance"""
qkv = weight['P1']
attn_out = weight['P2']
attn_out_t = scope.find_var(qkv[0]).get_tensor()
num_per_head = int(attn_out_t.shape()[0] / head_num)
index = np.reshape(
np.take(
np.reshape(
np.arange(
0, head_num * num_per_head, dtype='int64'),
(head_num, num_per_head)),
idx,
axis=0), (-1))
def reorder_head_matrix(w_name, index, dim):
pd_w = scope.find_var(w_name).get_tensor()
np_w = np.array(pd_w)
new_w = np.take(np_w, index, axis=dim)
pd_w.set(new_w, place)
for w_idx, weight_name in enumerate(qkv):
if w_idx % 2 == 0:
### reorder qkv weight
reorder_head_matrix(weight_name, index, dim=1)
else:
### reorder qkv bias
reorder_head_matrix(weight_name, index, dim=0)
### reorder attention output weight
reorder_head_matrix(attn_out[0], index, dim=0)
def _reorder_neuron(self, scope, place, weight, idx):
""" Start to weight according to importance"""
ffn_i = weight['P1']
ffn_o = weight['P2']
def reorder_neurons_matrix(w_name, index, dim):
pd_w = scope.find_var(w_name).get_tensor()
np_w = np.array(pd_w)
new_w = np.take(np_w, index, axis=dim)
pd_w.set(new_w, place)
reorder_neurons_matrix(ffn_i[0], idx, dim=1)
reorder_neurons_matrix(ffn_i[1], idx, dim=0)
reorder_neurons_matrix(ffn_o[0], idx, dim=0)
def reorder_neuron_head(self, scope, place, mha_weight, ffn_weight,
head_importance, neuron_importance, head_num):
""" Start to weight and head according to importance"""
for layer, current_importance in enumerate(neuron_importance):
### reorder heads
idx = np.argsort(head_importance[layer])[::-1]
self._reorder_head(scope, place, mha_weight[layer], head_num, idx)
### reorder neurons
idx = np.argsort(current_importance)[::-1]
self._reorder_neuron(scope, place, ffn_weight[layer], idx)
def reorder(self, inference_program, scope, patterns, layer_num, head_num,
mha_weight, ffn_weight):
compute_program = inference_program.clone()
########################### COMPUTE IMPORTANCE ################################
compute_program, head_importance, neuron_importance = self.compute_importance(
self.exe, compute_program, patterns, ffn_weight, layer_num,
head_num, self.label_info, self.fetch_targets, self.dataloader)
############################### REORDER ##################################
self.reorder_neuron_head(scope, self.places, mha_weight, ffn_weight,
head_importance, neuron_importance, head_num)
return scope
### PRUNE
def _update_input_mask_inputs(self, program, op, new_inputs_len):
""" Prune input mask op """
input_var_name = op.input_arg_names
block = program.blocks[0]
var = block.var(input_var_name[0])
op.desc.set_input(
'X', input_var_name[:int(len(input_var_name) * new_inputs_len)])
def _prune_weight(self, graph, scope, place, pruned_name, pruned_ratio):
""" Prune every weight in program """
param = graph.var(pruned_name)
_var = scope.find_var(param.name())
if _var is None:
return
param_t = _var.get_tensor()
pruned_ratio = [pruned_ratio[1]] if len(param_t.shape(
)) == 1 else pruned_ratio
pruned_shape = np.multiply(param_t.shape(), pruned_ratio)
pruned_shape = list(map(int, pruned_shape))
param.set_shape(pruned_shape)
if len(pruned_shape) == 2:
pruned_param = np.array(param_t)[:pruned_shape[0], :pruned_shape[1]]
else:
pruned_param = np.array(param_t)[:pruned_shape[0]]
param_t.set(pruned_param, place)
def _prune_transformer(self, scope, place, graph, pruned_dict):
""" Prune transformer program """
for name, value in pruned_dict.items():
### prune weight
self._prune_weight(graph, scope, place, name, value)
graph.infer_shape()
return graph.program
def prune(self):
### get input_mask op and start to prune input_mask op
if self.input_mask_op.type == 'stack':
self._update_input_mask_inputs(self.inference_program,
self.input_mask_op, self.width_mult)
pruned_params = []
pruned_ratio = []
for partern_weight in [self.mha_weight, self.ffn_weight]:
for block, part in partern_weight.items():
pruned_params.extend(part['P1'])
pruned_ratio.extend(len(part['P1']) * [[1.0, self.width_mult]])
pruned_params.extend(part['P2'])
pruned_ratio.extend(len(part['P2']) * [[self.width_mult, 1.0]])
if 'reshape_op' in part:
for op in part['reshape_op']:
origin_shape = op.attr('shape')
pruned_shape = origin_shape
if len(origin_shape) == 3:
pruned_shape[-1] = int(origin_shape[-1] *
self.width_mult)
op.set_attr('shape', pruned_shape)
elif len(origin_shape) == 4:
pruned_shape[-2] = int(origin_shape[-2] *
self.width_mult)
op.set_attr('shape', pruned_shape)
else:
raise IndexError
pruned_dict = dict(zip(pruned_params, pruned_ratio))
### start to prune weight
pruned_program = self._prune_transformer(self.scope, self.places,
self.graph, pruned_dict)
return pruned_program
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .predict import predict_compressed_model
__all__ = ["predict_compressed_model"]
......@@ -6,19 +6,30 @@ from .fake_ptq import post_quant_fake
import shutil
def predict_compressed_model(model_file, param_file, hardware='SD710'):
def predict_compressed_model(model_dir,
model_filename,
params_filename,
hardware='SD710'):
"""
Evaluating the latency of the model under various compression strategies.
Args:
model_file(str), param_file(str): The inference model to be compressed.
model_dir(str): The path of inference model that will be compressed, and
the model and params that saved by ``paddle.static.io.save_inference_model``
are under the path.
model_filename(str, optional): The name of model file. If parameters
are saved in separate files, set it as 'None'. Default: 'None'.
params_filename(str, optional): The name of params file.
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default : 'None'.
hardware(str): Target device.
Returns:
latency_dict(dict): The latency latency of the model under various compression strategies.
"""
latency_dict = {}
model_filename = model_file.split('/')[-1]
param_filename = param_file.split('/')[-1]
model_file = os.path.join(model_dir, model_filename)
param_file = os.path.join(model_dir, params_filename)
predictor = TableLatencyPredictor(hardware)
latency = predictor.predict(
......@@ -29,22 +40,22 @@ def predict_compressed_model(model_file, param_file, hardware='SD710'):
exe = paddle.static.Executor(place)
post_quant_fake(
exe,
model_dir=os.path.dirname(model_file),
model_dir=model_dir,
model_filename=model_filename,
params_filename=param_filename,
params_filename=params_filename,
save_model_path='quant_model',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', param_filename)
quant_param_file = os.path.join('quant_model', params_filename)
latency = predictor.predict(
model_file=quant_model_file,
param_file=quant_param_file,
data_type='int8')
latency_dict.update({f'origin_int8': latency})
latency_dict.update({'origin_int8': latency})
for prune_ratio in [0.3, 0.4, 0.5, 0.6]:
get_prune_model(
......@@ -53,7 +64,7 @@ def predict_compressed_model(model_file, param_file, hardware='SD710'):
ratio=prune_ratio,
save_path='prune_model')
prune_model_file = os.path.join('prune_model', model_filename)
prune_param_file = os.path.join('prune_model', param_filename)
prune_param_file = os.path.join('prune_model', params_filename)
latency = predictor.predict(
model_file=prune_model_file,
......@@ -65,14 +76,14 @@ def predict_compressed_model(model_file, param_file, hardware='SD710'):
exe,
model_dir='prune_model',
model_filename=model_filename,
params_filename=param_filename,
params_filename=params_filename,
save_model_path='quant_model',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', param_filename)
quant_param_file = os.path.join('quant_model', params_filename)
latency = predictor.predict(
model_file=quant_model_file,
......@@ -87,7 +98,7 @@ def predict_compressed_model(model_file, param_file, hardware='SD710'):
ratio=sparse_ratio,
save_path='sparse_model')
sparse_model_file = os.path.join('sparse_model', model_filename)
sparse_param_file = os.path.join('sparse_model', param_filename)
sparse_param_file = os.path.join('sparse_model', params_filename)
latency = predictor.predict(
model_file=sparse_model_file,
......@@ -99,20 +110,20 @@ def predict_compressed_model(model_file, param_file, hardware='SD710'):
exe,
model_dir='sparse_model',
model_filename=model_filename,
params_filename=param_filename,
params_filename=params_filename,
save_model_path='quant_model',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
weight_bits=8)
quant_model_file = os.path.join('quant_model', model_filename)
quant_param_file = os.path.join('quant_model', param_filename)
quant_param_file = os.path.join('quant_model', params_filename)
latency = predictor.predict(
model_file=quant_model_file,
param_file=quant_param_file,
data_type='int8')
latency_dict.update({f'sparse_{prune_ratio}_int8': latency})
latency_dict.update({f'sparse_{sparse_ratio}_int8': latency})
# Delete temporary model files
shutil.rmtree('./quant_model')
......
......@@ -24,6 +24,7 @@ from .meter import AvgrageMeter
from .analyze_helper import VarCollector
from . import wrapper_function
from . import recover_program
from . import patterns
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
......@@ -33,3 +34,4 @@ __all__ = [
__all__ += wrapper_function.__all__
__all__ += recover_program.__all__
__all__ += patterns.__all__
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import numpy as np
import warnings
import paddle
from ..core import GraphWrapper
from .patterns_common import *
__all__ = ['find_final_nodes', 'get_patterns']
def find_final_nodes(program):
""" Find the output of the final op with weights in the program """
final_nodes = []
graph = GraphWrapper(program)
for op in sorted(graph.ops()):
if op.type() in ALL_WEIGHT_OP and is_output_weight_ops(op, graph):
n_op = has_bias(op, graph)
if n_op is not None:
final_nodes.extend(n_op.all_outputs())
else:
if op.type() == 'batch_norm':
out_var = op.outputs('Y')
else:
out_var = op.all_outputs()
final_nodes.extend(out_var)
return final_nodes
def _is_mha(pattern_ops, pattern_ops_type):
""" judge whether this pattern is multihead attention """
if pattern_ops_type.count('softmax') != 1 or pattern_ops_type.count(
'fetch') > 0:
return False
matmul_num = 0
for op in pattern_ops:
if op.type() in ['matmul', 'matmul_v2']:
if not is_dynamic_weight_op(op):
matmul_num += 1
if matmul_num == 2:
return True
return False
def _is_ffn(pattern_ops, pattern_ops_type):
""" judge whether this pattern is feed forward network """
if pattern_ops_type.count('layer_norm') != 1:
return False
linear_num = 0
act_num = 0
for op in pattern_ops:
if op.type() in ['mul', 'matmul', 'matmul_v2']:
if is_dynamic_weight_op(op):
linear_num += 1
if op.type() in ['relu', 'gelu']:
act_num += 1
if linear_num == 2 and act_num == 1:
return True
return False
def get_patterns(program, only_final_node=True):
""" distinguish the pattern in the program and get distillation node """
distill_node = []
patterns = {}
graph = GraphWrapper(program)
block_num = 0
model_type = None
for op in graph.ops():
belonged_teacher = False
for inp in op.all_inputs():
if 'teacher' in inp._var.name:
belonged_teacher = True
break
if belonged_teacher:
continue
if op.type() == 'elementwise_add':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable):
sc_path = []
shortcut_start_op = []
is_sc = is_shortcut(op, graph, sc_path, shortcut_start_op)
if is_sc:
out_var_name = op.all_outputs()[0]._var.name
shortcut_start_op = shortcut_start_op[0]
pattern_ops, pattern_ops_type = traversal_ops(
shortcut_start_op, graph, op.idx())
pattern_name = shortcut_start_op.type() + '$' + str(op.idx(
))
if _is_mha(pattern_ops, pattern_ops_type):
model_type = 'transformer'
pattern_name = 'MHA$' + str(block_num)
if model_type == 'transformer' and _is_ffn(
pattern_ops, pattern_ops_type):
pattern_name = 'FFN$' + str(block_num)
block_num += 1
if not only_final_node:
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
if model_type == 'transformer' and (
'fetch' in pattern_ops_type or
pattern_ops_type[-1] == 'scale'):
if 'input_mask' not in patterns:
patterns['input_mask'] = pattern_ops[0]._op
if 'fetch' in pattern_ops_type or pattern_ops_type[
-1] == 'scale':
continue
patterns[pattern_name] = pattern_ops
if model_type != 'transformer' and (not only_final_node):
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
### add the output of final weight node to distill node
final_weight_node = find_final_nodes(program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
return patterns, distill_node, model_type
import os
ALL_WEIGHT_OP = [
'conv2d', 'mul', 'matmul', 'embedding', 'conv2d_transpose',
'depthwise_conv2d', 'batch_norm', 'layer_norm', 'instance_norm',
'sync_batch_norm', 'matmul_v2'
]
def traversal_ops(op, graph, target_op_idx):
""" Get all operators in the multi-path from op to target op. """
pattern_ops = []
pattern_ops_type = []
visited = []
pq = [op]
while pq:
cnt = len(pq)
level = []
for _ in range(cnt):
cur = pq.pop(0)
level.append(cur.type())
if cur.idx() not in visited:
### first op must be start op
pattern_ops.append(cur)
pattern_ops_type.append(cur.type())
visited.append(cur.idx())
for n_op in graph.next_ops(cur):
if n_op.is_opt_op() or n_op.is_bwd_op():
break
if n_op.idx() == target_op_idx or n_op.idx() in visited:
continue
pq.append(n_op)
return pattern_ops, pattern_ops_type
def find_weight_op(op, graph):
""" Find operators with weight."""
next_ops = sorted(graph.next_ops(op))
for next_op in next_ops:
if is_dynamic_weight_op(next_op):
return next_op
else:
return find_weight_op(next_op, graph)
def get_weight(op, return_name=True):
""" get the weight of operators with weight."""
for inp in op.all_inputs():
if inp._var.persistable == True:
if return_name:
return inp.name()
else:
return inp
def is_dynamic_weight_op(op):
weight_ops = ALL_WEIGHT_OP
if op.type() in weight_ops:
if op.type() in ['mul', 'matmul', 'matmul_v2']:
for inp in sorted(op.all_inputs()):
if inp._var.persistable == True:
return True
return False
return True
return False
def is_output_weight_ops(op, graph):
""" Judge whether is the final op with weights in the graph """
next_ops = sorted(graph.next_ops(op))
for next_op in next_ops:
if is_dynamic_weight_op(next_op):
return False
return is_output_weight_ops(next_op, graph)
return True
def has_bias(op, graph):
""" Get the bias of the op if exists """
n_op = graph.next_ops(op)[0]
if op.type() in ALL_WEIGHT_OP:
if n_op.type() == 'elementwise_add':
for inp in n_op.all_inputs():
if inp._var.persistable == True:
return n_op
return None
def _find_next_target_op(op, graph, target_op_idx, sc_path):
""" Find the target op from other branch in the shortcut """
if op.idx() == target_op_idx:
return True
n_ops = graph.next_ops(op)
for n_op in n_ops:
sc_path.append(n_op.type())
return _find_next_target_op(n_op, graph, target_op_idx, sc_path)
return False
def is_shortcut(op, graph, sc_path, shortcut_start_op):
"""
op /```````````````````\ add
\____op1___op2__..._/
"""
inps = op.all_inputs()
pre_ops = graph.pre_ops(op)
for p_op in pre_ops:
n_ops = graph.next_ops(p_op)
if len(n_ops) == 1:
continue
### note: only support one branch donnot have op
has_sc = False
for n_op in n_ops:
if n_op.idx() == op.idx():
shortcut_start_op.append(p_op)
has_sc = True
if has_sc:
for n_op in n_ops:
if n_op.idx() != op.idx():
sc_path.append(p_op.type())
sc_path.append(n_op.type())
return _find_next_target_op(n_op, graph, op.idx(), sc_path)
return False
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core import GraphWrapper
from .patterns_common import *
__all__ = ['preprocess_transformer_patterns']
def _append_transformer_prune_params(op, graph, block_num, params_dict):
for next_op in graph.next_ops(op):
if next_op.type() in ['mul', 'matmul', 'matmul_v2'
] and is_dynamic_weight_op(next_op):
if block_num not in params_dict:
params_dict[block_num] = {}
params_dict[block_num]['P1'] = [get_weight(next_op)]
else:
params_dict[block_num]['P1'].append(get_weight(next_op))
params_dict[block_num]['P1'].append(
get_weight(has_bias(next_op, graph)))
op = next_op
next_op = find_weight_op(op, graph)
if next_op:
params_dict[block_num]['P2'] = [get_weight(next_op)]
params_dict[block_num]['P2'].append(
get_weight(has_bias(next_op, graph)))
return params_dict
def preprocess_transformer_patterns(patterns, graph):
""" """
mha_weight = {}
ffn_weight = {}
for pattern_name, pattern_ops in patterns.items():
if pattern_name == 'input_mask':
continue
block_num = int(pattern_name.split('$')[-1])
if 'MHA' in pattern_name:
mha_weight = _append_transformer_prune_params(pattern_ops[0], graph,
block_num, mha_weight)
mha_weight[block_num]['reshape_op'] = []
for op in pattern_ops:
if op.type() in ['reshape', 'reshape2']:
mha_weight[block_num]['reshape_op'].append(op)
elif 'FFN' in pattern_name:
ffn_weight = _append_transformer_prune_params(pattern_ops[0], graph,
block_num, ffn_weight)
return mha_weight, ffn_weight
......@@ -340,8 +340,9 @@ class GraphWrapper(object):
for p in self.ops():
for in_var in op.all_inputs():
if in_var in p.all_outputs():
ops.append(p)
return ops
if p.idx() != op.idx():
ops.append(p)
return sorted(ops)
def next_ops(self, op):
"""
......@@ -357,8 +358,9 @@ class GraphWrapper(object):
for p in self.ops():
for out_var in op.all_outputs():
if out_var in p.all_inputs():
ops.append(p)
return ops
if p.idx() != op.idx():
ops.append(p)
return sorted(ops)
def get_param_by_op(self, op):
"""
......
......@@ -100,10 +100,17 @@ class QuantConfig:
g_quant_config = None
g_min_emd_loss = float('inf')
g_quant_model_cache_path = "quant_model_tmp"
def emd_loss_init():
global g_min_emd_loss
g_min_emd_loss = float('inf')
emd_loss_init()
def make_feed_dict(feed_target_names, data):
"""construct feed dictionary"""
feed_dict = {}
......@@ -236,6 +243,8 @@ def eval_quant_model():
def quantize(cfg):
"""model quantize job"""
algo = cfg["algo"] if 'algo' in cfg else g_quant_config.algo[0][0]
if g_quant_config.hist_percent[0] is None:
g_quant_config.hist_percent = [g_quant_config.hist_percent]
hist_percent = cfg[
"hist_percent"] if "hist_percent" in cfg else g_quant_config.hist_percent[
0][0]
......@@ -272,12 +281,32 @@ def quantize(cfg):
batch_nums=batch_num)
global g_min_emd_loss
### if eval_function is not None, use eval function provided by user.
### TODO(ceci3): fix eval_function
if g_quant_config.eval_function is not None:
emd_loss = g_quant_config.eval_function()
else:
try:
emd_loss = eval_quant_model()
except:
### if eval_function is not None, use eval function provided by user.
float_scope = paddle.static.Scope()
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[float_inference_program, feed_target_names, fetch_targets]= fluid.io.load_inference_model( \
dirname=g_quant_config.model_filename, \
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
float_metric = g_quant_config.eval_function(
g_quant_config.executor, float_inference_program,
feed_target_names, fetch_targets)
with paddle.static.scope_guard(quant_scope):
[quant_inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model( \
dirname=g_quant_model_cache_path, \
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
quant_metric = g_quant_config.eval_function(
g_quant_config.executor, inference_program, feed_target_names,
fetch_targets)
emd_loss = float(abs(float_metric - quant_metric)) / float_metric
print("emd loss: ", emd_loss)
if emd_loss < g_min_emd_loss:
g_min_emd_loss = emd_loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册