From c264d3b2b6c26a0dd80b2811458033b312593dee Mon Sep 17 00:00:00 2001 From: yejianwu Date: Thu, 18 Jan 2018 20:18:06 +0800 Subject: [PATCH] merge commit in mace repo --- python/tools/model.template | 4 ++++ python/tools/source_converter_lib.py | 3 ++- python/tools/tf_converter.py | 11 ++++++++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/tools/model.template b/python/tools/model.template index 06832877..03273053 100644 --- a/python/tools/model.template +++ b/python/tools/model.template @@ -263,6 +263,10 @@ NetDef CreateNet() { return net_def; } +const std::string ModelChecksum() { + return {{ model_pb_checksum|tojson }}; +} + } // namespace {{tag}} } // namespace mace {% endif %} diff --git a/python/tools/source_converter_lib.py b/python/tools/source_converter_lib.py index dcc8e5ad..5a93d6af 100644 --- a/python/tools/source_converter_lib.py +++ b/python/tools/source_converter_lib.py @@ -86,7 +86,7 @@ class TensorInfo: def stringfy(value): return ', '.join('"{0}"'.format(w) for w in value) -def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): +def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime): if obfuscate: obfuscate_name(net_def) else: @@ -140,6 +140,7 @@ def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): tag = model_tag, mode = 2, runtime = runtime, + model_pb_checksum = mode_pb_checksum, ) with gfile.GFile(output, "wb") as f: f.write(source) diff --git a/python/tools/tf_converter.py b/python/tools/tf_converter.py index 3f486ac2..650dcd4c 100644 --- a/python/tools/tf_converter.py +++ b/python/tools/tf_converter.py @@ -1,5 +1,6 @@ import argparse import sys +import hashlib import tensorflow as tf from tensorflow import gfile from lib.proto import mace_pb2 @@ -11,11 +12,19 @@ from lib.python.tools import source_converter_lib FLAGS = None +def md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + def main(unused_args): if not gfile.Exists(FLAGS.input): print("Input graph file '" + FLAGS.input + "' does not exist!") return -1 + mode_pb_checksum = md5(FLAGS.input) input_graph_def = tf.GraphDef() with gfile.Open(FLAGS.input, "rb") as f: data = f.read() @@ -29,7 +38,7 @@ def main(unused_args): input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) if FLAGS.output_type == 'source': - source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.obfuscate, + source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime) else: with gfile.GFile(FLAGS.output, "wb") as f: -- GitLab