提交 4fc44d15 编写于 作者: L liuqi

TF converter support multiple inputs or outputs.

上级 16023300
......@@ -49,12 +49,9 @@ def main(unused_args):
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode)
else:
input_shape = []
if FLAGS.input_shape != "":
input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')])
from lib.python.tools import tf_converter_lib
output_graph_def = tf_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.model_file, FLAGS.input_node, FLAGS.input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source':
......
......@@ -118,34 +118,41 @@ class TFConverter(object):
arg.i = self.dt
return output_name
def add_input_transform(self, names, is_single):
for name in names:
if is_single:
new_input_name = MACE_INPUT_NODE_NAME + ":0"
else:
new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add()
op_def.name = name
op_def.type = 'BufferToImage'
op_def.input.extend([new_input_name])
op_def.output.extend([name+':0'])
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL']
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
def add_output_transform(self, names, is_single):
for name in names:
if is_single:
output_name = MACE_OUTPUT_NODE_NAME + ":0"
else:
output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0"
op_def = self.net_def.op.add()
op_def.name = output_name[:-2]
op_def.type = 'ImageToBuffer'
op_def.input.extend([name+':0'])
op_def.output.extend([output_name])
def add_input_transform(self, name):
new_input_name = MACE_INPUT_NODE_NAME + ":0"
op_def = self.net_def.op.add()
op_def.name = name
op_def.type = 'BufferToImage'
op_def.input.extend([new_input_name])
op_def.output.extend([name+':0'])
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL']
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
def add_output_transform(self, name):
output_name = MACE_OUTPUT_NODE_NAME + ":0"
op_def = self.net_def.op.add()
op_def.name = output_name[:-2]
op_def.type = 'ImageToBuffer'
op_def.input.extend([name+':0'])
op_def.output.extend([output_name])
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL']
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL']
@staticmethod
def add_output_shape(outputs, op):
......@@ -794,19 +801,26 @@ class TFConverter(object):
self.add_output_shape(op.outputs, op_def)
self.resolved_ops[op.name] = 1
def replace_in_out_name(self, input_name, output_name):
input_name = input_name + ":0"
output_name = output_name + ":0"
for op in self.net_def.op:
if len(op.input) > 0 and op.input[0] == input_name:
op.input[0] = MACE_INPUT_NODE_NAME + ":0"
if len(op.output) > 0 and op.output[0] == output_name:
op.output[0] = MACE_OUTPUT_NODE_NAME + ":0"
def convert(self, input_node, output_node):
def replace_in_out_name(self, input_names, output_names, is_single):
in_names = set([input_name + ":0" for input_name in input_names])
out_names = set([output_name + ":0" for output_name in output_names])
if is_single:
for op in self.net_def.op:
if len(op.input) > 0 and op.input[0] in in_names:
op.input[0] = MACE_INPUT_NODE_NAME + ':0'
if len(op.output) > 0 and op.output[0] in out_names:
op.output[0] = MACE_OUTPUT_NODE_NAME + ':0'
else:
for op in self.net_def.op:
if len(op.input) > 0 and op.input[0] in in_names:
op.input[0] = MACE_INPUT_NODE_NAME + '_' + op.input[0]
if len(op.output) > 0 and op.output[0] in out_names:
op.output[0] = MACE_OUTPUT_NODE_NAME + '_' + op.output[0]
def convert(self, input_nodes, output_nodes):
is_single = len(input_nodes) == 1 and len(output_nodes) == 1
if self.device == 'gpu':
self.add_input_transform(input_node)
self.add_input_transform(input_nodes, is_single)
for op in self.tf_ops:
if self.resolved_ops[op.name] == 1:
......@@ -874,10 +888,10 @@ class TFConverter(object):
raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type))
if self.device == 'gpu':
self.add_output_transform(output_node)
self.add_output_transform(output_nodes, is_single)
if self.device == 'cpu':
self.replace_in_out_name(input_node, output_node)
self.replace_in_out_name(input_nodes, output_nodes, is_single)
for key in self.resolved_ops:
if self.resolved_ops[key] != 1:
......@@ -978,10 +992,12 @@ class Optimizer:
new_net = self.fold_batch_norm()
return new_net
def add_shape_info(input_graph_def, input_node, input_shape):
def add_shape_info(input_graph_def, input_nodes, input_shapes):
inputs_replaced_graph = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name == input_node:
if node.name in input_nodes:
idx = input_nodes.index(node.name)
input_shape = input_shapes[idx]
placeholder_node = copy.deepcopy(node)
placeholder_node.attr.clear()
placeholder_node.attr['shape'].shape.dim.extend([
......@@ -1003,13 +1019,22 @@ def convert_to_mace_pb(model_file, input_node, input_shape, output_node, data_ty
data = f.read()
input_graph_def.ParseFromString(data)
input_graph_def = add_shape_info(input_graph_def, input_node, input_shape)
input_nodes = [x for x in input_node.split(',')]
input_shapes = []
if input_shape != "":
input_shape_strs = [x for x in input_shape.split(':')]
for shape_str in input_shape_strs:
input_shapes.extend([[int(x) for x in shape_str.split(',')]])
output_nodes = [x for x in output_node.split(',')]
assert len(input_nodes) == len(input_shapes)
input_graph_def = add_shape_info(input_graph_def, input_nodes, input_shapes)
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations()
converter = TFConverter(ops, net_def, dt, device, winograd)
converter.convert(input_node, output_node)
converter.convert(input_nodes, output_nodes)
optimizer = Optimizer(net_def, device)
net_def = optimizer.optimize()
print "Model Converted."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册