提交 ad89666a 编写于 作者: M mamingjie-China

modify

上级 74ed7822
...@@ -482,6 +482,69 @@ class TFOptimizer(object): ...@@ -482,6 +482,69 @@ class TFOptimizer(object):
0].inputs 0].inputs
del out_node.fluid_code.layers[0] del out_node.fluid_code.layers[0]
graph_copy = cp.deepcopy(self.graph)
for node_name in self.graph.topo_sort:
node = graph_copy.get_node(node_name)
if node is None:
continue
if node.layer_type in elementwise_ops:
can_be_removed = True
if len(node.fluid_code.layers) < 3:
continue
numTranspose = 0
numNotTranspose = 0
for i in range(len(node.fluid_code.layers)):
if node.fluid_code.layers[i].op == 'transpose':
numTranspose += 1
elif node.fluid_code.layers[i].op != 'expand':
numNotTranspose += 1
if numTranspose > numNotTranspose:
if node.fluid_code.layers[0].op == 'expand':
if node.fluid_code.layers[
1].op != 'transpose' or node.fluid_code.layers[
2].op != 'transpose':
continue
else:
true_node = self.graph.get_node(node_name)
true_node.fluid_code.layers[3].inputs[
'x'] = true_node.fluid_code.layers[1].inputs
true_node.fluid_code.layers[3].inputs[
'y'] = true_node.fluid_code.layers[2].inputs
l = Layer()
l.op = 'transpose'
l.inputs = true_node.fluid_code.layers[3].output
l.param_attr = {'perm': [0, 3, 1, 2]}
if type(l.inputs) == str:
l.output = l.inputs
else:
l.output = l.inputs.layer_name
true_node.fluid_code.layers.append(l)
del true_node.fluid_code.layers[1]
del true_node.fluid_code.layers[1]
else:
if node.fluid_code.layers[
0].op != 'transpose' or node.fluid_code.layers[
1].op != 'transpose':
continue
else:
true_node = self.graph.get_node(node_name)
true_node.fluid_code.layers[2].inputs[
'x'] = true_node.fluid_code.layers[0].inputs
true_node.fluid_code.layers[2].inputs[
'y'] = true_node.fluid_code.layers[1].inputs
l = Layer()
l.op = 'transpose'
l.inputs = true_node.fluid_code.layers[2].output
l.param_attr = {'perm': [0, 3, 1, 2]}
l.output = l.inputs.layer_name
true_node.fluid_code.layers.append(l)
del true_node.fluid_code.layers[0]
del true_node.fluid_code.layers[0]
def make_nchw_input_output(self): def make_nchw_input_output(self):
for i, name in enumerate(self.graph.input_nodes): for i, name in enumerate(self.graph.input_nodes):
node = self.graph.get_node(name) node = self.graph.get_node(name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册