From 3aab734be475edf0a89b4bbf554d8d1883085ccc Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 18 Dec 2017 13:41:17 +0800 Subject: [PATCH] Add model name confusion strategy. --- mace/python/tools/BUILD | 12 +++ mace/python/tools/model.template | 50 +++++---- mace/python/tools/source_converter_lib.py | 122 ++++++++++++++++++++++ mace/python/tools/tf_converter.py | 60 +++-------- 4 files changed, 177 insertions(+), 67 deletions(-) create mode 100644 mace/python/tools/source_converter_lib.py diff --git a/mace/python/tools/BUILD b/mace/python/tools/BUILD index 675f12ac..fbe406d3 100644 --- a/mace/python/tools/BUILD +++ b/mace/python/tools/BUILD @@ -13,12 +13,24 @@ py_library( ], ) +py_library( + name = "source_converter_lib", + srcs = [ + "source_converter_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//mace/proto:mace_py", + ], +) + py_binary( name = "tf_converter", srcs = ["tf_converter.py"], srcs_version = "PY2AND3", deps = [ ":tf_converter_lib", + ":source_converter_lib", "@six_archive//:six", ], ) diff --git a/mace/python/tools/model.template b/mace/python/tools/model.template index 4f679087..0fcbcde4 100644 --- a/mace/python/tools/model.template +++ b/mace/python/tools/model.template @@ -1,30 +1,36 @@ // // Copyright (c) 2017 XiaoMi All rights reserved. +// Generated by the mace converter. DO NOT EDIT! // {% if mode == 0 %} -namespace mace { +namespace {{tag}}{ -alignas(4) unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[] = { +alignas(4) unsigned char {{ tensor.name }}[] = { {% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%} }; -} // namespace mace +} // namespace {{tag}} {% else %} #include #include #include "mace/core/mace.h" -namespace mace { + +namespace {{tag}} { {% for tensor in tensors %} -extern unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[]; +extern unsigned char {{ tensor.name }}[]; {% endfor %} +} // namespace {{ tag }} + +namespace { + {% if net.arg|length != 0 %} -static void CreateNetArg(NetDef &net_def) { +static void CreateNetArg(mace::NetDef &net_def) { net_def.mutable_arg().reserve({{ net.arg|length }}); - Argument *arg = nullptr; + mace::Argument *arg = nullptr; {% for arg in net.arg %} arg = net_def.add_arg(); @@ -57,13 +63,13 @@ static void CreateNetArg(NetDef &net_def) { } {% endif %} -static void UpdateOp(OperatorDef &op, +static void UpdateOp(mace::OperatorDef &op, const std::string &name, const std::string &type, const int mem_id, const std::vector &inputs, const std::vector &outputs, - const std::vector &output_types) { + const std::vector &output_types) { op.set_name(name); op.set_type(type); op.set_input(inputs); @@ -72,9 +78,9 @@ static void UpdateOp(OperatorDef &op, op.set_output_type(output_types); } -static void CreateOperators(std::vector &ops) { +static void CreateOperators(std::vector &ops) { ops.resize({{ net.op|length }}); - Argument *arg = nullptr; + mace::Argument *arg = nullptr; {% for i in range(net.op|length) %} {% for arg in net.op[i].arg %} @@ -103,7 +109,7 @@ static void CreateOperators(std::vector &ops) { {% endfor %} {% for shape in net.op[i].output_shape %} - ops[{{i}}].add_output_shape(OutputShape({ {{ shape.dims|join(', ') }} })); + ops[{{i}}].add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} })); {% endfor %} UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }}, @@ -115,13 +121,13 @@ static void CreateOperators(std::vector &ops) { } -static void CreateTensors(std::vector &tensors) { +static void CreateTensors(std::vector &tensors) { tensors.reserve({{ net.tensors|length }}); {% for tensor in net.tensors %} - tensors.emplace_back(TensorProto( - {{ tensor.name|tojson }}, {{ "_" + tensor.name[:-2].replace("/", "_") }}, + tensors.emplace_back(mace::TensorProto( + {{ tensor.name|tojson }}, {{ tag + '::' + tensor.name }}, { {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, {{ tensor.node_id }} )); @@ -132,20 +138,24 @@ static void CreateTensors(std::vector &tensors) { {% if net.mem_arena.mem_block|length != 0 %} -static void CreateMemoryArena(MemoryArena &mem_arena) { +static void CreateMemoryArena(mace::MemoryArena &mem_arena) { auto mem_block = mem_arena.mutable_mem_block(); mem_block.reserve({{ net.mem_arena.mem_block|length }}); {% for mem_blk in net.mem_arena.mem_block %} - mem_block.emplace_back(MemoryBlock({{ mem_blk.mem_id }}, - {{mem_blk.x}}, - {{mem_blk.y}})); + mem_block.emplace_back(mace::MemoryBlock({{ mem_blk.mem_id }}, + {{mem_blk.x}}, + {{mem_blk.y}})); {% endfor %} } {% endif %} -NetDef CreateNet() { +} + +namespace mace { + +NetDef {{'Create' + tag}}() { NetDef net_def; net_def.set_name("{{ net.name}}"); net_def.set_version("{{ net.version }}"); diff --git a/mace/python/tools/source_converter_lib.py b/mace/python/tools/source_converter_lib.py new file mode 100644 index 00000000..cd6495fd --- /dev/null +++ b/mace/python/tools/source_converter_lib.py @@ -0,0 +1,122 @@ +import struct +import os +import uuid + +from tensorflow import gfile +from mace.proto import mace_pb2 +from jinja2 import Environment, FileSystemLoader + + +GENERATED_NAME = set() + +def generate_random_name(): + name = '_' + uuid.uuid4().hex[:7].upper() + while name in GENERATED_NAME: + name = '_' + uuid.uuid4().hex[:7].upper() + GENERATED_NAME.add(name) + return name + +def generate_tensor_map(tensors): + tensor_map = {} + for t in tensors: + if not tensor_map.has_key(t.name): + tensor_map[t.name] = generate_random_name() + return tensor_map + +def generate_in_out_map(ops, tensor_map): + in_out_map = {} + for op in ops: + op.name = generate_random_name() + for input_name in op.input: + if not in_out_map.has_key(input_name): + if tensor_map.has_key(input_name): + in_out_map[input_name] = tensor_map[input_name] + else: + in_out_map[input_name] = generate_random_name() + for output_name in op.output: + if not in_out_map.has_key(output_name): + if tensor_map.has_key(output_name): + in_out_map[output_name] = tensor_map[output_name] + else: + in_out_map[output_name] = generate_random_name() + return in_out_map + +def confuse_name(net_def): + input_node = "mace_input_node" + output_node = "mace_output_node" + tensor_map = generate_tensor_map(net_def.tensors) + in_out_map = generate_in_out_map(net_def.op, tensor_map) + for t in net_def.tensors: + if input_node not in t.name and output_node not in t.name: + t.name = tensor_map[t.name] + for op in net_def.op: + for i in range(len(op.input)): + if input_node not in op.input[i]: + op.input[i] = in_out_map[op.input[i]] + for i in range(len(op.output)): + if output_node not in op.output[i]: + op.output[i] = in_out_map[op.output[i]] + +def rename_tensor(net_def): + tensor_map = {} + for t in net_def.tensors: + if not tensor_map.has_key(t.name): + tensor_map[t.name] = "_" + t.name[:-2].replace("/", "_") + t.name = tensor_map[t.name] + for op in net_def.op: + for i in range(len(op.input)): + if tensor_map.has_key(op.input[i]): + op.input[i] = tensor_map[op.input[i]] + for i in range(len(op.output)): + if tensor_map.has_key(op.output[i]): + op.output[i] = tensor_map[op.output[i]] + +class TensorInfo: + def __init__(self, t): + self.name = t.name + if t.data_type == mace_pb2.DT_FLOAT: + self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data)) + elif t.data_type == mace_pb2.DT_INT32: + self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data)) + +def stringfy(value): + return ', '.join('"{0}"'.format(w) for w in value) + +def convert_to_source(net_def, template, confuse, model_tag, output): + if confuse: + confuse_name(net_def) + else: + rename_tensor(net_def) + + # Capture our current directory + template_dir = os.path.dirname(template) + template_name = os.path.basename(template) + print template_dir + + # Create the jinja2 environment. + j2_env = Environment(loader=FileSystemLoader(template_dir), + trim_blocks=True) + j2_env.filters['stringfy'] = stringfy + counter = 0 + output_dir = os.path.dirname(output) + '/' + # generate tensor source files + for t in net_def.tensors: + source = j2_env.get_template(template_name).render( + tensor = TensorInfo(t), + tag = model_tag, + mode = 0, + ) + with gfile.GFile(output_dir + str(counter) + '.cc', "wb") as f: + f.write(source) + counter += 1 + + # generate model source files + tensors = [TensorInfo(t) for t in net_def.tensors] + source = j2_env.get_template(template_name).render( + tensors = tensors, + net = net_def, + tag = model_tag, + mode = 1 + ) + with gfile.GFile(output, "wb") as f: + f.write(source) diff --git a/mace/python/tools/tf_converter.py b/mace/python/tools/tf_converter.py index 15ca54cf..1251bf55 100644 --- a/mace/python/tools/tf_converter.py +++ b/mace/python/tools/tf_converter.py @@ -5,57 +5,12 @@ from tensorflow import gfile from mace.proto import mace_pb2 from mace.python.tools import tf_converter_lib from mace.python.tools import tf_dsp_converter_lib -import struct -from jinja2 import Environment, FileSystemLoader -import os +from mace.python.tools import source_converter_lib # ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3 FLAGS = None -class TensorInfo: - def __init__(self, t): - self.name = t.name - if t.data_type == mace_pb2.DT_FLOAT: - self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data)) - elif t.data_type == mace_pb2.DT_INT32: - self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data)) - -def stringfy(value): - return ', '.join('"{0}"'.format(w) for w in value) - -def convert_to_source(net_def): - # Capture our current directory - template_dir = os.path.dirname(FLAGS.template) - template_name = os.path.basename(FLAGS.template) - print template_dir - - # Create the jinja2 environment. - # Notice the use of trim_blocks, which greatly helps control whitespace. - j2_env = Environment(loader=FileSystemLoader(template_dir), - trim_blocks=True) - j2_env.filters['stringfy'] = stringfy - counter = 0 - output_dir = os.path.dirname(FLAGS.output) + '/' - for t in net_def.tensors: - source = j2_env.get_template(template_name).render( - tensor = TensorInfo(t), - mode = 0, - ) - with gfile.GFile(output_dir + str(counter) + '.cc', "wb") as f: - f.write(source) - counter += 1 - - - tensors = [TensorInfo(t) for t in net_def.tensors] - source = j2_env.get_template(template_name).render( - tensors = tensors, - net = net_def, - mode = 1 - ) - with gfile.GFile(FLAGS.output, "wb") as f: - f.write(source) - def main(unused_args): if not gfile.Exists(FLAGS.input): print("Input graph file '" + FLAGS.input + "' does not exist!") @@ -75,7 +30,8 @@ def main(unused_args): input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) if FLAGS.output_type == 'source': - convert_to_source(output_graph_def) + source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.confuse, + FLAGS.model_tag, FLAGS.output) else: with gfile.GFile(FLAGS.output, "wb") as f: f.write(output_graph_def.SerializeToString()) @@ -133,6 +89,16 @@ def parse_args(): type=str, default="", help="template path") + parser.add_argument( + "--confuse", + type=bool, + default=False, + help="confuse model names") + parser.add_argument( + "--model_tag", + type=str, + default="", + help="model tag for generated function and namespace") return parser.parse_known_args() -- GitLab