diff --git a/python/tools/converter.py b/python/tools/converter.py index 55b89d0149eb601ca363b9488ab45f3e6af703b1..09b64dd75cb0a6738c3a7233f2ce891cadf1f620 100644 --- a/python/tools/converter.py +++ b/python/tools/converter.py @@ -20,29 +20,45 @@ def main(unused_args): print("Input graph file '" + FLAGS.model_file + "' does not exist!") return -1 - model_checksum = file_checksum(FLAGS.input) + model_checksum = file_checksum(FLAGS.model_file) if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum: print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum)) return -1 - if FLAGS.runtime == 'dsp': - from lib.python.tools import tf_dsp_converter_lib - output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( - FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) - else: + if FLAGS.platform == 'caffe': + if not os.path.isfile(FLAGS.weight_file): + print("Input weight file '" + FLAGS.weight_file + "' does not exist!") + return -1 + + weight_checksum = file_checksum(FLAGS.weight_file) + if FLAGS.weight_checksum != "" and FLAGS.weight_checksum != weight_checksum: + print("Weight checksum mismatch: %s != %s" % (weight_checksum, FLAGS.weight_checksum)) + return -1 + + if FLAGS.runtime == 'dsp': + print("DSP not support caffe model yet.") + return -1 + input_shape = [] if FLAGS.input_shape != "": input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')]) - if FLAGS.platform == 'tensorflow': + from lib.python.tools import caffe_converter_lib + output_graph_def = caffe_converter_lib.convert_to_mace_pb( + FLAGS.model_file, FLAGS.weight_file, FLAGS.input_node, input_shape, FLAGS.output_node, + FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) + elif FLAGS.platform == 'tensorflow': + if FLAGS.runtime == 'dsp': + from lib.python.tools import tf_dsp_converter_lib + output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( + FLAGS.model_file, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) + else: + input_shape = [] + if FLAGS.input_shape != "": + input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')]) from lib.python.tools import tf_converter_lib output_graph_def = tf_converter_lib.convert_to_mace_pb( FLAGS.model_file, FLAGS.input_node, input_shape, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) - elif FLAGS.platform == 'caffe': - from lib.python.tools import caffe_converter_lib - output_graph_def = caffe_converter_lib.convert_to_mace_pb( - FLAGS.model_file, FLAGS.weight_file, FLAGS.input_node, input_shape, FLAGS.output_node, - FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) if FLAGS.output_type == 'source': source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate, @@ -82,6 +98,11 @@ def parse_args(): type=str, default="", help="Model file sha256 checksum") + parser.add_argument( + "--weight_checksum", + type=str, + default="", + help="Weight file sha256 checksum") parser.add_argument( "--output", type=str,