未验证 提交 e72ce197 编写于 作者: C ceci3 提交者: GitHub

support vit prune (#1590)

* support vit prune

* update

* add unittest
上级 5662a660
......@@ -287,8 +287,10 @@ class TransformerPruner:
def _preprocess_patterns(self, patterns, graph):
""" Preprocess pattern of the program, get some info need by reorder"""
input_mask_op = patterns['input_mask']
layer_num = int((len(patterns) - 1) / 2)
input_mask_op = patterns.get('input_mask', None)
layer_num = int(
(len(patterns) - 1) / 2) if input_mask_op is not None else int(
(len(patterns) / 2))
### get real head number
head_num = -1
......@@ -395,8 +397,6 @@ class TransformerPruner:
shape=[program.global_block().var(w_name).shape[1]],
dtype='float32'))
exe.run(paddle.static.default_startup_program())
### need to send a dataloader with label
for batch_id, data in enumerate(dataloader()):
outs = exe.run(program, feed=data, fetch_list=fetch_list)
......@@ -445,13 +445,21 @@ class TransformerPruner:
new_w = np.take(np_w, index, axis=dim)
pd_w.set(new_w, place)
if int(len(qkv) / 2) == 1:
q_index = index
k_index = index + 768
v_index = index + (768 * 2)
qkv_index = np.append(np.append(q_index, k_index), v_index)
else:
qkv_index = index
for w_idx, weight_name in enumerate(qkv):
if w_idx % 2 == 0:
### reorder qkv weight
reorder_head_matrix(weight_name, index, dim=1)
reorder_head_matrix(weight_name, qkv_index, dim=1)
else:
### reorder qkv bias
reorder_head_matrix(weight_name, index, dim=0)
reorder_head_matrix(weight_name, qkv_index, dim=0)
### reorder attention output weight
reorder_head_matrix(attn_out[0], index, dim=0)
......@@ -507,7 +515,13 @@ class TransformerPruner:
op.desc.set_input(
'X', input_var_name[:int(len(input_var_name) * new_inputs_len)])
def _prune_weight(self, graph, scope, place, pruned_name, pruned_ratio):
def _prune_weight(self,
graph,
scope,
place,
pruned_name,
pruned_ratio,
fuse_qkv=False):
""" Prune every weight in program """
param = graph.var(pruned_name)
_var = scope.find_var(param.name())
......@@ -516,26 +530,62 @@ class TransformerPruner:
param_t = _var.get_tensor()
pruned_ratio = [pruned_ratio[1]] if len(param_t.shape(
)) == 1 else pruned_ratio
pruned_shape = np.multiply(param_t.shape(), pruned_ratio)
pruned_shape = list(map(int, pruned_shape))
param.set_shape(pruned_shape)
if len(pruned_shape) == 2:
pruned_param = np.array(param_t)[:pruned_shape[0], :pruned_shape[1]]
origin_shape = param_t.shape()
def process_qkv(qkv_param, pruned_ratio):
qkv_param_shape = qkv_param.shape()
if len(qkv_param_shape) == 2:
tmp_qkv_param_shape = [qkv_param_shape[0], -1, 3]
else:
tmp_qkv_param_shape = [-1, 3]
tmp_param = np.reshape(qkv_param, tmp_qkv_param_shape)
tmp_pruned_ratio = pruned_ratio + [1.0]
tmp_pruned_shape = np.multiply(tmp_param.shape, tmp_pruned_ratio)
tmp_pruned_shape = list(map(int, tmp_pruned_shape))
if len(qkv_param_shape) == 2:
tmp_prune_qkv_param = tmp_param[:tmp_pruned_shape[
0], :tmp_pruned_shape[1], :tmp_pruned_shape[2]]
pruned_param = np.reshape(tmp_prune_qkv_param,
(qkv_param_shape[0], -1))
else:
tmp_prune_qkv_param = tmp_param[:tmp_pruned_shape[0], :
tmp_pruned_shape[1]]
pruned_param = np.reshape(tmp_prune_qkv_param, (-1))
return pruned_param
if fuse_qkv:
pruned_param = process_qkv(param_t, pruned_ratio)
param.set_shape(pruned_param.shape)
param_t.set(pruned_param, place)
else:
pruned_param = np.array(param_t)[:pruned_shape[0]]
param_t.set(pruned_param, place)
pruned_shape = np.multiply(param_t.shape(), pruned_ratio)
pruned_shape = list(map(int, pruned_shape))
param.set_shape(pruned_shape)
if len(pruned_shape) == 2:
pruned_param = np.array(param_t)[:pruned_shape[0], :
pruned_shape[1]]
else:
pruned_param = np.array(param_t)[:pruned_shape[0]]
param_t.set(pruned_param, place)
def _prune_transformer(self, scope, place, graph, pruned_dict):
""" Prune transformer program """
qkv_weights_name = []
if (len(self.mha_weight[0]['P1']) // 2 == 1):
for _, mha_weights_name in self.mha_weight.items():
qkv_weights_name.extend(mha_weights_name['P1'])
for name, value in pruned_dict.items():
### prune weight
self._prune_weight(graph, scope, place, name, value)
fuse_qkv = False
if name in qkv_weights_name:
fuse_qkv = True
self._prune_weight(graph, scope, place, name, value, fuse_qkv)
graph.infer_shape()
return graph.program
def prune(self):
### get input_mask op and start to prune input_mask op
if self.input_mask_op.type == 'stack':
if self.input_mask_op is not None and self.input_mask_op.type == 'stack':
self._update_input_mask_inputs(self.inference_program,
self.input_mask_op, self.width_mult)
......@@ -555,7 +605,7 @@ class TransformerPruner:
pruned_shape[-1] = int(origin_shape[-1] *
self.width_mult)
op.set_attr('shape', pruned_shape)
elif len(origin_shape) == 4:
elif len(origin_shape) == 4 or len(origin_shape) == 5:
pruned_shape[-2] = int(origin_shape[-2] *
self.width_mult)
op.set_attr('shape', pruned_shape)
......
......@@ -101,14 +101,15 @@ def get_patterns(program, only_final_node=True):
if (not inp1._var.persistable) and (not inp2._var.persistable):
sc_path = []
shortcut_start_op = []
is_sc = is_shortcut(op, graph, sc_path, shortcut_start_op)
is_sc, target_op_idx = is_shortcut(op, graph, sc_path,
shortcut_start_op)
if is_sc:
out_var_name = op.all_outputs()[0]._var.name
shortcut_start_op = shortcut_start_op[0]
next_op = graph.next_ops(op)
next_ops = graph.next_ops(op)
pattern_ops, pattern_ops_type = traversal_ops(
shortcut_start_op, graph, next_op[0].idx())
shortcut_start_op, graph, target_op_idx)
pattern_name = shortcut_start_op.type() + '$' + str(op.idx(
))
......
......@@ -132,5 +132,6 @@ def is_shortcut(op, graph, sc_path, shortcut_start_op):
if n_op.idx() != op.idx():
sc_path.append(p_op.type())
sc_path.append(n_op.type())
return _find_next_target_op(n_op, graph, op.idx(), sc_path)
return False
return _find_next_target_op(n_op, graph, op.idx(),
sc_path), op.idx()
return False, -1
......@@ -18,8 +18,18 @@ from .patterns_common import *
__all__ = ['preprocess_transformer_patterns']
def _find_gemm_op(op, graph):
while op.type() not in ['mul', 'matmul', 'matmul_v2']:
next_op = find_weight_op(op, graph)
op = next_op
return op
def _append_transformer_prune_params(op, graph, block_num, params_dict):
for next_op in graph.next_ops(op):
if next_op.type() == 'elementwise_add':
continue
next_op = _find_gemm_op(next_op, graph)
if next_op.type() in ['mul', 'matmul', 'matmul_v2'
] and is_dynamic_weight_op(next_op):
if block_num not in params_dict:
......@@ -30,7 +40,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict):
params_dict[block_num]['P1'].append(
get_weight(has_bias(next_op, graph)))
op = next_op
next_op = find_weight_op(op, graph)
next_op = _find_gemm_op(find_weight_op(op, graph), graph)
if next_op:
params_dict[block_num]['P2'] = [get_weight(next_op)]
params_dict[block_num]['P2'].append(
......
......@@ -428,8 +428,7 @@ def quant_aware(program,
quant_bits=config['activation_bits'],
skip_pattern=config['not_quant_pattern'],
quantizable_op_type=quant_dequant_ops,
is_test=is_test,
scale_dict=scale_dict)
is_test=is_test)
quant_dequant_pass.apply(main_graph)
......
......@@ -274,5 +274,60 @@ class ACTChannelPrune(unittest.TestCase):
os.system('rm -rf asp_output')
class ACTViTPrune(ACTChannelPrune):
def __init__(self, *args, **kwargs):
super(ACTViTPrune, self).__init__(*args, **kwargs)
if not os.path.exists('ViT_base_patch16_224_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ViT_base_patch16_224_infer.tar'
)
os.system('tar -xf ViT_base_patch16_224_infer.tar')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
self.train_dataloader, self.eval_dataloader = self.create_dataloader()
def test_act_vit_transformer_prune(self):
def eval_function(exe, compiled_test_program, test_feed_names,
test_fetch_list):
res = eval_func(compiled_test_program, exe, test_feed_names,
test_fetch_list, self.eval_dataloader)
return res
configs = {
'Distillation': {},
'TransformerPrune': {
'pruned_ratio': 0.1
},
'TrainConfig': {
'epochs': 1,
'eval_iter': 1000,
'learning_rate': 5.0e-03,
'optimizer_builder': {
'optimizer': {
'type': 'SGD'
},
"weight_decay": 0.0005,
}
}
}
ac = AutoCompression(
model_dir='./ViT_base_patch16_224_infer',
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="vit_prune_output",
config=configs,
train_dataloader=self.train_dataloader,
eval_callback=eval_function,
eval_dataloader=self.
eval_dataloader) # eval_function to verify accuracy
ac.compress()
os.system('rm -rf vit_prune_output')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册