diff --git a/python/tools/tf_converter.py b/python/tools/tf_converter.py index 2849725d2267d7217e13e0bee295c940dbfecbff..5ea9db4b39313877d6f0e4de8d6627c78ad02185 100644 --- a/python/tools/tf_converter.py +++ b/python/tools/tf_converter.py @@ -34,8 +34,10 @@ def main(unused_args): output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) else: + input_shape = [int(x) for x in FLAGS.input_shape.split(',')] output_graph_def = tf_converter_lib.convert_to_mace_pb( - input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) + input_graph_def, FLAGS.input_node, input_shape, FLAGS.output_node, + FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) if FLAGS.output_type == 'source': source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, @@ -124,6 +126,11 @@ def parse_args(): type=int, default=0, help="dsp run mode, defalut=0") + parser.add_argument( + "--input_shape", + type=str, + default="1,512,512,3", + help="input shape.") return parser.parse_known_args() diff --git a/python/tools/tf_converter_lib.py b/python/tools/tf_converter_lib.py index 7062221621d0c2fabd975c452e624935a032dbb5..d221bc54dc2661ea19f307d163ecafc4093fcb94 100644 --- a/python/tools/tf_converter_lib.py +++ b/python/tools/tf_converter_lib.py @@ -2,7 +2,12 @@ from lib.proto import mace_pb2 import tensorflow as tf import numpy as np import math +import copy from lib.python.tools import memory_optimizer +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import node_def_pb2 # TODO: support NCHW formt, now only support NHWC. padding_mode = { @@ -903,10 +908,27 @@ class Optimizer: new_net = self.fold_batch_norm() return new_net -def convert_to_mace_pb(input_graph_def, input_node, output_node, data_type, device, winograd): +def add_shape_info(input_graph_def, input_node, input_shape): + inputs_replaced_graph = graph_pb2.GraphDef() + for node in input_graph_def.node: + if node.name == input_node: + placeholder_node = copy.deepcopy(node) + placeholder_node.attr.clear() + placeholder_node.attr['shape'].shape.dim.extend([ + tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in input_shape + ]) + placeholder_node.attr['dtype'].CopyFrom(node.attr['dtype']) + inputs_replaced_graph.node.extend([placeholder_node]) + else: + inputs_replaced_graph.node.extend([copy.deepcopy(node)]) + return inputs_replaced_graph + + +def convert_to_mace_pb(input_graph_def, input_node, input_shape, output_node, data_type, device, winograd): net_def = mace_pb2.NetDef() dt = data_type_map[data_type] + input_graph_def = add_shape_info(input_graph_def, input_node, input_shape) with tf.Session() as session: with session.graph.as_default() as graph: tf.import_graph_def(input_graph_def, name="")