提交 c264d3b2 编写于 作者: Y yejianwu

merge commit in mace repo

上级 0e5ebc1f
...@@ -263,6 +263,10 @@ NetDef CreateNet() { ...@@ -263,6 +263,10 @@ NetDef CreateNet() {
return net_def; return net_def;
} }
const std::string ModelChecksum() {
return {{ model_pb_checksum|tojson }};
}
} // namespace {{tag}} } // namespace {{tag}}
} // namespace mace } // namespace mace
{% endif %} {% endif %}
...@@ -86,7 +86,7 @@ class TensorInfo: ...@@ -86,7 +86,7 @@ class TensorInfo:
def stringfy(value): def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value) return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, output, runtime):
if obfuscate: if obfuscate:
obfuscate_name(net_def) obfuscate_name(net_def)
else: else:
...@@ -140,6 +140,7 @@ def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime): ...@@ -140,6 +140,7 @@ def convert_to_source(net_def, template, obfuscate, model_tag, output, runtime):
tag = model_tag, tag = model_tag,
mode = 2, mode = 2,
runtime = runtime, runtime = runtime,
model_pb_checksum = mode_pb_checksum,
) )
with gfile.GFile(output, "wb") as f: with gfile.GFile(output, "wb") as f:
f.write(source) f.write(source)
import argparse import argparse
import sys import sys
import hashlib
import tensorflow as tf import tensorflow as tf
from tensorflow import gfile from tensorflow import gfile
from lib.proto import mace_pb2 from lib.proto import mace_pb2
...@@ -11,11 +12,19 @@ from lib.python.tools import source_converter_lib ...@@ -11,11 +12,19 @@ from lib.python.tools import source_converter_lib
FLAGS = None FLAGS = None
def md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def main(unused_args): def main(unused_args):
if not gfile.Exists(FLAGS.input): if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!") print("Input graph file '" + FLAGS.input + "' does not exist!")
return -1 return -1
mode_pb_checksum = md5(FLAGS.input)
input_graph_def = tf.GraphDef() input_graph_def = tf.GraphDef()
with gfile.Open(FLAGS.input, "rb") as f: with gfile.Open(FLAGS.input, "rb") as f:
data = f.read() data = f.read()
...@@ -29,7 +38,7 @@ def main(unused_args): ...@@ -29,7 +38,7 @@ def main(unused_args):
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.obfuscate, source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate,
FLAGS.model_tag, FLAGS.output, FLAGS.runtime) FLAGS.model_tag, FLAGS.output, FLAGS.runtime)
else: else:
with gfile.GFile(FLAGS.output, "wb") as f: with gfile.GFile(FLAGS.output, "wb") as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册