未验证 提交 03331331 编写于 作者: C ceci3 提交者: GitHub

fix act recover params and pattern recognition (#1695)

* fix

* update
上级 b735a396
...@@ -71,6 +71,8 @@ def _recover_param_attr(program, startup_program): ...@@ -71,6 +71,8 @@ def _recover_param_attr(program, startup_program):
if param.persistable is True and param.name != 'feed' and param.name != 'fetch'] if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
with paddle.static.program_guard(program, startup_program): with paddle.static.program_guard(program, startup_program):
for w in all_weights: for w in all_weights:
if w.dtype not in [paddle.float32]:
continue
new_w = paddle.create_parameter( new_w = paddle.create_parameter(
shape=w.shape, dtype=w.dtype, name=w.name) shape=w.shape, dtype=w.dtype, name=w.name)
new_w.set_value(w.get_value()) new_w.set_value(w.get_value())
......
...@@ -25,13 +25,15 @@ def _find_gemm_op(op, graph): ...@@ -25,13 +25,15 @@ def _find_gemm_op(op, graph):
return op return op
def _append_transformer_prune_params(op, graph, block_num, params_dict): def _append_transformer_prune_params(op_lists, graph, block_num, params_dict):
for next_op in graph.next_ops(op): first_op = op_lists[0]
for next_op in graph.next_ops(first_op):
if next_op.type() == 'elementwise_add': if next_op.type() == 'elementwise_add':
continue continue
next_op = _find_gemm_op(next_op, graph) next_op = _find_gemm_op(next_op, graph)
if next_op.type() in ['mul', 'matmul', 'matmul_v2' if next_op.type() in [
] and has_trainable_var(next_op): 'mul', 'matmul', 'matmul_v2'
] and has_trainable_var(next_op) and next_op in op_lists:
if block_num not in params_dict: if block_num not in params_dict:
params_dict[block_num] = {} params_dict[block_num] = {}
params_dict[block_num]['P1'] = [get_weight(next_op)] params_dict[block_num]['P1'] = [get_weight(next_op)]
...@@ -41,7 +43,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict): ...@@ -41,7 +43,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict):
get_weight(has_bias(next_op, graph))) get_weight(has_bias(next_op, graph)))
op = next_op op = next_op
next_op = _find_gemm_op(find_weight_op(op, graph), graph) next_op = _find_gemm_op(find_weight_op(op, graph), graph)
if next_op: if next_op and next_op in op_lists:
params_dict[block_num]['P2'] = [get_weight(next_op)] params_dict[block_num]['P2'] = [get_weight(next_op)]
params_dict[block_num]['P2'].append( params_dict[block_num]['P2'].append(
get_weight(has_bias(next_op, graph))) get_weight(has_bias(next_op, graph)))
...@@ -57,14 +59,14 @@ def preprocess_transformer_patterns(patterns, graph): ...@@ -57,14 +59,14 @@ def preprocess_transformer_patterns(patterns, graph):
continue continue
block_num = int(pattern_name.split('$')[-1]) block_num = int(pattern_name.split('$')[-1])
if 'MHA' in pattern_name: if 'MHA' in pattern_name:
mha_weight = _append_transformer_prune_params(pattern_ops[0], graph, mha_weight = _append_transformer_prune_params(
block_num, mha_weight) pattern_ops, graph, block_num, mha_weight)
mha_weight[block_num]['reshape_op'] = [] mha_weight[block_num]['reshape_op'] = []
for op in pattern_ops: for op in pattern_ops:
if op.type() in ['reshape', 'reshape2']: if op.type() in ['reshape', 'reshape2']:
mha_weight[block_num]['reshape_op'].append(op) mha_weight[block_num]['reshape_op'].append(op)
elif 'FFN' in pattern_name: elif 'FFN' in pattern_name:
ffn_weight = _append_transformer_prune_params(pattern_ops[0], graph, ffn_weight = _append_transformer_prune_params(
block_num, ffn_weight) pattern_ops, graph, block_num, ffn_weight)
return mha_weight, ffn_weight return mha_weight, ffn_weight
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册