From 775d6e6e9e7815100d2964241cc9785522bc95b9 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Fri, 13 Nov 2020 18:07:58 +0800 Subject: [PATCH] rename delete_layer --- .../dygraph/transpose_elimination.py | 2 +- x2paddle/optimizer/pattern_matcher.py | 54 ++----------------- 2 files changed, 4 insertions(+), 52 deletions(-) diff --git a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py index 8e5b61c..b2a8a67 100644 --- a/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py +++ b/x2paddle/optimizer/elimination/dygraph/transpose_elimination.py @@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase): continue for l in transpose_layers: - self.delete_layer_with_associated(_graph, l) + _graph.delete_layer(l) optimized_transpose_layers.extend(transpose_layers) optimized_reduce_layers.extend(reduce_layers) diff --git a/x2paddle/optimizer/pattern_matcher.py b/x2paddle/optimizer/pattern_matcher.py index 0427c40..826c900 100644 --- a/x2paddle/optimizer/pattern_matcher.py +++ b/x2paddle/optimizer/pattern_matcher.py @@ -268,7 +268,7 @@ class FuseBase(object): first_layer_id = list(match.keys())[0] subgraph = get_subgraph("", first_layer_id, graph) self.insert_new_layer(subgraph, parameters, match) - self.delete_layer(graph) + self.delete_match(graph) graph.build() def perform_pattern_matcher(self, graph, match_kind="topo"): @@ -283,7 +283,7 @@ class FuseBase(object): pattern_matcher = PatternMatcher(self.pattern) self.matches = pattern_matcher.operate(graph, match_kind) - def delete_layer(self, graph): + def delete_match(self, graph): """ 删除不需要的中间layer及其对应参数。 """ for match in self.matches: @@ -298,52 +298,4 @@ class FuseBase(object): if layer_id in subgraph.layers: # layer_id可能是属于子图的,此时删除父layer,即删除整个子图 subgraph.layers.pop(layer_id) - - def delete_layer_with_associated(self, graph, layer_id): - """ 删除不需要的中间layer及其相关连接点。 - """ - layer = graph.layers[layer_id] - outputs = graph.edges_out.get(layer_id, []) - inputs = graph.edges_in.get(layer_id, []) - - assert len( - inputs) <= 1, "There should be 0 or 1 input for deleted layer." - - if len(inputs) == 0: - for out in outputs: - while layer_id in graph.edges_in[out]: - index = graph.edges_in[out].index(layer_id) - del graph.edges_in[out][index] - - input_keys = list(graph.layers[out].inputs.keys()) - for k in input_keys: - if graph.layers[out].inputs[k] == layer.outputs[0]: - del graph.layers[out].inputs[k] - - del graph.layers[layer_id] - if layer_id in graph.edges_in: - del graph.edges_in[layer_id] - if layer_id in graph.edges_out: - del graph.edges_out[layer_id] - return - - # 将所有输出layer的输入layer进行替换 - for out in outputs: - for i in range(len(graph.edges_in[out])): - if graph.edges_in[out][i] == layer_id: - graph.edges_in[out][i] = inputs[0] - - # 将输出layer赋给输入layer的输出 - replace_index = graph.edges_out[inputs[0]].index(layer_id) - del graph.edges_out[inputs[0]][replace_index] - for i, out in enumerate(outputs): - graph.edges_out[inputs[0]].insert(replace_index + i, out) - for k, v in graph.layers[out].inputs.items(): - if v == layer.outputs[0]: - graph.layers[out].inputs[k] = list(layer.inputs.values())[0] - - del graph.layers[layer_id] - if layer_id in graph.edges_out: - del graph.edges_out[layer_id] - if layer_id in graph.edges_in: - del graph.edges_in[layer_id] + \ No newline at end of file -- GitLab