提交 c48cf6d8 编写于 作者: S SunAhong1993

fix the tensorflow

上级 6ce10dc4
......@@ -130,7 +130,7 @@ class TFGraph(Graph):
def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model)
self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2', 'Unpack']
self.tf_data_format = data_format
self.graph_name = "TFModel"
......@@ -172,7 +172,8 @@ class TFGraph(Graph):
self._remove_isolated_node()
self._optimize_dialiation_conv()
self._remove_identity_node()
self._remove_cast_node()
# self._remove_cast_node()
def get_node(self, node_name, copy=False):
items = node_name.strip().split(':')
......@@ -192,6 +193,8 @@ class TFGraph(Graph):
def get_input_node(self, node, idx=0, copy=False):
input_node_name = node.layer.input[idx]
if idx > 0:
copy = True
return self.get_node(input_node_name, copy)
def remove_node(self, node_name):
......@@ -402,7 +405,7 @@ class TFDecoder(object):
right_shape_been_input = False
while not right_shape_been_input:
try:
shape = input(
shape = raw_input(
"Shape of Input(e.g. None,224,224,3): ")
except:
shape = input("Shape of Input(e.g. None,224,224,3): ")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册