提交 81d83a95 编写于 作者: L liuqi

Remove unused code.

上级 7523a01b
...@@ -8,19 +8,22 @@ from lib.python.tools import source_converter_lib ...@@ -8,19 +8,22 @@ from lib.python.tools import source_converter_lib
FLAGS = None FLAGS = None
def md5(fname): def file_checksum(fname):
hash_md5 = hashlib.md5() hash_func = hashlib.sha256()
with open(fname, "rb") as f: with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""): for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk) hash_func.update(chunk)
return hash_md5.hexdigest() return hash_func.hexdigest()
def main(unused_args): def main(unused_args):
if not os.path.isfile(FLAGS.model_file): if not os.path.isfile(FLAGS.model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!") print("Input graph file '" + FLAGS.model_file + "' does not exist!")
return -1 return -1
mode_pb_checksum = md5(FLAGS.model_file) model_checksum = file_checksum(FLAGS.input)
if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum))
return -1
if FLAGS.runtime == 'dsp': if FLAGS.runtime == 'dsp':
from lib.python.tools import tf_dsp_converter_lib from lib.python.tools import tf_dsp_converter_lib
...@@ -42,7 +45,7 @@ def main(unused_args): ...@@ -42,7 +45,7 @@ def main(unused_args):
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data) FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data)
else: else:
with open(FLAGS.output, "wb") as f: with open(FLAGS.output, "wb") as f:
...@@ -74,6 +77,11 @@ def parse_args(): ...@@ -74,6 +77,11 @@ def parse_args():
type=str, type=str,
default="", default="",
help="Caffe data file to load.") help="Caffe data file to load.")
parser.add_argument(
"--model_checksum",
type=str,
default="",
help="Model file sha256 checksum")
parser.add_argument( parser.add_argument(
"--output", "--output",
type=str, type=str,
......
import argparse
import sys
import hashlib
import tensorflow as tf
from tensorflow import gfile
from lib.proto import mace_pb2
from lib.python.tools import tf_converter_lib
from lib.python.tools import tf_dsp_converter_lib
from lib.python.tools import source_converter_lib
# ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3
FLAGS = None
def file_checksum(fname):
hash_func = hashlib.sha256()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_func.update(chunk)
return hash_func.hexdigest()
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
model_checksum = file_checksum(FLAGS.input)
if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum))
return -1
input_graph_def = tf.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f:
data = f.read()
input_graph_def.ParseFromString(data)
if FLAGS.runtime == 'dsp':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode)
else:
input_shape = []
if FLAGS.input_shape != "":
input_shape.extend([int(x) for x in FLAGS.input_shape.split(',')])
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data)
else:
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f:
# output_graph_def.ClearField('tensors')
f.write(str(output_graph_def))
print("Model conversion is completed.")
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--model_checksum",
type=str,
default="",
help="Model file sha256 checksum")
parser.add_argument(
"--output",
type=str,
default="",
help="File to save the output graph to.")
parser.add_argument(
"--runtime",
type=str,
default="cpu",
help="Runtime: cpu/gpu/dsp")
parser.add_argument(
"--input_node",
type=str,
default="input_node",
help="e.g., input_node")
parser.add_argument(
"--output_node",
type=str,
default="softmax",
help="e.g., softmax")
parser.add_argument(
"--data_type",
type=str,
default='DT_FLOAT',
help="e.g., DT_HALF/DT_FLOAT")
parser.add_argument(
"--output_type",
type=str,
default="pb",
help="output type: source/pb")
parser.add_argument(
"--template",
type=str,
default="",
help="template path")
parser.add_argument(
"--obfuscate",
type=str2bool,
nargs='?',
const=False,
default=False,
help="obfuscate model names")
parser.add_argument(
"--model_tag",
type=str,
default="",
help="model tag for generated function and namespace")
parser.add_argument(
"--winograd",
type=str2bool,
nargs='?',
const=False,
default=False,
help="open winograd convolution or not")
parser.add_argument(
"--dsp_mode",
type=int,
default=0,
help="dsp run mode, defalut=0")
parser.add_argument(
"--input_shape",
type=str,
default="",
help="input shape.")
parser.add_argument(
"--embed_model_data",
type=str2bool,
default=True,
help="input shape.")
return parser.parse_known_args()
if __name__ == '__main__':
FLAGS, unparsed = parse_args()
main(unused_args=[sys.argv[0]] + unparsed)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册