提交 4d43c38c 编写于 作者: L liuqi

Add model checksum api for track original model.

上级 1b871466
......@@ -32,6 +32,8 @@ namespace MACE_MODEL_TAG {
extern NetDef CreateNet();
extern const std::string ModelChecksum();
}
}
......@@ -140,6 +142,7 @@ int main(int argc, char **argv) {
VLOG(0) << "mace version: " << MaceVersion() << std::endl
<< "mace git version: " << MaceGitVersion() << std::endl
<< "model checksum: " << mace::MACE_MODEL_TAG::ModelChecksum() << std::endl
<< "input_shape: " << input_shape << std::endl
<< "output_shape: " << output_shape << std::endl
<< "input_file: " << input_file << std::endl
......
......@@ -263,6 +263,10 @@ NetDef CreateNet() {
return net_def;
}
const std::string ModelChecksum() {
return {{ model_pb_checksum|tojson }};
}
} // namespace {{tag}}
} // namespace mace
{% endif %}
......@@ -86,7 +86,7 @@ class TensorInfo:
def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime):
def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime):
if obfuscate:
obfuscate_name(net_def)
else:
......@@ -140,6 +140,7 @@ def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime):
tag = model_tag,
mode = 2,
runtime = runtime,
model_pb_checksum = mode_pb_checksum,
)
with gfile.GFile(output, "wb") as f:
f.write(source)
import argparse
import sys
import hashlib
import tensorflow as tf
from tensorflow import gfile
from mace.proto import mace_pb2
......@@ -11,11 +12,19 @@ from mace.python.tools import source_converter_lib
FLAGS = None
def md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1
mode_pb_checksum = md5(FLAGS.input)
input_graph_def = tf.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f:
data = f.read()
......@@ -29,7 +38,7 @@ def main(unused_args):
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.obfuscate,
source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime)
else:
with gfile.GFile(FLAGS.output, "wb") as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册