提交 9142790a 编写于 作者: 刘琦

Merge branch 'master' into 'master'

Calculate shape if it cannot be inferred

See merge request !666
...@@ -200,6 +200,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -200,6 +200,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
tf_graph_def = tf.GraphDef() tf_graph_def = tf.GraphDef()
with tf.gfile.Open(src_model_file, 'rb') as f: with tf.gfile.Open(src_model_file, 'rb') as f:
tf_graph_def.ParseFromString(f.read()) tf_graph_def.ParseFromString(f.read())
self._placeholders = {}
self.add_shape_info(tf_graph_def) self.add_shape_info(tf_graph_def)
with tf.Session() as session: with tf.Session() as session:
...@@ -240,6 +242,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -240,6 +242,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
input_node.shape input_node.shape
]) ])
self._placeholders[node.name + ':0'] = \
np.zeros(shape=input_node.shape, dtype=float)
@staticmethod @staticmethod
def get_scope(tensor_name): def get_scope(tensor_name):
...@@ -288,13 +292,18 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -288,13 +292,18 @@ class TensorflowConverter(base_converter.ConverterInterface):
# this function tries to infer tensor shape, but some dimension shape # this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length # may be undefined due to variance of input length
@staticmethod def infer_tensor_shape(self, tensor):
def infer_tensor_shape(tensor): inferred_tensor_shape = tensor.shape.as_list()
shape = tensor.shape.as_list() inferred_success = True
for _, dim in enumerate(inferred_tensor_shape):
def normalize_func(dim): if dim is None:
return dim if dim else - 1 inferred_success = False
return [normalize_func(dim) for dim in shape] break
if inferred_success:
return inferred_tensor_shape
tensor_shape = tf.shape(tensor).eval(feed_dict=self._placeholders)
return tensor_shape
def convert_nop(self, tf_op): def convert_nop(self, tf_op):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册