diff --git a/python/tools/model.template b/python/tools/model.template index 068328775b3a16057f6ee8f5f7750cf6ede59a8f..032730531c2a955be0053399124d357eb2f4c155 100644 --- a/python/tools/model.template +++ b/python/tools/model.template @@ -263,6 +263,10 @@ NetDef CreateNet() { return net_def; } +const std::string ModelChecksum() { + return {{ model_pb_checksum|tojson }}; +} + } // namespace {{tag}} } // namespace mace {% endif %} diff --git a/python/tools/source_converter_lib.py b/python/tools/source_converter_lib.py index dcc8e5add559c73913e76c58b20acba9ce9daa8c..5a93d6afe92a54de9bd4f304777c2ab46eb955ee 100644 --- a/python/tools/source_converter_lib.py +++ b/python/tools/source_converter_lib.py @@ -86,7 +86,7 @@ class TensorInfo: def stringfy(value): return ', '.join('"{0}"'.format(w) for w in value) -def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): +def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime): if obfuscate: obfuscate_name(net_def) else: @@ -140,6 +140,7 @@ def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): tag = model_tag, mode = 2, runtime = runtime, + model_pb_checksum = mode_pb_checksum, ) with gfile.GFile(output, "wb") as f: f.write(source) diff --git a/python/tools/tf_converter.py b/python/tools/tf_converter.py index 3f486ac2f3e3cfc8792d0daab6c3d85993ecc232..650dcd4c627f0ad287d4cfaf73c6136505ab405a 100644 --- a/python/tools/tf_converter.py +++ b/python/tools/tf_converter.py @@ -1,5 +1,6 @@ import argparse import sys +import hashlib import tensorflow as tf from tensorflow import gfile from lib.proto import mace_pb2 @@ -11,11 +12,19 @@ from lib.python.tools import source_converter_lib FLAGS = None +def md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + def main(unused_args): if not gfile.Exists(FLAGS.input): print("Input graph file '" + FLAGS.input + "' does not exist!") return -1 + mode_pb_checksum = md5(FLAGS.input) input_graph_def = tf.GraphDef() with gfile.Open(FLAGS.input, "rb") as f: data = f.read() @@ -29,7 +38,7 @@ def main(unused_args): input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) if FLAGS.output_type == 'source': - source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.obfuscate, + source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime) else: with gfile.GFile(FLAGS.output, "wb") as f: