From 81d83a95fe07bb6668d388c8346389601d31bf53 Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 5 Mar 2018 16:12:20 +0800 Subject: [PATCH] Remove unused code. --- python/tools/converter.py | 20 +++-- python/tools/tf_converter.py | 155 ----------------------------------- 2 files changed, 14 insertions(+), 161 deletions(-) delete mode 100644 python/tools/tf_converter.py diff --git a/python/tools/converter.py b/python/tools/converter.py index 6ba52fbf..55b89d01 100644 --- a/python/tools/converter.py +++ b/python/tools/converter.py @@ -8,19 +8,22 @@ from lib.python.tools import source_converter_lib FLAGS = None -def md5(fname): - hash_md5 = hashlib.md5() +def file_checksum(fname): + hash_func = hashlib.sha256() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): - hash_md5.update(chunk) - return hash_md5.hexdigest() + hash_func.update(chunk) + return hash_func.hexdigest() def main(unused_args): if not os.path.isfile(FLAGS.model_file): print("Input graph file '" + FLAGS.model_file + "' does not exist!") return -1 - mode_pb_checksum = md5(FLAGS.model_file) + model_checksum = file_checksum(FLAGS.input) + if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum: + print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum)) + return -1 if FLAGS.runtime == 'dsp': from lib.python.tools import tf_dsp_converter_lib @@ -42,7 +45,7 @@ def main(unused_args): FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) if FLAGS.output_type == 'source': - source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, + source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data) else: with open(FLAGS.output, "wb") as f: @@ -74,6 +77,11 @@ def parse_args(): type=str, default="", help="Caffe data file to load.") + parser.add_argument( + "--model_checksum", + type=str, + default="", + help="Model file sha256 checksum") parser.add_argument( "--output", type=str, diff --git a/python/tools/tf_converter.py b/python/tools/tf_converter.py deleted file mode 100644 index 131d5f4c..00000000 --- a/python/tools/tf_converter.py +++ /dev/null @@ -1,155 +0,0 @@ -import argparse -import sys -import hashlib -import tensorflow as tf -from tensorflow import gfile -from lib.proto import mace_pb2 -from lib.python.tools import tf_converter_lib -from lib.python.tools import tf_dsp_converter_lib -from lib.python.tools import source_converter_lib - -# ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3 - -FLAGS = None - -def file_checksum(fname): - hash_func = hashlib.sha256() - with open(fname, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): - hash_func.update(chunk) - return hash_func.hexdigest() - -def main(unused_args): - if not gfile.Exists(FLAGS.input): - print("Input graph file '" + FLAGS.input + "' does not exist!") - return -1 - - model_checksum = file_checksum(FLAGS.input) - if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum: - print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum)) - return -1 - - input_graph_def = tf.GraphDef() - with gfile.Open(FLAGS.input, "rb") as f: - data = f.read() - input_graph_def.ParseFromString(data) - - if FLAGS.runtime == 'dsp': - output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( - input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) - else: - input_shape = [] - if FLAGS.input_shape != "": - input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')]) - output_graph_def = tf_converter_lib.convert_to_mace_pb( - input_graph_def, FLAGS.input_node, input_shape, FLAGS.output_node, - FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) - - if FLAGS.output_type == 'source': - source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate, - FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data) - else: - with gfile.GFile(FLAGS.output, "wb") as f: - f.write(output_graph_def.SerializeToString()) - with gfile.GFile(FLAGS.output + '_txt', "wb") as f: - # output_graph_def.ClearField('tensors') - f.write(str(output_graph_def)) - print("Model conversion is completed.") - -def str2bool(v): - if v.lower() in ('yes', 'true', 't', 'y', '1'): - return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): - return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') - -def parse_args(): - """Parses command line arguments.""" - parser = argparse.ArgumentParser() - parser.register("type", "bool", lambda v: v.lower() == "true") - parser.add_argument( - "--input", - type=str, - default="", - help="TensorFlow \'GraphDef\' file to load.") - parser.add_argument( - "--model_checksum", - type=str, - default="", - help="Model file sha256 checksum") - parser.add_argument( - "--output", - type=str, - default="", - help="File to save the output graph to.") - parser.add_argument( - "--runtime", - type=str, - default="cpu", - help="Runtime: cpu/gpu/dsp") - parser.add_argument( - "--input_node", - type=str, - default="input_node", - help="e.g., input_node") - parser.add_argument( - "--output_node", - type=str, - default="softmax", - help="e.g., softmax") - parser.add_argument( - "--data_type", - type=str, - default='DT_FLOAT', - help="e.g., DT_HALF/DT_FLOAT") - parser.add_argument( - "--output_type", - type=str, - default="pb", - help="output type: source/pb") - parser.add_argument( - "--template", - type=str, - default="", - help="template path") - parser.add_argument( - "--obfuscate", - type=str2bool, - nargs='?', - const=False, - default=False, - help="obfuscate model names") - parser.add_argument( - "--model_tag", - type=str, - default="", - help="model tag for generated function and namespace") - parser.add_argument( - "--winograd", - type=str2bool, - nargs='?', - const=False, - default=False, - help="open winograd convolution or not") - parser.add_argument( - "--dsp_mode", - type=int, - default=0, - help="dsp run mode, defalut=0") - parser.add_argument( - "--input_shape", - type=str, - default="", - help="input shape.") - parser.add_argument( - "--embed_model_data", - type=str2bool, - default=True, - help="input shape.") - return parser.parse_known_args() - - -if __name__ == '__main__': - FLAGS, unparsed = parse_args() - main(unused_args=[sys.argv[0]] + unparsed) -- GitLab