From 8c13edc16b4bb918cda6b6cf85ba63b21ef54911 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Fri, 9 Apr 2021 14:09:37 +0800 Subject: [PATCH] fix the tf --- x2paddle/decoder/tf_decoder.py | 8 +++++++- x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py | 6 ++++-- x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py | 6 ++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index d5d862a..69a0fc1 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -101,6 +101,7 @@ class TFGraphNode(GraphNode): @property def name(self): if hasattr(self, 'index'): + print(self.layer_type) return self.layer_name + "_p{}".format(self.index) return self.layer_name @@ -184,7 +185,7 @@ class TFGraph(Graph): node = super(TFGraph, self).get_node(new_node_name, copy) if node is None: return None - if node.layer_type == "Switch": + if node.layer_type in ["Switch", "Reshape", "Sub"]: if hasattr(node, 'index'): del node.index if len(items) == 1 and node.layer_type in self.multi_out_ops: @@ -284,6 +285,11 @@ class TFGraph(Graph): if node_name in self.output_nodes: idx = self.output_nodes.index(node_name) self.output_nodes[idx] = input_node.layer_name + if len(input_node.outputs) > 0: + self.output_nodes.pop(idx) + else: + self.output_nodes[idx] = input_node.layer_name + def _remove_cast_node(self): cast_node = list() diff --git a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py index e860366..fad7365 100644 --- a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py @@ -248,8 +248,10 @@ class TFOpMapper(OpMapper): def Transpose(self, node): input = self.graph.get_input_node(node, 0) perm = self.graph.get_input_node(node, 1) - assert perm.layer_type == "Const", "Perm of transpose OP should be Const" - perm = perm.value.tolist() + if perm.layer_type == "Const": + perm = perm.value.tolist() + else: + perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist() self.paddle_graph.add_layer( "paddle.transpose", diff --git a/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py index 5370519..527d1cc 100644 --- a/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py @@ -238,8 +238,10 @@ class TFOpMapper(OpMapper): def Transpose(self, node): input = self.graph.get_node(node.layer.input[0]) perm = self.graph.get_node(node.layer.input[1]) - assert perm.layer_type == "Const", "Perm of transpose OP should be Const" - perm = perm.value.tolist() + if perm.layer_type == "Const": + perm = perm.value.tolist() + else: + perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist() self.paddle_graph.add_layer( kernel="paddle.transpose", -- GitLab