diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index d5d862ad602e5f93f1f7923ca8dd5f28573d7049..69a0fc1807bc3bc51e2f17dc36e73d044397d254 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -101,6 +101,7 @@ class TFGraphNode(GraphNode): @property def name(self): if hasattr(self, 'index'): + print(self.layer_type) return self.layer_name + "_p{}".format(self.index) return self.layer_name @@ -184,7 +185,7 @@ class TFGraph(Graph): node = super(TFGraph, self).get_node(new_node_name, copy) if node is None: return None - if node.layer_type == "Switch": + if node.layer_type in ["Switch", "Reshape", "Sub"]: if hasattr(node, 'index'): del node.index if len(items) == 1 and node.layer_type in self.multi_out_ops: @@ -284,6 +285,11 @@ class TFGraph(Graph): if node_name in self.output_nodes: idx = self.output_nodes.index(node_name) self.output_nodes[idx] = input_node.layer_name + if len(input_node.outputs) > 0: + self.output_nodes.pop(idx) + else: + self.output_nodes[idx] = input_node.layer_name + def _remove_cast_node(self): cast_node = list() diff --git a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py index e860366b919e6c4e526bd83dfa7f7d85a38493cf..fad736595fbdb91aed5743476f661034660d0da1 100644 --- a/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/dygraph/tf2paddle/tf_op_mapper.py @@ -248,8 +248,10 @@ class TFOpMapper(OpMapper): def Transpose(self, node): input = self.graph.get_input_node(node, 0) perm = self.graph.get_input_node(node, 1) - assert perm.layer_type == "Const", "Perm of transpose OP should be Const" - perm = perm.value.tolist() + if perm.layer_type == "Const": + perm = perm.value.tolist() + else: + perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist() self.paddle_graph.add_layer( "paddle.transpose", diff --git a/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py b/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py index 5370519671e1af32fae95569bb3c61e91e4e9222..527d1cca24dbb5c7b20ce2f1f1255f91f2044b62 100644 --- a/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py +++ b/x2paddle/op_mapper/static/tf2paddle/tf_op_mapper.py @@ -238,8 +238,10 @@ class TFOpMapper(OpMapper): def Transpose(self, node): input = self.graph.get_node(node.layer.input[0]) perm = self.graph.get_node(node.layer.input[1]) - assert perm.layer_type == "Const", "Perm of transpose OP should be Const" - perm = perm.value.tolist() + if perm.layer_type == "Const": + perm = perm.value.tolist() + else: + perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist() self.paddle_graph.add_layer( kernel="paddle.transpose",