From e29dcb56e7fc7bab7312c370c357d4ba04d3c202 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Tue, 17 Jul 2018 15:01:09 +0800 Subject: [PATCH] Calculate shape if it cannot be inferred --- .../converter_tool/tensorflow_converter.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index ac3333c7..a4430841 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -200,6 +200,8 @@ class TensorflowConverter(base_converter.ConverterInterface): tf_graph_def = tf.GraphDef() with tf.gfile.Open(src_model_file, 'rb') as f: tf_graph_def.ParseFromString(f.read()) + + self._placeholders = {} self.add_shape_info(tf_graph_def) with tf.Session() as session: @@ -240,6 +242,8 @@ class TensorflowConverter(base_converter.ConverterInterface): tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in input_node.shape ]) + self._placeholders[node.name + ':0'] = \ + np.zeros(shape=input_node.shape, dtype=float) @staticmethod def get_scope(tensor_name): @@ -288,13 +292,18 @@ class TensorflowConverter(base_converter.ConverterInterface): # this function tries to infer tensor shape, but some dimension shape # may be undefined due to variance of input length - @staticmethod - def infer_tensor_shape(tensor): - shape = tensor.shape.as_list() - - def normalize_func(dim): - return dim if dim else - 1 - return [normalize_func(dim) for dim in shape] + def infer_tensor_shape(self, tensor): + inferred_tensor_shape = tensor.shape.as_list() + inferred_success = True + for _, dim in enumerate(inferred_tensor_shape): + if dim is None: + inferred_success = False + break + if inferred_success: + return inferred_tensor_shape + + tensor_shape = tf.shape(tensor).eval(feed_dict=self._placeholders) + return tensor_shape def convert_nop(self, tf_op): pass -- GitLab