未验证 提交 03331331 编写于 作者: C ceci3 提交者: GitHub

fix act recover params and pattern recognition (#1695)

* fix

* update
上级 b735a396
......@@ -71,6 +71,8 @@ def _recover_param_attr(program, startup_program):
if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
with paddle.static.program_guard(program, startup_program):
for w in all_weights:
if w.dtype not in [paddle.float32]:
continue
new_w = paddle.create_parameter(
shape=w.shape, dtype=w.dtype, name=w.name)
new_w.set_value(w.get_value())
......
......@@ -25,13 +25,15 @@ def _find_gemm_op(op, graph):
return op
def _append_transformer_prune_params(op, graph, block_num, params_dict):
for next_op in graph.next_ops(op):
def _append_transformer_prune_params(op_lists, graph, block_num, params_dict):
first_op = op_lists[0]
for next_op in graph.next_ops(first_op):
if next_op.type() == 'elementwise_add':
continue
next_op = _find_gemm_op(next_op, graph)
if next_op.type() in ['mul', 'matmul', 'matmul_v2'
] and has_trainable_var(next_op):
if next_op.type() in [
'mul', 'matmul', 'matmul_v2'
] and has_trainable_var(next_op) and next_op in op_lists:
if block_num not in params_dict:
params_dict[block_num] = {}
params_dict[block_num]['P1'] = [get_weight(next_op)]
......@@ -41,7 +43,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict):
get_weight(has_bias(next_op, graph)))
op = next_op
next_op = _find_gemm_op(find_weight_op(op, graph), graph)
if next_op:
if next_op and next_op in op_lists:
params_dict[block_num]['P2'] = [get_weight(next_op)]
params_dict[block_num]['P2'].append(
get_weight(has_bias(next_op, graph)))
......@@ -57,14 +59,14 @@ def preprocess_transformer_patterns(patterns, graph):
continue
block_num = int(pattern_name.split('$')[-1])
if 'MHA' in pattern_name:
mha_weight = _append_transformer_prune_params(pattern_ops[0], graph,
block_num, mha_weight)
mha_weight = _append_transformer_prune_params(
pattern_ops, graph, block_num, mha_weight)
mha_weight[block_num]['reshape_op'] = []
for op in pattern_ops:
if op.type() in ['reshape', 'reshape2']:
mha_weight[block_num]['reshape_op'].append(op)
elif 'FFN' in pattern_name:
ffn_weight = _append_transformer_prune_params(pattern_ops[0], graph,
block_num, ffn_weight)
ffn_weight = _append_transformer_prune_params(
pattern_ops, graph, block_num, ffn_weight)
return mha_weight, ffn_weight
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册