diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index 2d57d7f794922ff001327fca7928768af25f8f11..6e2d2433bf0345468d91c7a16496efdcf315d99c 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -32,6 +32,8 @@ namespace MACE_MODEL_TAG { extern NetDef CreateNet(); +extern const std::string ModelChecksum(); + } } @@ -140,6 +142,7 @@ int main(int argc, char **argv) { VLOG(0) << "mace version: " << MaceVersion() << std::endl << "mace git version: " << MaceGitVersion() << std::endl + << "model checksum: " << mace::MACE_MODEL_TAG::ModelChecksum() << std::endl << "input_shape: " << input_shape << std::endl << "output_shape: " << output_shape << std::endl << "input_file: " << input_file << std::endl diff --git a/mace/python/tools/model.template b/mace/python/tools/model.template index 068328775b3a16057f6ee8f5f7750cf6ede59a8f..032730531c2a955be0053399124d357eb2f4c155 100644 --- a/mace/python/tools/model.template +++ b/mace/python/tools/model.template @@ -263,6 +263,10 @@ NetDef CreateNet() { return net_def; } +const std::string ModelChecksum() { + return {{ model_pb_checksum|tojson }}; +} + } // namespace {{tag}} } // namespace mace {% endif %} diff --git a/mace/python/tools/source_converter_lib.py b/mace/python/tools/source_converter_lib.py index c6be3a3e8ce6cabd8550d9e2e4095224be54c6bd..b9db39110956181864ff793722301756cc1f4190 100644 --- a/mace/python/tools/source_converter_lib.py +++ b/mace/python/tools/source_converter_lib.py @@ -86,7 +86,7 @@ class TensorInfo: def stringfy(value): return ', '.join('"{0}"'.format(w) for w in value) -def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): +def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime): if obfuscate: obfuscate_name(net_def) else: @@ -140,6 +140,7 @@ def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): tag = model_tag, mode = 2, runtime = runtime, + model_pb_checksum = mode_pb_checksum, ) with gfile.GFile(output, "wb") as f: f.write(source) diff --git a/mace/python/tools/tf_converter.py b/mace/python/tools/tf_converter.py index 303fd143ba90d48557925939fe5dffe48d502c90..1f02fc28e9269f0480a3b53ef80d786df60f0134 100644 --- a/mace/python/tools/tf_converter.py +++ b/mace/python/tools/tf_converter.py @@ -1,5 +1,6 @@ import argparse import sys +import hashlib import tensorflow as tf from tensorflow import gfile from mace.proto import mace_pb2 @@ -11,11 +12,19 @@ from mace.python.tools import source_converter_lib FLAGS = None +def md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + def main(unused_args): if not gfile.Exists(FLAGS.input): print("Input graph file '" + FLAGS.input + "' does not exist!") return -1 + mode_pb_checksum = md5(FLAGS.input) input_graph_def = tf.GraphDef() with gfile.Open(FLAGS.input, "rb") as f: data = f.read() @@ -29,7 +38,7 @@ def main(unused_args): input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) if FLAGS.output_type == 'source': - source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.obfuscate, + source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, FLAGS.model_tag, FLAGS.output, FLAGS.runtime) else: with gfile.GFile(FLAGS.output, "wb") as f: