From c48cf6d8cb0fb52364f61e659c86672e388f78ed Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Thu, 17 Dec 2020 16:07:34 +0800 Subject: [PATCH] fix the tensorflow --- x2paddle/decoder/tf_decoder.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index 3509d5e..eb687ab 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -130,7 +130,7 @@ class TFGraph(Graph): def __init__(self, model, data_format="NHWC"): super(TFGraph, self).__init__(model) self.identity_map = dict() - self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2'] + self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2', 'Unpack'] self.tf_data_format = data_format self.graph_name = "TFModel" @@ -172,7 +172,8 @@ class TFGraph(Graph): self._remove_isolated_node() self._optimize_dialiation_conv() self._remove_identity_node() - self._remove_cast_node() +# self._remove_cast_node() + def get_node(self, node_name, copy=False): items = node_name.strip().split(':') @@ -192,6 +193,8 @@ class TFGraph(Graph): def get_input_node(self, node, idx=0, copy=False): input_node_name = node.layer.input[idx] + if idx > 0: + copy = True return self.get_node(input_node_name, copy) def remove_node(self, node_name): @@ -402,7 +405,7 @@ class TFDecoder(object): right_shape_been_input = False while not right_shape_been_input: try: - shape = input( + shape = raw_input( "Shape of Input(e.g. None,224,224,3): ") except: shape = input("Shape of Input(e.g. None,224,224,3): ") -- GitLab