diff --git a/paddleslim/common/recover_program.py b/paddleslim/common/recover_program.py index efda1b47e08dbbd18d094cee52d2a3effb0259c7..46224a58d80006ad379dbd4c252a3de6194e1860 100644 --- a/paddleslim/common/recover_program.py +++ b/paddleslim/common/recover_program.py @@ -71,6 +71,8 @@ def _recover_param_attr(program, startup_program): if param.persistable is True and param.name != 'feed' and param.name != 'fetch'] with paddle.static.program_guard(program, startup_program): for w in all_weights: + if w.dtype not in [paddle.float32]: + continue new_w = paddle.create_parameter( shape=w.shape, dtype=w.dtype, name=w.name) new_w.set_value(w.get_value()) diff --git a/paddleslim/common/transformer_pattern.py b/paddleslim/common/transformer_pattern.py index 1ba9c977521a93fb00450c024d4c5dd1321b1e2d..c6b08d78a31e4220c69aee7a671a39907cb410d9 100644 --- a/paddleslim/common/transformer_pattern.py +++ b/paddleslim/common/transformer_pattern.py @@ -25,13 +25,15 @@ def _find_gemm_op(op, graph): return op -def _append_transformer_prune_params(op, graph, block_num, params_dict): - for next_op in graph.next_ops(op): +def _append_transformer_prune_params(op_lists, graph, block_num, params_dict): + first_op = op_lists[0] + for next_op in graph.next_ops(first_op): if next_op.type() == 'elementwise_add': continue next_op = _find_gemm_op(next_op, graph) - if next_op.type() in ['mul', 'matmul', 'matmul_v2' - ] and has_trainable_var(next_op): + if next_op.type() in [ + 'mul', 'matmul', 'matmul_v2' + ] and has_trainable_var(next_op) and next_op in op_lists: if block_num not in params_dict: params_dict[block_num] = {} params_dict[block_num]['P1'] = [get_weight(next_op)] @@ -41,7 +43,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict): get_weight(has_bias(next_op, graph))) op = next_op next_op = _find_gemm_op(find_weight_op(op, graph), graph) - if next_op: + if next_op and next_op in op_lists: params_dict[block_num]['P2'] = [get_weight(next_op)] params_dict[block_num]['P2'].append( get_weight(has_bias(next_op, graph))) @@ -57,14 +59,14 @@ def preprocess_transformer_patterns(patterns, graph): continue block_num = int(pattern_name.split('$')[-1]) if 'MHA' in pattern_name: - mha_weight = _append_transformer_prune_params(pattern_ops[0], graph, - block_num, mha_weight) + mha_weight = _append_transformer_prune_params( + pattern_ops, graph, block_num, mha_weight) mha_weight[block_num]['reshape_op'] = [] for op in pattern_ops: if op.type() in ['reshape', 'reshape2']: mha_weight[block_num]['reshape_op'].append(op) elif 'FFN' in pattern_name: - ffn_weight = _append_transformer_prune_params(pattern_ops[0], graph, - block_num, ffn_weight) + ffn_weight = _append_transformer_prune_params( + pattern_ops, graph, block_num, ffn_weight) return mha_weight, ffn_weight