From 46f46aaabaecb1e03606058ea974de6a53cf90fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 14 Mar 2018 16:02:48 +0800 Subject: [PATCH] Converter compatible for TF 1.6.0 --- mace/python/tools/tf_converter_lib.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 5b488b1e..dc989796 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -158,10 +158,15 @@ class TFConverter(object): def add_output_shape(outputs, op): output_shapes = [] for output in outputs: - if output.shape.num_elements() is not None: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) + output_shape = mace_pb2.OutputShape() + if isinstance(output, list): + output_shape.dims.extend(output) + elif isinstance(output, tf.Tensor): + if output.shape.num_elements() is not None: + output_shape.dims.extend(output.shape.as_list()) + else: + raise ValueError('output type not supported: ', type(output)) + output_shapes.append(output_shape) op.output_shape.extend(output_shapes) def add_tensor(self, name, shape, tf_dt, value): @@ -782,11 +787,11 @@ class TFConverter(object): self.unused_tensor.add(get_input_tensor(reshape_op, 1).name) if reshape_op.outputs[0].shape.ndims == 2: - shape = reshape_op.outputs[0].shape - from tensorflow.python.framework.tensor_shape import as_shape - reshape_op.outputs[0]._shape = as_shape([1, 1, shape[0], shape[1]]) + shape = [dim.value for dim in reshape_op.outputs[0].shape] + if len(shape) == 2: + shape = [1, 1, shape[0], shape[1]] op_def.output.extend([output.name for output in reshape_op.outputs]) - self.add_output_shape(reshape_op.outputs, op_def) + self.add_output_shape([shape], op_def) self.resolved_ops[reshape_op.name] = 1 def convert_normal_op(self, op): -- GitLab