提交 0e4a49a8 编写于 作者: L Liangliang He

Add model checksum and change build path format

上级 16f2f26c
...@@ -12,19 +12,23 @@ from lib.python.tools import source_converter_lib ...@@ -12,19 +12,23 @@ from lib.python.tools import source_converter_lib
FLAGS = None FLAGS = None
def md5(fname): def file_checksum(fname):
hash_md5 = hashlib.md5() hash_func = hashlib.sha256()
with open(fname, "rb") as f: with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""): for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk) hash_func.update(chunk)
return hash_md5.hexdigest() return hash_func.hexdigest()
def main(unused_args): def main(unused_args):
if not gfile.Exists(FLAGS.input): if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!") print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1 return -1
mode_pb_checksum = md5(FLAGS.input) model_checksum = file_checksum(FLAGS.input)
if FLAGS.model_checksum != "" and FLAGS.model_checksum != model_checksum:
print("Model checksum mismatch: %s != %s" % (model_checksum, FLAGS.model_checksum))
return -1
input_graph_def = tf.GraphDef() input_graph_def = tf.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f: with gfile.Open(FLAGS.input, "rb") as f:
data = f.read() data = f.read()
...@@ -42,7 +46,7 @@ def main(unused_args): ...@@ -42,7 +46,7 @@ def main(unused_args):
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, source_converter_lib.convert_to_source(output_graph_def, model_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data) FLAGS.model_tag, FLAGS.output, FLAGS.runtime, FLAGS.embed_model_data)
else: else:
with gfile.GFile(FLAGS.output, "wb") as f: with gfile.GFile(FLAGS.output, "wb") as f:
...@@ -69,6 +73,11 @@ def parse_args(): ...@@ -69,6 +73,11 @@ def parse_args():
type=str, type=str,
default="", default="",
help="TensorFlow \'GraphDef\' file to load.") help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--model_checksum",
type=str,
default="",
help="Model file sha256 checksum")
parser.add_argument( parser.add_argument(
"--output", "--output",
type=str, type=str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册