diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 848409cabb898d75c0b2d07c08e758caefcd4899..08f7b08a671301a319a2536200f120009102a59d 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -482,6 +482,69 @@ class TFOptimizer(object): 0].inputs del out_node.fluid_code.layers[0] + graph_copy = cp.deepcopy(self.graph) + for node_name in self.graph.topo_sort: + node = graph_copy.get_node(node_name) + if node is None: + continue + if node.layer_type in elementwise_ops: + can_be_removed = True + if len(node.fluid_code.layers) < 3: + continue + + numTranspose = 0 + numNotTranspose = 0 + + for i in range(len(node.fluid_code.layers)): + if node.fluid_code.layers[i].op == 'transpose': + numTranspose += 1 + elif node.fluid_code.layers[i].op != 'expand': + numNotTranspose += 1 + if numTranspose > numNotTranspose: + if node.fluid_code.layers[0].op == 'expand': + if node.fluid_code.layers[ + 1].op != 'transpose' or node.fluid_code.layers[ + 2].op != 'transpose': + continue + else: + true_node = self.graph.get_node(node_name) + true_node.fluid_code.layers[3].inputs[ + 'x'] = true_node.fluid_code.layers[1].inputs + true_node.fluid_code.layers[3].inputs[ + 'y'] = true_node.fluid_code.layers[2].inputs + + l = Layer() + l.op = 'transpose' + l.inputs = true_node.fluid_code.layers[3].output + l.param_attr = {'perm': [0, 3, 1, 2]} + if type(l.inputs) == str: + l.output = l.inputs + else: + l.output = l.inputs.layer_name + true_node.fluid_code.layers.append(l) + del true_node.fluid_code.layers[1] + del true_node.fluid_code.layers[1] + else: + if node.fluid_code.layers[ + 0].op != 'transpose' or node.fluid_code.layers[ + 1].op != 'transpose': + continue + else: + true_node = self.graph.get_node(node_name) + true_node.fluid_code.layers[2].inputs[ + 'x'] = true_node.fluid_code.layers[0].inputs + true_node.fluid_code.layers[2].inputs[ + 'y'] = true_node.fluid_code.layers[1].inputs + + l = Layer() + l.op = 'transpose' + l.inputs = true_node.fluid_code.layers[2].output + l.param_attr = {'perm': [0, 3, 1, 2]} + l.output = l.inputs.layer_name + true_node.fluid_code.layers.append(l) + del true_node.fluid_code.layers[0] + del true_node.fluid_code.layers[0] + def make_nchw_input_output(self): for i, name in enumerate(self.graph.input_nodes): node = self.graph.get_node(name)