diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 81bf2027eb4e382df13146d3c8a102c67705cf8e..1d9619386338426a2ebc00f266ee8d7ef45a2c47 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -101,6 +101,7 @@ class TensorflowConverter(base_converter.ConverterInterface): 'AvgPool': self.convert_pooling, 'MaxPool': self.convert_pooling, 'Squeeze': self.convert_identity, + 'Identity': self.convert_identity, 'Reshape': self.convert_reshape, 'Shape': self.convert_nop, 'Softmax': self.convert_softmax, @@ -153,12 +154,14 @@ class TensorflowConverter(base_converter.ConverterInterface): def add_shape_info(self, tf_graph_def): for node in tf_graph_def.node: - if node.name in self._option.input_nodes: - del node.attr['shape'].shape.dim[:] - node.attr['shape'].shape.dim.extend([ - tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in - self._option.input_nodes[node.name].shape - ]) + for input_node in self._option.input_nodes.values(): + if node.name == input_node.name \ + or node.name + ':0' == input_node.name: + del node.attr['shape'].shape.dim[:] + node.attr['shape'].shape.dim.extend([ + tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in + input_node.shape + ]) @staticmethod def get_scope(tensor_name): diff --git a/mace/python/tools/tf_ops_stats.py b/mace/python/tools/tf_ops_stats.py index 21d2db80a9b7fd60a72dee059e3a4f02bfeec198..2d9152392cfd3a0802fc32d7433098979bc56115 100644 --- a/mace/python/tools/tf_ops_stats.py +++ b/mace/python/tools/tf_ops_stats.py @@ -45,8 +45,11 @@ def to_int_list(long_list): def add_shape_info(input_graph_def, input_nodes, input_shapes): inputs_replaced_graph = graph_pb2.GraphDef() for node in input_graph_def.node: - if node.name in input_nodes: - idx = input_nodes.index(node.name) + if node.name in input_nodes or node.name + ':0' in input_nodes: + if node.name in input_nodes: + idx = input_nodes.index(node.name) + else: + idx = input_nodes.index(node.name + ':0') input_shape = input_shapes[idx] print input_shape placeholder_node = copy.deepcopy(node) diff --git a/tools/validate.py b/tools/validate.py index dba6a3e2db2b2b38b53d4f1c925645f246ebd3a2..b98fd0894adff8f44bb9109231274c147d14c372 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -65,6 +65,13 @@ def compare_output(platform, device_type, output_name, mace_out_value, sys.exit(-1) +def normalize_tf_tensor_name(name): + if name.find(':') == -1: + return name + ':0' + else: + return name + + def validate_tf_model(platform, device_type, model_file, input_file, mace_out_file, input_names, input_shapes, output_names): import tensorflow as tf @@ -88,13 +95,14 @@ def validate_tf_model(platform, device_type, model_file, input_file, common.formatted_file_name(input_file, input_names[i])) input_value = input_value.reshape(input_shapes[i]) input_node = graph.get_tensor_by_name( - input_names[i] + ':0') + normalize_tf_tensor_name(input_names[i])) input_dict[input_node] = input_value output_nodes = [] for name in output_names: output_nodes.extend( - [graph.get_tensor_by_name(name + ':0')]) + [graph.get_tensor_by_name( + normalize_tf_tensor_name(name))]) output_values = session.run(output_nodes, feed_dict=input_dict) for i in range(len(output_names)): output_file_name = common.formatted_file_name(