diff --git a/python/tools/converter.py b/python/tools/converter.py index a2be72359db23d4f760e06f7bfe405b28898ee61..3305d1e7e26f19058bce48ac2d6304cb152ca946 100644 --- a/python/tools/converter.py +++ b/python/tools/converter.py @@ -49,12 +49,9 @@ def main(unused_args): output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) else: - input_shape = [] - if FLAGS.input_shape != "": - input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')]) from lib.python.tools import tf_converter_lib output_graph_def = tf_converter_lib.convert_to_mace_pb( - FLAGS.model_file, FLAGS.input_node, input_shape, FLAGS.output_node, + FLAGS.model_file, FLAGS.input_node, FLAGS.input_shape, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) if FLAGS.output_type == 'source': diff --git a/python/tools/tf_converter_lib.py b/python/tools/tf_converter_lib.py index 4e3f1021c2d551a3a741ad0792f58ee6668df55e..5d1ea941160d0c6bfee2113d37e3e837d8b38fb4 100644 --- a/python/tools/tf_converter_lib.py +++ b/python/tools/tf_converter_lib.py @@ -118,34 +118,41 @@ class TFConverter(object): arg.i = self.dt return output_name + def add_input_transform(self, names, is_single): + for name in names: + if is_single: + new_input_name = MACE_INPUT_NODE_NAME + ":0" + else: + new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0" + op_def = self.net_def.op.add() + op_def.name = name + op_def.type = 'BufferToImage' + op_def.input.extend([new_input_name]) + op_def.output.extend([name+':0']) + + epsilon_arg = op_def.arg.add() + epsilon_arg.name = 'buffer_type' + epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL'] + + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + + def add_output_transform(self, names, is_single): + for name in names: + if is_single: + output_name = MACE_OUTPUT_NODE_NAME + ":0" + else: + output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0" + op_def = self.net_def.op.add() + op_def.name = output_name[:-2] + op_def.type = 'ImageToBuffer' + op_def.input.extend([name+':0']) + op_def.output.extend([output_name]) - def add_input_transform(self, name): - new_input_name = MACE_INPUT_NODE_NAME + ":0" - op_def = self.net_def.op.add() - op_def.name = name - op_def.type = 'BufferToImage' - op_def.input.extend([new_input_name]) - op_def.output.extend([name+':0']) - - epsilon_arg = op_def.arg.add() - epsilon_arg.name = 'buffer_type' - epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL'] - - arg = op_def.arg.add() - arg.name = 'T' - arg.i = self.dt - - def add_output_transform(self, name): - output_name = MACE_OUTPUT_NODE_NAME + ":0" - op_def = self.net_def.op.add() - op_def.name = output_name[:-2] - op_def.type = 'ImageToBuffer' - op_def.input.extend([name+':0']) - op_def.output.extend([output_name]) - - epsilon_arg = op_def.arg.add() - epsilon_arg.name = 'buffer_type' - epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL'] + epsilon_arg = op_def.arg.add() + epsilon_arg.name = 'buffer_type' + epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL'] @staticmethod def add_output_shape(outputs, op): @@ -794,19 +801,26 @@ class TFConverter(object): self.add_output_shape(op.outputs, op_def) self.resolved_ops[op.name] = 1 - def replace_in_out_name(self, input_name, output_name): - input_name = input_name + ":0" - output_name = output_name + ":0" - for op in self.net_def.op: - if len(op.input) > 0 and op.input[0] == input_name: - op.input[0] = MACE_INPUT_NODE_NAME + ":0" - if len(op.output) > 0 and op.output[0] == output_name: - op.output[0] = MACE_OUTPUT_NODE_NAME + ":0" - - - def convert(self, input_node, output_node): + def replace_in_out_name(self, input_names, output_names, is_single): + in_names = set([input_name + ":0" for input_name in input_names]) + out_names = set([output_name + ":0" for output_name in output_names]) + if is_single: + for op in self.net_def.op: + if len(op.input) > 0 and op.input[0] in in_names: + op.input[0] = MACE_INPUT_NODE_NAME + ':0' + if len(op.output) > 0 and op.output[0] in out_names: + op.output[0] = MACE_OUTPUT_NODE_NAME + ':0' + else: + for op in self.net_def.op: + if len(op.input) > 0 and op.input[0] in in_names: + op.input[0] = MACE_INPUT_NODE_NAME + '_' + op.input[0] + if len(op.output) > 0 and op.output[0] in out_names: + op.output[0] = MACE_OUTPUT_NODE_NAME + '_' + op.output[0] + + def convert(self, input_nodes, output_nodes): + is_single = len(input_nodes) == 1 and len(output_nodes) == 1 if self.device == 'gpu': - self.add_input_transform(input_node) + self.add_input_transform(input_nodes, is_single) for op in self.tf_ops: if self.resolved_ops[op.name] == 1: @@ -874,10 +888,10 @@ class TFConverter(object): raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) if self.device == 'gpu': - self.add_output_transform(output_node) + self.add_output_transform(output_nodes, is_single) if self.device == 'cpu': - self.replace_in_out_name(input_node, output_node) + self.replace_in_out_name(input_nodes, output_nodes, is_single) for key in self.resolved_ops: if self.resolved_ops[key] != 1: @@ -978,10 +992,12 @@ class Optimizer: new_net = self.fold_batch_norm() return new_net -def add_shape_info(input_graph_def, input_node, input_shape): +def add_shape_info(input_graph_def, input_nodes, input_shapes): inputs_replaced_graph = graph_pb2.GraphDef() for node in input_graph_def.node: - if node.name == input_node: + if node.name in input_nodes: + idx = input_nodes.index(node.name) + input_shape = input_shapes[idx] placeholder_node = copy.deepcopy(node) placeholder_node.attr.clear() placeholder_node.attr['shape'].shape.dim.extend([ @@ -1003,13 +1019,22 @@ def convert_to_mace_pb(model_file, input_node, input_shape, output_node, data_ty data = f.read() input_graph_def.ParseFromString(data) - input_graph_def = add_shape_info(input_graph_def, input_node, input_shape) + input_nodes = [x for x in input_node.split(',')] + input_shapes = [] + if input_shape != "": + input_shape_strs = [x for x in input_shape.split(':')] + for shape_str in input_shape_strs: + input_shapes.extend([[int(x) for x in shape_str.split(',')]]) + output_nodes = [x for x in output_node.split(',')] + assert len(input_nodes) == len(input_shapes) + + input_graph_def = add_shape_info(input_graph_def, input_nodes, input_shapes) with tf.Session() as session: with session.graph.as_default() as graph: tf.import_graph_def(input_graph_def, name="") ops = graph.get_operations() converter = TFConverter(ops, net_def, dt, device, winograd) - converter.convert(input_node, output_node) + converter.convert(input_nodes, output_nodes) optimizer = Optimizer(net_def, device) net_def = optimizer.optimize() print "Model Converted."