提交 606f9dca 编写于 作者: W wjj19950828

rm params

上级 952726a4
......@@ -113,11 +113,13 @@ class TraceFcFuser(FuseBase):
attrs["out_features"] = parameters[weight_name].shape[0]
linear_name = "linear{}".format(self.linear_index)
self.linear_index += 1
weight_numpy = parameters.pop(weight_name)
weight_numpy = parameters[weight_name]
parameters["{}.weight".format(linear_name)] = weight_numpy.transpose(
(1, 0))
bias_numpy = parameters.pop(bias_name)
self.rm_params.add(weight_name)
bias_numpy = parameters[bias_name]
parameters["{}.bias".format(linear_name)] = np.squeeze(bias_numpy)
self.rm_params.add(bias_name)
new_layer = PaddleLayer(
layers_id[0],
"paddle.nn.Linear",
......
......@@ -325,6 +325,7 @@ class FuseBase(object):
def __init__(self):
self.pattern = PaddleGraph()
self.patterns = list()
self.rm_params = set()
def operate(self, graph, match_kind="topo"):
parameters = graph.parameters
......@@ -335,6 +336,8 @@ class FuseBase(object):
subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, parameters, match)
self.delete_match(graph)
for param_name in self.rm_params:
parameters.pop(param_name)
graph.build()
def perform_pattern_matcher(self, graph, match_kind="topo"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册