提交 8c13edc1 编写于 作者: S SunAhong1993

fix the tf

上级 4f892dd9
......@@ -101,6 +101,7 @@ class TFGraphNode(GraphNode):
@property
def name(self):
if hasattr(self, 'index'):
print(self.layer_type)
return self.layer_name + "_p{}".format(self.index)
return self.layer_name
......@@ -184,7 +185,7 @@ class TFGraph(Graph):
node = super(TFGraph, self).get_node(new_node_name, copy)
if node is None:
return None
if node.layer_type == "Switch":
if node.layer_type in ["Switch", "Reshape", "Sub"]:
if hasattr(node, 'index'):
del node.index
if len(items) == 1 and node.layer_type in self.multi_out_ops:
......@@ -284,6 +285,11 @@ class TFGraph(Graph):
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name
if len(input_node.outputs) > 0:
self.output_nodes.pop(idx)
else:
self.output_nodes[idx] = input_node.layer_name
def _remove_cast_node(self):
cast_node = list()
......
......@@ -248,8 +248,10 @@ class TFOpMapper(OpMapper):
def Transpose(self, node):
input = self.graph.get_input_node(node, 0)
perm = self.graph.get_input_node(node, 1)
assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
perm = perm.value.tolist()
if perm.layer_type == "Const":
perm = perm.value.tolist()
else:
perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist()
self.paddle_graph.add_layer(
"paddle.transpose",
......
......@@ -238,8 +238,10 @@ class TFOpMapper(OpMapper):
def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0])
perm = self.graph.get_node(node.layer.input[1])
assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
perm = perm.value.tolist()
if perm.layer_type == "Const":
perm = perm.value.tolist()
else:
perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist()
self.paddle_graph.add_layer(
kernel="paddle.transpose",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册