提交 fd3eeb9e 编写于 作者: 叶剑武

Merge branch 'style' into 'master'

Misc code improvements

See merge request !325
......@@ -112,7 +112,8 @@ RUN pip install -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
RUN pip install -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com tensorflow==1.4.0 \
scipy \
jinja2 \
pyyaml
pyyaml \
sh
# Download tensorflow tools
RUN wget http://cnbj1-inner-fds.api.xiaomi.net/mace/tool/transform_graph && \
......
......@@ -64,6 +64,7 @@ std::unique_ptr<OperatorBase> OperatorRegistry::CreateOperator(
namespace ops {
// Keep in lexicographical order
extern void Register_Activation(OperatorRegistry *op_registry);
extern void Register_AddN(OperatorRegistry *op_registry);
extern void Register_BatchNorm(OperatorRegistry *op_registry);
......@@ -74,27 +75,29 @@ extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_FusedConv2D(OperatorRegistry *op_registry);
extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry);
extern void Register_ImageToBuffer(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_MatMul(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
} // namespace ops
OperatorRegistry::OperatorRegistry() {
// Keep in lexicographical order
ops::Register_Activation(this);
ops::Register_AddN(this);
ops::Register_BatchNorm(this);
......@@ -105,23 +108,23 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_FusedConv2D(this);
ops::Register_GlobalAvgPooling(this);
ops::Register_ImageToBuffer(this);
ops::Register_MatMul(this);
ops::Register_Pooling(this);
ops::Register_Proposal(this);
ops::Register_PSROIAlign(this);
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Slice(this);
ops::Register_Softmax(this);
ops::Register_SpaceToBatchND(this);
ops::Register_MatMul(this);
ops::Register_WinogradTransform(this);
ops::Register_WinogradInverseTransform(this);
ops::Register_Reshape(this);
ops::Register_Eltwise(this);
ops::Register_FullyConnected(this);
ops::Register_Slice(this);
ops::Register_Proposal(this);
ops::Register_PSROIAlign(this);
ops::Register_WinogradTransform(this);
}
} // namespace mace
......@@ -44,7 +44,7 @@ def generate_cpp_source():
idx += params_size
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
return env.get_template('str2vec_maps.cc.tmpl').render(
return env.get_template('str2vec_maps.cc.jinja2').render(
maps = data_map,
data_type = 'unsigned int',
variable_name = FLAGS.variable_name
......
......@@ -45,7 +45,7 @@ def main(unused_args):
encrypted_code_maps[file_name[:-3]] = encrypted_code_arr
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
cpp_cl_encrypted_kernel = env.get_template('str2vec_maps.cc.tmpl').render(
cpp_cl_encrypted_kernel = env.get_template('str2vec_maps.cc.jinja2').render(
maps=encrypted_code_maps,
data_type='unsigned char',
variable_name='kEncryptedProgramMap')
......
......@@ -14,13 +14,13 @@ namespace mace {
namespace {{tag}} {
{% for tensor in tensors %}
extern void CreateTensor{{ tensor.id }}(std::vector<mace::ConstTensor> &tensors,
extern void CreateTensor{{ tensor.id }}(std::vector<mace::ConstTensor> *tensors,
const unsigned char *model_data);
{% endfor %}
{% for i in range(net.op|length) %}
extern void CreateOperator{{i}}(mace::OperatorDef &op);
extern void CreateOperator{{i}}(mace::OperatorDef *op);
{% endfor %}
} // namespace {{ tag }}
......@@ -79,31 +79,30 @@ void CreateOutputInfo(mace::NetDef &net_def) {
}
{% endif %}
void CreateOperators(std::vector<mace::OperatorDef> &ops) {
void CreateOperators(std::vector<mace::OperatorDef> *ops) {
MACE_LATENCY_LOGGER(1, "Create operators");
ops.resize({{ net.op|length }});
{% for i in range(net.op|length) %}
ops->resize({{ net.op|length }});
mace::{{tag}}::CreateOperator{{i}}(ops[{{i}}]);
{% for i in range(net.op|length) %}
mace::{{tag}}::CreateOperator{{i}}(&ops->at({{i}}));
{% endfor %}
}
void CreateTensors(std::vector<mace::ConstTensor> &tensors,
void CreateTensors(std::vector<mace::ConstTensor> *tensors,
const unsigned char *model_data) {
MACE_LATENCY_LOGGER(1, "Create tensors");
tensors.reserve({{ net.tensors|length }});
tensors->reserve({{ net.tensors|length }});
{% for tensor in tensors %}
mace::{{tag}}::CreateTensor{{tensor.id}}(tensors, model_data);
{% endfor %}
}
{% if net.mem_arena.mem_block|length != 0 %}
void CreateMemoryArena(mace::MemoryArena &mem_arena) {
std::vector<mace::MemoryBlock> &mem_block = mem_arena.mutable_mem_block();
void CreateMemoryArena(mace::MemoryArena *mem_arena) {
std::vector<mace::MemoryBlock> &mem_block = mem_arena->mutable_mem_block();
mem_block.reserve({{ net.mem_arena.mem_block|length }});
{% for mem_blk in net.mem_arena.mem_block %}
......@@ -129,12 +128,12 @@ NetDef CreateNet(const unsigned char *model_data) {
CreateNetArg(net_def);
{% endif %}
CreateOperators(net_def.mutable_op());
CreateOperators(&net_def.mutable_op());
CreateTensors(net_def.mutable_tensors(), model_data);
CreateTensors(&net_def.mutable_tensors(), model_data);
{% if net.mem_arena.mem_block|length != 0 %}
CreateMemoryArena(net_def.mutable_mem_arena());
CreateMemoryArena(&net_def.mutable_mem_arena());
{% endif %}
{% if net.output_info | length > 0 %}
......
......@@ -31,7 +31,7 @@ def generate_cpp_source():
maps[file_name[:-4]].append(hex(ele))
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
return env.get_template('str2vec_maps.cc.tmpl').render(
return env.get_template('str2vec_maps.cc.jinja2').render(
maps = maps,
data_type = 'unsigned char',
variable_name = 'kCompiledProgramMap'
......
......@@ -13,7 +13,7 @@
namespace mace {
namespace {
void UpdateOp(mace::OperatorDef &op,
void UpdateOp(mace::OperatorDef *op,
const std::string &name,
const std::string &type,
const std::vector<std::string> &inputs,
......@@ -21,13 +21,13 @@ void UpdateOp(mace::OperatorDef &op,
const std::vector<mace::DataType> &output_types,
uint32_t node_id,
const std::vector<int> &mem_ids) {
op.set_name(name);
op.set_type(type);
op.set_input(inputs);
op.set_output(outputs);
op.set_output_type(output_types);
op.set_node_id(node_id);
op.set_mem_id(mem_ids);
op->set_name(name);
op->set_type(type);
op->set_input(inputs);
op->set_output(outputs);
op->set_output_type(output_types);
op->set_node_id(node_id);
op->set_mem_id(mem_ids);
}
} // namespace
......@@ -38,13 +38,13 @@ namespace {{tag}} {
{% for i in range(start, end) %}
void CreateOperator{{i}}(mace::OperatorDef &op) {
void CreateOperator{{i}}(mace::OperatorDef *op) {
MACE_LATENCY_LOGGER(2, "Create operator {{ net.op[i].name }}");
mace::Argument *arg = nullptr;
{% for arg in net.op[i].arg %}
arg = op.add_arg();
arg = op->add_arg();
arg->set_name({{ arg.name|tojson }});
{%- if arg.HasField('f') %}
......@@ -70,7 +70,7 @@ void CreateOperator{{i}}(mace::OperatorDef &op) {
{% for shape in net.op[i].output_shape %}
{% if shape.dims | length > 0 %}
op.add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} }));
op->add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} }));
{% endif %}
{% endfor %}
......@@ -87,20 +87,20 @@ void CreateOperator{{i}}(mace::OperatorDef &op) {
{ {{ net.op[i].mem_id | join(', ') }} });
{% if runtime == 'dsp' %}
op.set_padding({{ net.op[i].padding }});
op->set_padding({{ net.op[i].padding }});
{% if net.op[i].node_input | length > 0 %}
std::vector<int> input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} });
std::vector<int> input_output_ports({ {{ net.op[i].node_input | map(attribute='output_port') | join(', ')}} });
for (size_t i = 0; i < {{ net.op[i].node_input | length }}; ++i) {
mace::NodeInput input(input_node_ids[i], input_output_ports[i]);
op.add_node_input(input);
op->add_node_input(input);
}
{% endif %}
{% if net.op[i].out_max_byte_size | length > 0 %}
std::vector<int> out_max_byte_sizes {{ net.op[i].out_max_byte_size | replace('[', '{') | replace(']', '}') }};
for (size_t i = 0; i < {{ net.op[i].out_max_byte_size | length }}; ++i) {
op.add_out_max_byte_size(out_max_byte_sizes[i]);
op->add_out_max_byte_size(out_max_byte_sizes[i]);
}
{% endif %}
{% endif %}
......
......@@ -110,7 +110,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
j2_env.filters['stringfy'] = stringfy
output_dir = os.path.dirname(output) + '/'
# generate tensor source files
template_name = 'tensor_source.template'
template_name = 'tensor_source.jinja2'
model_data = []
offset = 0
counter = 0
......@@ -135,7 +135,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
counter += 1
# generate tensor data
template_name = 'tensor_data.template'
template_name = 'tensor_data.jinja2'
source = j2_env.get_template(template_name).render(
tag = model_tag,
embed_model_data = embed_model_data,
......@@ -150,7 +150,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
f.close()
# generate op source files
template_name = 'operator.template'
template_name = 'operator.jinja2'
counter = 0
op_size = len(net_def.op)
for start in range(0, op_size, 10):
......@@ -166,7 +166,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
counter += 1
# generate model source files
template_name = 'model.template'
template_name = 'model.jinja2'
tensors = [TensorInfo(i, net_def.tensors[i], runtime) for i in range(len(net_def.tensors))]
source = j2_env.get_template(template_name).render(
tensors = tensors,
......@@ -179,7 +179,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
f.write(source)
# generate model header file
template_name = 'model_header.template'
template_name = 'model_header.jinja2'
source = j2_env.get_template(template_name).render(
tag = model_tag,
)
......
......@@ -10,7 +10,7 @@
namespace mace {
extern const std::map<std::string, std::vector<{{data_type}}>> {{variable_name}}=
extern const std::map<std::string, std::vector<{{data_type}}>> {{variable_name}} =
{
{% for key, value in maps.iteritems() %}
{
......@@ -24,4 +24,4 @@ extern const std::map<std::string, std::vector<{{data_type}}>> {{variable_name}}
{% endfor %}
};
} // namespace
} // namespace mace
......@@ -13,10 +13,10 @@
namespace mace {
namespace {{tag}} {
void CreateTensor{{tensor_info.id}}(std::vector<mace::ConstTensor> &tensors,
void CreateTensor{{tensor_info.id}}(std::vector<mace::ConstTensor> *tensors,
const unsigned char *model_data) {
MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}");
tensors.emplace_back(mace::ConstTensor(
tensors->emplace_back(mace::ConstTensor(
{{ tensor.name|tojson }}, model_data + {{ offset }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }}));
}
......
import sh
import re
def adb_split_stdout(stdout_str):
# Filter out last empty line
return [l.strip() for l in stdout_str.split('\n') if len(l.strip()) > 0]
def adb_devices():
outputs = sh.grep(sh.adb("devices"), "^[A-Za-z0-9]\+[[:space:]]\+device$")
raw_lists = sh.cut(outputs, "-f1")
return adb_split_stdout(raw_lists)
def adb_getprop_by_serialno(serialno):
outputs = sh.adb("-s", serialno, "shell", "getprop")
raw_props = adb_split_stdout(outputs)
props = {}
p = re.compile("\[(.+)\]: \[(.+)\]")
for raw_prop in raw_props:
m = p.match(raw_prop)
if m:
props[m.group(1)] = m.group(2)
return props
def adb_get_all_socs():
socs = []
for d in adb_devices():
props = adb_getprop_by_serialno(d)
socs.append(props["ro.product.board"])
return set(socs)
......@@ -15,6 +15,8 @@ import sys
import urllib
import yaml
import adb_tools
from ConfigParser import ConfigParser
......@@ -201,6 +203,11 @@ def parse_args():
type=str,
default="all",
help="[build|run|validate|merge|all|throughput_test].")
parser.add_argument(
"--socs",
type=str,
default="all",
help="SoCs to build, comma seperated list (getprop ro.board.platform)")
return parser.parse_known_args()
......@@ -227,7 +234,21 @@ def main(unused_args):
generate_opencl_and_version_code()
option_args = ' '.join([arg for arg in unused_args if arg.startswith('--')])
for target_soc in configs["target_socs"]:
available_socs = adb_tools.adb_get_all_socs()
target_socs = available_socs
if hasattr(configs, "target_socs"):
target_socs = set(configs["target_socs"])
target_socs = target_socs & available_socs
if FLAGS.socs != "all":
socs = set(FLAGS.socs.split(','))
target_socs = target_socs & socs
missing_socs = socs.difference(target_socs)
if len(missing_socs) > 0:
print("Error: devices with SoCs are not connected %s" % missing_socs)
exit(1)
for target_soc in target_socs:
for target_abi in configs["target_abis"]:
global_runtime = get_global_runtime(configs)
# Transfer params by environment
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册