提交 775d6e6e 编写于 作者: S SunAhong1993

rename delete_layer

上级 934ee6a8
...@@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase): ...@@ -252,7 +252,7 @@ class DygraphTransposeElimination(FuseBase):
continue continue
for l in transpose_layers: for l in transpose_layers:
self.delete_layer_with_associated(_graph, l) _graph.delete_layer(l)
optimized_transpose_layers.extend(transpose_layers) optimized_transpose_layers.extend(transpose_layers)
optimized_reduce_layers.extend(reduce_layers) optimized_reduce_layers.extend(reduce_layers)
......
...@@ -268,7 +268,7 @@ class FuseBase(object): ...@@ -268,7 +268,7 @@ class FuseBase(object):
first_layer_id = list(match.keys())[0] first_layer_id = list(match.keys())[0]
subgraph = get_subgraph("", first_layer_id, graph) subgraph = get_subgraph("", first_layer_id, graph)
self.insert_new_layer(subgraph, parameters, match) self.insert_new_layer(subgraph, parameters, match)
self.delete_layer(graph) self.delete_match(graph)
graph.build() graph.build()
def perform_pattern_matcher(self, graph, match_kind="topo"): def perform_pattern_matcher(self, graph, match_kind="topo"):
...@@ -283,7 +283,7 @@ class FuseBase(object): ...@@ -283,7 +283,7 @@ class FuseBase(object):
pattern_matcher = PatternMatcher(self.pattern) pattern_matcher = PatternMatcher(self.pattern)
self.matches = pattern_matcher.operate(graph, match_kind) self.matches = pattern_matcher.operate(graph, match_kind)
def delete_layer(self, graph): def delete_match(self, graph):
""" 删除不需要的中间layer及其对应参数。 """ 删除不需要的中间layer及其对应参数。
""" """
for match in self.matches: for match in self.matches:
...@@ -298,52 +298,4 @@ class FuseBase(object): ...@@ -298,52 +298,4 @@ class FuseBase(object):
if layer_id in subgraph.layers: if layer_id in subgraph.layers:
# layer_id可能是属于子图的,此时删除父layer,即删除整个子图 # layer_id可能是属于子图的,此时删除父layer,即删除整个子图
subgraph.layers.pop(layer_id) subgraph.layers.pop(layer_id)
def delete_layer_with_associated(self, graph, layer_id): \ No newline at end of file
""" 删除不需要的中间layer及其相关连接点。
"""
layer = graph.layers[layer_id]
outputs = graph.edges_out.get(layer_id, [])
inputs = graph.edges_in.get(layer_id, [])
assert len(
inputs) <= 1, "There should be 0 or 1 input for deleted layer."
if len(inputs) == 0:
for out in outputs:
while layer_id in graph.edges_in[out]:
index = graph.edges_in[out].index(layer_id)
del graph.edges_in[out][index]
input_keys = list(graph.layers[out].inputs.keys())
for k in input_keys:
if graph.layers[out].inputs[k] == layer.outputs[0]:
del graph.layers[out].inputs[k]
del graph.layers[layer_id]
if layer_id in graph.edges_in:
del graph.edges_in[layer_id]
if layer_id in graph.edges_out:
del graph.edges_out[layer_id]
return
# 将所有输出layer的输入layer进行替换
for out in outputs:
for i in range(len(graph.edges_in[out])):
if graph.edges_in[out][i] == layer_id:
graph.edges_in[out][i] = inputs[0]
# 将输出layer赋给输入layer的输出
replace_index = graph.edges_out[inputs[0]].index(layer_id)
del graph.edges_out[inputs[0]][replace_index]
for i, out in enumerate(outputs):
graph.edges_out[inputs[0]].insert(replace_index + i, out)
for k, v in graph.layers[out].inputs.items():
if v == layer.outputs[0]:
graph.layers[out].inputs[k] = list(layer.inputs.values())[0]
del graph.layers[layer_id]
if layer_id in graph.edges_out:
del graph.edges_out[layer_id]
if layer_id in graph.edges_in:
del graph.edges_in[layer_id]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册