提交 9142790a 编写于 作者: 刘琦

Merge branch 'master' into 'master'

Calculate shape if it cannot be inferred

See merge request !666
......@@ -200,6 +200,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
tf_graph_def = tf.GraphDef()
with tf.gfile.Open(src_model_file, 'rb') as f:
tf_graph_def.ParseFromString(f.read())
self._placeholders = {}
self.add_shape_info(tf_graph_def)
with tf.Session() as session:
......@@ -240,6 +242,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
input_node.shape
])
self._placeholders[node.name + ':0'] = \
np.zeros(shape=input_node.shape, dtype=float)
@staticmethod
def get_scope(tensor_name):
......@@ -288,13 +292,18 @@ class TensorflowConverter(base_converter.ConverterInterface):
# this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length
@staticmethod
def infer_tensor_shape(tensor):
shape = tensor.shape.as_list()
def normalize_func(dim):
return dim if dim else - 1
return [normalize_func(dim) for dim in shape]
def infer_tensor_shape(self, tensor):
inferred_tensor_shape = tensor.shape.as_list()
inferred_success = True
for _, dim in enumerate(inferred_tensor_shape):
if dim is None:
inferred_success = False
break
if inferred_success:
return inferred_tensor_shape
tensor_shape = tf.shape(tensor).eval(feed_dict=self._placeholders)
return tensor_shape
def convert_nop(self, tf_op):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册