diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index 3509d5ee9d5399fa485916e38ad3eecc6edc227f..eb687ab3b9c1367d618d61bc32dd8f998a37d0de 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -130,7 +130,7 @@ class TFGraph(Graph): def __init__(self, model, data_format="NHWC"): super(TFGraph, self).__init__(model) self.identity_map = dict() - self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2'] + self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2', 'Unpack'] self.tf_data_format = data_format self.graph_name = "TFModel" @@ -172,7 +172,8 @@ class TFGraph(Graph): self._remove_isolated_node() self._optimize_dialiation_conv() self._remove_identity_node() - self._remove_cast_node() +# self._remove_cast_node() + def get_node(self, node_name, copy=False): items = node_name.strip().split(':') @@ -192,6 +193,8 @@ class TFGraph(Graph): def get_input_node(self, node, idx=0, copy=False): input_node_name = node.layer.input[idx] + if idx > 0: + copy = True return self.get_node(input_node_name, copy) def remove_node(self, node_name): @@ -402,7 +405,7 @@ class TFDecoder(object): right_shape_been_input = False while not right_shape_been_input: try: - shape = input( + shape = raw_input( "Shape of Input(e.g. None,224,224,3): ") except: shape = input("Shape of Input(e.g. None,224,224,3): ")