未验证 提交 e19de504 编写于 作者: C ceci3 提交者: GitHub

limit prune ratio in transformer pruner (#1087)

上级 dede4a80
......@@ -39,6 +39,9 @@ def load_config(config_path):
else:
train_config = None
if len(compress_config) == 0:
compress_config = None
return compress_config, train_config
......
......@@ -19,6 +19,7 @@ from ..core import GraphWrapper
from ..common import get_logger
from ..common.recover_program import recover_inference_program
from ..common.transformer_pattern import preprocess_transformer_patterns
from ..common.patterns_common import is_dynamic_weight_op
_logger = get_logger(__name__, level=logging.INFO)
......@@ -228,13 +229,21 @@ class TransformerPruner:
self.graph = GraphWrapper(inference_program)
self.patterns = patterns
self.label_info = label_info
self.width_mult = width_mult
self.fetch_targets = fetch_targets
self.dataloader = dataloader
self.scope = paddle.static.global_scope()
input_mask_op, layer_num, head_num, mha_weight, ffn_weight = self._preprocess_patterns(
patterns, self.graph)
### the prune ratio * head_num need to be an integer.
pruned_head = round(width_mult * head_num)
self.width_mult = float(pruned_head) / head_num
if self.width_mult != width_mult:
_logger.info(
"the prune ratio * head_num need to be an integer. so change prune ratio from {} to {}".
format(str(1.0 - width_mult), str(1.0 - self.width_mult)))
self.input_mask_op = input_mask_op
self.mha_weight = mha_weight
self.ffn_weight = ffn_weight
......@@ -247,7 +256,15 @@ class TransformerPruner:
""" Preprocess pattern of the program, get some info need by reorder"""
input_mask_op = patterns['input_mask']
layer_num = int((len(patterns) - 1) / 2)
head_num = len(input_mask_op.input_arg_names)
### get real head number
head_num = -1
tmp_mha_ops = patterns['MHA$0']
for op in tmp_mha_ops:
if op.type() in ['matmul', 'matmul_v2'] and (
not is_dynamic_weight_op(op)) and head_num == -1:
inp_var = op.inputs("X")
head_num = inp_var[0].shape()[1]
mha_weight, ffn_weight = preprocess_transformer_patterns(patterns,
graph)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册