提交 3aab734b 编写于 作者: L liuqi

Add model name confusion strategy.

上级 4a9ea4a3
......@@ -13,12 +13,24 @@ py_library(
],
)
py_library(
name = "source_converter_lib",
srcs = [
"source_converter_lib.py",
],
srcs_version = "PY2AND3",
deps = [
"//mace/proto:mace_py",
],
)
py_binary(
name = "tf_converter",
srcs = ["tf_converter.py"],
srcs_version = "PY2AND3",
deps = [
":tf_converter_lib",
":source_converter_lib",
"@six_archive//:six",
],
)
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
// Generated by the mace converter. DO NOT EDIT!
//
{% if mode == 0 %}
namespace mace {
namespace {{tag}}{
alignas(4) unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[] = {
alignas(4) unsigned char {{ tensor.name }}[] = {
{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%}
};
} // namespace mace
} // namespace {{tag}}
{% else %}
#include <vector>
#include <string>
#include "mace/core/mace.h"
namespace mace {
namespace {{tag}} {
{% for tensor in tensors %}
extern unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[];
extern unsigned char {{ tensor.name }}[];
{% endfor %}
} // namespace {{ tag }}
namespace {
{% if net.arg|length != 0 %}
static void CreateNetArg(NetDef &net_def) {
static void CreateNetArg(mace::NetDef &net_def) {
net_def.mutable_arg().reserve({{ net.arg|length }});
Argument *arg = nullptr;
mace::Argument *arg = nullptr;
{% for arg in net.arg %}
arg = net_def.add_arg();
......@@ -57,13 +63,13 @@ static void CreateNetArg(NetDef &net_def) {
}
{% endif %}
static void UpdateOp(OperatorDef &op,
static void UpdateOp(mace::OperatorDef &op,
const std::string &name,
const std::string &type,
const int mem_id,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs,
const std::vector<DataType> &output_types) {
const std::vector<mace::DataType> &output_types) {
op.set_name(name);
op.set_type(type);
op.set_input(inputs);
......@@ -72,9 +78,9 @@ static void UpdateOp(OperatorDef &op,
op.set_output_type(output_types);
}
static void CreateOperators(std::vector<OperatorDef> &ops) {
static void CreateOperators(std::vector<mace::OperatorDef> &ops) {
ops.resize({{ net.op|length }});
Argument *arg = nullptr;
mace::Argument *arg = nullptr;
{% for i in range(net.op|length) %}
{% for arg in net.op[i].arg %}
......@@ -103,7 +109,7 @@ static void CreateOperators(std::vector<OperatorDef> &ops) {
{% endfor %}
{% for shape in net.op[i].output_shape %}
ops[{{i}}].add_output_shape(OutputShape({ {{ shape.dims|join(', ') }} }));
ops[{{i}}].add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} }));
{% endfor %}
UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }},
......@@ -115,13 +121,13 @@ static void CreateOperators(std::vector<OperatorDef> &ops) {
}
static void CreateTensors(std::vector<TensorProto> &tensors) {
static void CreateTensors(std::vector<mace::TensorProto> &tensors) {
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
tensors.emplace_back(TensorProto(
{{ tensor.name|tojson }}, {{ "_" + tensor.name[:-2].replace("/", "_") }},
tensors.emplace_back(mace::TensorProto(
{{ tensor.name|tojson }}, {{ tag + '::' + tensor.name }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }},
{{ tensor.node_id }}
));
......@@ -132,20 +138,24 @@ static void CreateTensors(std::vector<TensorProto> &tensors) {
{% if net.mem_arena.mem_block|length != 0 %}
static void CreateMemoryArena(MemoryArena &mem_arena) {
static void CreateMemoryArena(mace::MemoryArena &mem_arena) {
auto mem_block = mem_arena.mutable_mem_block();
mem_block.reserve({{ net.mem_arena.mem_block|length }});
{% for mem_blk in net.mem_arena.mem_block %}
mem_block.emplace_back(MemoryBlock({{ mem_blk.mem_id }},
{{mem_blk.x}},
{{mem_blk.y}}));
mem_block.emplace_back(mace::MemoryBlock({{ mem_blk.mem_id }},
{{mem_blk.x}},
{{mem_blk.y}}));
{% endfor %}
}
{% endif %}
NetDef CreateNet() {
}
namespace mace {
NetDef {{'Create' + tag}}() {
NetDef net_def;
net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}");
......
import struct
import os
import uuid
from tensorflow import gfile
from mace.proto import mace_pb2
from jinja2 import Environment, FileSystemLoader
GENERATED_NAME = set()
def generate_random_name():
name = '_' + uuid.uuid4().hex[:7].upper()
while name in GENERATED_NAME:
name = '_' + uuid.uuid4().hex[:7].upper()
GENERATED_NAME.add(name)
return name
def generate_tensor_map(tensors):
tensor_map = {}
for t in tensors:
if not tensor_map.has_key(t.name):
tensor_map[t.name] = generate_random_name()
return tensor_map
def generate_in_out_map(ops, tensor_map):
in_out_map = {}
for op in ops:
op.name = generate_random_name()
for input_name in op.input:
if not in_out_map.has_key(input_name):
if tensor_map.has_key(input_name):
in_out_map[input_name] = tensor_map[input_name]
else:
in_out_map[input_name] = generate_random_name()
for output_name in op.output:
if not in_out_map.has_key(output_name):
if tensor_map.has_key(output_name):
in_out_map[output_name] = tensor_map[output_name]
else:
in_out_map[output_name] = generate_random_name()
return in_out_map
def confuse_name(net_def):
input_node = "mace_input_node"
output_node = "mace_output_node"
tensor_map = generate_tensor_map(net_def.tensors)
in_out_map = generate_in_out_map(net_def.op, tensor_map)
for t in net_def.tensors:
if input_node not in t.name and output_node not in t.name:
t.name = tensor_map[t.name]
for op in net_def.op:
for i in range(len(op.input)):
if input_node not in op.input[i]:
op.input[i] = in_out_map[op.input[i]]
for i in range(len(op.output)):
if output_node not in op.output[i]:
op.output[i] = in_out_map[op.output[i]]
def rename_tensor(net_def):
tensor_map = {}
for t in net_def.tensors:
if not tensor_map.has_key(t.name):
tensor_map[t.name] = "_" + t.name[:-2].replace("/", "_")
t.name = tensor_map[t.name]
for op in net_def.op:
for i in range(len(op.input)):
if tensor_map.has_key(op.input[i]):
op.input[i] = tensor_map[op.input[i]]
for i in range(len(op.output)):
if tensor_map.has_key(op.output[i]):
op.output[i] = tensor_map[op.output[i]]
class TensorInfo:
def __init__(self, t):
self.name = t.name
if t.data_type == mace_pb2.DT_FLOAT:
self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data))
elif t.data_type == mace_pb2.DT_INT32:
self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data))
def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def, template, confuse, model_tag, output):
if confuse:
confuse_name(net_def)
else:
rename_tensor(net_def)
# Capture our current directory
template_dir = os.path.dirname(template)
template_name = os.path.basename(template)
print template_dir
# Create the jinja2 environment.
j2_env = Environment(loader=FileSystemLoader(template_dir),
trim_blocks=True)
j2_env.filters['stringfy'] = stringfy
counter = 0
output_dir = os.path.dirname(output) + '/'
# generate tensor source files
for t in net_def.tensors:
source = j2_env.get_template(template_name).render(
tensor = TensorInfo(t),
tag = model_tag,
mode = 0,
)
with gfile.GFile(output_dir + str(counter) + '.cc', "wb") as f:
f.write(source)
counter += 1
# generate model source files
tensors = [TensorInfo(t) for t in net_def.tensors]
source = j2_env.get_template(template_name).render(
tensors = tensors,
net = net_def,
tag = model_tag,
mode = 1
)
with gfile.GFile(output, "wb") as f:
f.write(source)
......@@ -5,57 +5,12 @@ from tensorflow import gfile
from mace.proto import mace_pb2
from mace.python.tools import tf_converter_lib
from mace.python.tools import tf_dsp_converter_lib
import struct
from jinja2 import Environment, FileSystemLoader
import os
from mace.python.tools import source_converter_lib
# ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3
FLAGS = None
class TensorInfo:
def __init__(self, t):
self.name = t.name
if t.data_type == mace_pb2.DT_FLOAT:
self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data))
elif t.data_type == mace_pb2.DT_INT32:
self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data))
def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def):
# Capture our current directory
template_dir = os.path.dirname(FLAGS.template)
template_name = os.path.basename(FLAGS.template)
print template_dir
# Create the jinja2 environment.
# Notice the use of trim_blocks, which greatly helps control whitespace.
j2_env = Environment(loader=FileSystemLoader(template_dir),
trim_blocks=True)
j2_env.filters['stringfy'] = stringfy
counter = 0
output_dir = os.path.dirname(FLAGS.output) + '/'
for t in net_def.tensors:
source = j2_env.get_template(template_name).render(
tensor = TensorInfo(t),
mode = 0,
)
with gfile.GFile(output_dir + str(counter) + '.cc', "wb") as f:
f.write(source)
counter += 1
tensors = [TensorInfo(t) for t in net_def.tensors]
source = j2_env.get_template(template_name).render(
tensors = tensors,
net = net_def,
mode = 1
)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(source)
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
......@@ -75,7 +30,8 @@ def main(unused_args):
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
if FLAGS.output_type == 'source':
convert_to_source(output_graph_def)
source_converter_lib.convert_to_source(output_graph_def, FLAGS.template, FLAGS.confuse,
FLAGS.model_tag, FLAGS.output)
else:
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
......@@ -133,6 +89,16 @@ def parse_args():
type=str,
default="",
help="template path")
parser.add_argument(
"--confuse",
type=bool,
default=False,
help="confuse model names")
parser.add_argument(
"--model_tag",
type=str,
default="",
help="model tag for generated function and namespace")
return parser.parse_known_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册