diff --git a/python/tools/tf_converter.py b/python/tools/tf_converter.py index 5ea9db4b39313877d6f0e4de8d6627c78ad02185..0da5420fd954646679511335647d63a63867c1c1 100644 --- a/python/tools/tf_converter.py +++ b/python/tools/tf_converter.py @@ -34,7 +34,9 @@ def main(unused_args): output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) else: - input_shape = [int(x) for x in FLAGS.input_shape.split(',')] + input_shape = [] + if FLAGS.input_shape != "": + input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')]) output_graph_def = tf_converter_lib.convert_to_mace_pb( input_graph_def, FLAGS.input_node, input_shape, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) @@ -129,7 +131,7 @@ def parse_args(): parser.add_argument( "--input_shape", type=str, - default="1,512,512,3", + default="", help="input shape.") return parser.parse_known_args() diff --git a/python/tools/tf_converter_lib.py b/python/tools/tf_converter_lib.py index d221bc54dc2661ea19f307d163ecafc4093fcb94..bf714231efd4a0a491af41e34a615254489b0e3e 100644 --- a/python/tools/tf_converter_lib.py +++ b/python/tools/tf_converter_lib.py @@ -4,10 +4,8 @@ import numpy as np import math import copy from lib.python.tools import memory_optimizer -from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import tensor_shape_pb2 -from tensorflow.core.framework import node_def_pb2 # TODO: support NCHW formt, now only support NHWC. padding_mode = {