diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index ac3333c7d597f4e62e98a934eb16fdf6b61c0f07..a4430841d5de678f2bbf37a295622e04c3c4a8c8 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -200,6 +200,8 @@ class TensorflowConverter(base_converter.ConverterInterface): tf_graph_def = tf.GraphDef() with tf.gfile.Open(src_model_file, 'rb') as f: tf_graph_def.ParseFromString(f.read()) + + self._placeholders = {} self.add_shape_info(tf_graph_def) with tf.Session() as session: @@ -240,6 +242,8 @@ class TensorflowConverter(base_converter.ConverterInterface): tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in input_node.shape ]) + self._placeholders[node.name + ':0'] = \ + np.zeros(shape=input_node.shape, dtype=float) @staticmethod def get_scope(tensor_name): @@ -288,13 +292,18 @@ class TensorflowConverter(base_converter.ConverterInterface): # this function tries to infer tensor shape, but some dimension shape # may be undefined due to variance of input length - @staticmethod - def infer_tensor_shape(tensor): - shape = tensor.shape.as_list() - - def normalize_func(dim): - return dim if dim else - 1 - return [normalize_func(dim) for dim in shape] + def infer_tensor_shape(self, tensor): + inferred_tensor_shape = tensor.shape.as_list() + inferred_success = True + for _, dim in enumerate(inferred_tensor_shape): + if dim is None: + inferred_success = False + break + if inferred_success: + return inferred_tensor_shape + + tensor_shape = tf.shape(tensor).eval(feed_dict=self._placeholders) + return tensor_shape def convert_nop(self, tf_op): pass