From ad89666ae2965b2d2d9a371333b7bc2a573b14e3 Mon Sep 17 00:00:00 2001 From: mamingjie-China Date: Fri, 11 Oct 2019 15:28:18 +0800 Subject: [PATCH] modify --- x2paddle/optimizer/tf_optimizer.py | 63 ++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 848409c..08f7b08 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -482,6 +482,69 @@ class TFOptimizer(object): 0].inputs del out_node.fluid_code.layers[0] + graph_copy = cp.deepcopy(self.graph) + for node_name in self.graph.topo_sort: + node = graph_copy.get_node(node_name) + if node is None: + continue + if node.layer_type in elementwise_ops: + can_be_removed = True + if len(node.fluid_code.layers) < 3: + continue + + numTranspose = 0 + numNotTranspose = 0 + + for i in range(len(node.fluid_code.layers)): + if node.fluid_code.layers[i].op == 'transpose': + numTranspose += 1 + elif node.fluid_code.layers[i].op != 'expand': + numNotTranspose += 1 + if numTranspose > numNotTranspose: + if node.fluid_code.layers[0].op == 'expand': + if node.fluid_code.layers[ + 1].op != 'transpose' or node.fluid_code.layers[ + 2].op != 'transpose': + continue + else: + true_node = self.graph.get_node(node_name) + true_node.fluid_code.layers[3].inputs[ + 'x'] = true_node.fluid_code.layers[1].inputs + true_node.fluid_code.layers[3].inputs[ + 'y'] = true_node.fluid_code.layers[2].inputs + + l = Layer() + l.op = 'transpose' + l.inputs = true_node.fluid_code.layers[3].output + l.param_attr = {'perm': [0, 3, 1, 2]} + if type(l.inputs) == str: + l.output = l.inputs + else: + l.output = l.inputs.layer_name + true_node.fluid_code.layers.append(l) + del true_node.fluid_code.layers[1] + del true_node.fluid_code.layers[1] + else: + if node.fluid_code.layers[ + 0].op != 'transpose' or node.fluid_code.layers[ + 1].op != 'transpose': + continue + else: + true_node = self.graph.get_node(node_name) + true_node.fluid_code.layers[2].inputs[ + 'x'] = true_node.fluid_code.layers[0].inputs + true_node.fluid_code.layers[2].inputs[ + 'y'] = true_node.fluid_code.layers[1].inputs + + l = Layer() + l.op = 'transpose' + l.inputs = true_node.fluid_code.layers[2].output + l.param_attr = {'perm': [0, 3, 1, 2]} + l.output = l.inputs.layer_name + true_node.fluid_code.layers.append(l) + del true_node.fluid_code.layers[0] + del true_node.fluid_code.layers[0] + def make_nchw_input_output(self): for i, name in enumerate(self.graph.input_nodes): node = self.graph.get_node(name) -- GitLab