提交 5f10b2c3 编写于 作者: 刘琦

Merge branch 'transform' into 'master'

Add identity op

See merge request !484
......@@ -101,6 +101,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
'AvgPool': self.convert_pooling,
'MaxPool': self.convert_pooling,
'Squeeze': self.convert_identity,
'Identity': self.convert_identity,
'Reshape': self.convert_reshape,
'Shape': self.convert_nop,
'Softmax': self.convert_softmax,
......@@ -153,12 +154,14 @@ class TensorflowConverter(base_converter.ConverterInterface):
def add_shape_info(self, tf_graph_def):
for node in tf_graph_def.node:
if node.name in self._option.input_nodes:
del node.attr['shape'].shape.dim[:]
node.attr['shape'].shape.dim.extend([
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
self._option.input_nodes[node.name].shape
])
for input_node in self._option.input_nodes.values():
if node.name == input_node.name \
or node.name + ':0' == input_node.name:
del node.attr['shape'].shape.dim[:]
node.attr['shape'].shape.dim.extend([
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
input_node.shape
])
@staticmethod
def get_scope(tensor_name):
......
......@@ -45,8 +45,11 @@ def to_int_list(long_list):
def add_shape_info(input_graph_def, input_nodes, input_shapes):
inputs_replaced_graph = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_nodes:
idx = input_nodes.index(node.name)
if node.name in input_nodes or node.name + ':0' in input_nodes:
if node.name in input_nodes:
idx = input_nodes.index(node.name)
else:
idx = input_nodes.index(node.name + ':0')
input_shape = input_shapes[idx]
print input_shape
placeholder_node = copy.deepcopy(node)
......
......@@ -65,6 +65,13 @@ def compare_output(platform, device_type, output_name, mace_out_value,
sys.exit(-1)
def normalize_tf_tensor_name(name):
if name.find(':') == -1:
return name + ':0'
else:
return name
def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, output_names):
import tensorflow as tf
......@@ -88,13 +95,14 @@ def validate_tf_model(platform, device_type, model_file, input_file,
common.formatted_file_name(input_file, input_names[i]))
input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name(
input_names[i] + ':0')
normalize_tf_tensor_name(input_names[i]))
input_dict[input_node] = input_value
output_nodes = []
for name in output_names:
output_nodes.extend(
[graph.get_tensor_by_name(name + ':0')])
[graph.get_tensor_by_name(
normalize_tf_tensor_name(name))])
output_values = session.run(output_nodes, feed_dict=input_dict)
for i in range(len(output_names)):
output_file_name = common.formatted_file_name(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册