提交 b7145c73 编写于 作者: L Liangliang He

Misc code improvements

上级 5efffb7c
...@@ -112,7 +112,8 @@ RUN pip install -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com ...@@ -112,7 +112,8 @@ RUN pip install -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com
RUN pip install -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com tensorflow==1.4.0 \ RUN pip install -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com tensorflow==1.4.0 \
scipy \ scipy \
jinja2 \ jinja2 \
pyyaml pyyaml \
sh
# Download tensorflow tools # Download tensorflow tools
RUN wget http://cnbj1-inner-fds.api.xiaomi.net/mace/tool/transform_graph && \ RUN wget http://cnbj1-inner-fds.api.xiaomi.net/mace/tool/transform_graph && \
......
...@@ -44,7 +44,7 @@ def generate_cpp_source(): ...@@ -44,7 +44,7 @@ def generate_cpp_source():
idx += params_size idx += params_size
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0])) env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
return env.get_template('str2vec_maps.cc.tmpl').render( return env.get_template('str2vec_maps.cc.jinja2').render(
maps = data_map, maps = data_map,
data_type = 'unsigned int', data_type = 'unsigned int',
variable_name = FLAGS.variable_name variable_name = FLAGS.variable_name
......
...@@ -45,7 +45,7 @@ def main(unused_args): ...@@ -45,7 +45,7 @@ def main(unused_args):
encrypted_code_maps[file_name[:-3]] = encrypted_code_arr encrypted_code_maps[file_name[:-3]] = encrypted_code_arr
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0])) env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
cpp_cl_encrypted_kernel = env.get_template('str2vec_maps.cc.tmpl').render( cpp_cl_encrypted_kernel = env.get_template('str2vec_maps.cc.jinja2').render(
maps=encrypted_code_maps, maps=encrypted_code_maps,
data_type='unsigned char', data_type='unsigned char',
variable_name='kEncryptedProgramMap') variable_name='kEncryptedProgramMap')
......
...@@ -14,13 +14,13 @@ namespace mace { ...@@ -14,13 +14,13 @@ namespace mace {
namespace {{tag}} { namespace {{tag}} {
{% for tensor in tensors %} {% for tensor in tensors %}
extern void CreateTensor{{ tensor.id }}(std::vector<mace::ConstTensor> &tensors, extern void CreateTensor{{ tensor.id }}(std::vector<mace::ConstTensor> *tensors,
const unsigned char *model_data); const unsigned char *model_data);
{% endfor %} {% endfor %}
{% for i in range(net.op|length) %} {% for i in range(net.op|length) %}
extern void CreateOperator{{i}}(mace::OperatorDef &op); extern void CreateOperator{{i}}(mace::OperatorDef *op);
{% endfor %} {% endfor %}
} // namespace {{ tag }} } // namespace {{ tag }}
...@@ -79,31 +79,30 @@ void CreateOutputInfo(mace::NetDef &net_def) { ...@@ -79,31 +79,30 @@ void CreateOutputInfo(mace::NetDef &net_def) {
} }
{% endif %} {% endif %}
void CreateOperators(std::vector<mace::OperatorDef> &ops) { void CreateOperators(std::vector<mace::OperatorDef> *ops) {
MACE_LATENCY_LOGGER(1, "Create operators"); MACE_LATENCY_LOGGER(1, "Create operators");
ops.resize({{ net.op|length }}); ops->resize({{ net.op|length }});
{% for i in range(net.op|length) %}
mace::{{tag}}::CreateOperator{{i}}(ops[{{i}}]); {% for i in range(net.op|length) %}
mace::{{tag}}::CreateOperator{{i}}(&ops->at({{i}}));
{% endfor %} {% endfor %}
} }
void CreateTensors(std::vector<mace::ConstTensor> &tensors, void CreateTensors(std::vector<mace::ConstTensor> *tensors,
const unsigned char *model_data) { const unsigned char *model_data) {
MACE_LATENCY_LOGGER(1, "Create tensors"); MACE_LATENCY_LOGGER(1, "Create tensors");
tensors.reserve({{ net.tensors|length }}); tensors->reserve({{ net.tensors|length }});
{% for tensor in tensors %} {% for tensor in tensors %}
mace::{{tag}}::CreateTensor{{tensor.id}}(tensors, model_data); mace::{{tag}}::CreateTensor{{tensor.id}}(tensors, model_data);
{% endfor %} {% endfor %}
} }
{% if net.mem_arena.mem_block|length != 0 %} {% if net.mem_arena.mem_block|length != 0 %}
void CreateMemoryArena(mace::MemoryArena &mem_arena) { void CreateMemoryArena(mace::MemoryArena *mem_arena) {
std::vector<mace::MemoryBlock> &mem_block = mem_arena.mutable_mem_block(); std::vector<mace::MemoryBlock> &mem_block = mem_arena->mutable_mem_block();
mem_block.reserve({{ net.mem_arena.mem_block|length }}); mem_block.reserve({{ net.mem_arena.mem_block|length }});
{% for mem_blk in net.mem_arena.mem_block %} {% for mem_blk in net.mem_arena.mem_block %}
...@@ -129,12 +128,12 @@ NetDef CreateNet(const unsigned char *model_data) { ...@@ -129,12 +128,12 @@ NetDef CreateNet(const unsigned char *model_data) {
CreateNetArg(net_def); CreateNetArg(net_def);
{% endif %} {% endif %}
CreateOperators(net_def.mutable_op()); CreateOperators(&net_def.mutable_op());
CreateTensors(net_def.mutable_tensors(), model_data); CreateTensors(&net_def.mutable_tensors(), model_data);
{% if net.mem_arena.mem_block|length != 0 %} {% if net.mem_arena.mem_block|length != 0 %}
CreateMemoryArena(net_def.mutable_mem_arena()); CreateMemoryArena(&net_def.mutable_mem_arena());
{% endif %} {% endif %}
{% if net.output_info | length > 0 %} {% if net.output_info | length > 0 %}
......
...@@ -31,7 +31,7 @@ def generate_cpp_source(): ...@@ -31,7 +31,7 @@ def generate_cpp_source():
maps[file_name[:-4]].append(hex(ele)) maps[file_name[:-4]].append(hex(ele))
env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0])) env = jinja2.Environment(loader=jinja2.FileSystemLoader(sys.path[0]))
return env.get_template('str2vec_maps.cc.tmpl').render( return env.get_template('str2vec_maps.cc.jinja2').render(
maps = maps, maps = maps,
data_type = 'unsigned char', data_type = 'unsigned char',
variable_name = 'kCompiledProgramMap' variable_name = 'kCompiledProgramMap'
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
namespace mace { namespace mace {
namespace { namespace {
void UpdateOp(mace::OperatorDef &op, void UpdateOp(mace::OperatorDef *op,
const std::string &name, const std::string &name,
const std::string &type, const std::string &type,
const std::vector<std::string> &inputs, const std::vector<std::string> &inputs,
...@@ -21,13 +21,13 @@ void UpdateOp(mace::OperatorDef &op, ...@@ -21,13 +21,13 @@ void UpdateOp(mace::OperatorDef &op,
const std::vector<mace::DataType> &output_types, const std::vector<mace::DataType> &output_types,
uint32_t node_id, uint32_t node_id,
const std::vector<int> &mem_ids) { const std::vector<int> &mem_ids) {
op.set_name(name); op->set_name(name);
op.set_type(type); op->set_type(type);
op.set_input(inputs); op->set_input(inputs);
op.set_output(outputs); op->set_output(outputs);
op.set_output_type(output_types); op->set_output_type(output_types);
op.set_node_id(node_id); op->set_node_id(node_id);
op.set_mem_id(mem_ids); op->set_mem_id(mem_ids);
} }
} // namespace } // namespace
...@@ -38,13 +38,13 @@ namespace {{tag}} { ...@@ -38,13 +38,13 @@ namespace {{tag}} {
{% for i in range(start, end) %} {% for i in range(start, end) %}
void CreateOperator{{i}}(mace::OperatorDef &op) { void CreateOperator{{i}}(mace::OperatorDef *op) {
MACE_LATENCY_LOGGER(2, "Create operator {{ net.op[i].name }}"); MACE_LATENCY_LOGGER(2, "Create operator {{ net.op[i].name }}");
mace::Argument *arg = nullptr; mace::Argument *arg = nullptr;
{% for arg in net.op[i].arg %} {% for arg in net.op[i].arg %}
arg = op.add_arg(); arg = op->add_arg();
arg->set_name({{ arg.name|tojson }}); arg->set_name({{ arg.name|tojson }});
{%- if arg.HasField('f') %} {%- if arg.HasField('f') %}
...@@ -70,7 +70,7 @@ void CreateOperator{{i}}(mace::OperatorDef &op) { ...@@ -70,7 +70,7 @@ void CreateOperator{{i}}(mace::OperatorDef &op) {
{% for shape in net.op[i].output_shape %} {% for shape in net.op[i].output_shape %}
{% if shape.dims | length > 0 %} {% if shape.dims | length > 0 %}
op.add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} })); op->add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} }));
{% endif %} {% endif %}
{% endfor %} {% endfor %}
...@@ -87,20 +87,20 @@ void CreateOperator{{i}}(mace::OperatorDef &op) { ...@@ -87,20 +87,20 @@ void CreateOperator{{i}}(mace::OperatorDef &op) {
{ {{ net.op[i].mem_id | join(', ') }} }); { {{ net.op[i].mem_id | join(', ') }} });
{% if runtime == 'dsp' %} {% if runtime == 'dsp' %}
op.set_padding({{ net.op[i].padding }}); op->set_padding({{ net.op[i].padding }});
{% if net.op[i].node_input | length > 0 %} {% if net.op[i].node_input | length > 0 %}
std::vector<int> input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} }); std::vector<int> input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} });
std::vector<int> input_output_ports({ {{ net.op[i].node_input | map(attribute='output_port') | join(', ')}} }); std::vector<int> input_output_ports({ {{ net.op[i].node_input | map(attribute='output_port') | join(', ')}} });
for (size_t i = 0; i < {{ net.op[i].node_input | length }}; ++i) { for (size_t i = 0; i < {{ net.op[i].node_input | length }}; ++i) {
mace::NodeInput input(input_node_ids[i], input_output_ports[i]); mace::NodeInput input(input_node_ids[i], input_output_ports[i]);
op.add_node_input(input); op->add_node_input(input);
} }
{% endif %} {% endif %}
{% if net.op[i].out_max_byte_size | length > 0 %} {% if net.op[i].out_max_byte_size | length > 0 %}
std::vector<int> out_max_byte_sizes {{ net.op[i].out_max_byte_size | replace('[', '{') | replace(']', '}') }}; std::vector<int> out_max_byte_sizes {{ net.op[i].out_max_byte_size | replace('[', '{') | replace(']', '}') }};
for (size_t i = 0; i < {{ net.op[i].out_max_byte_size | length }}; ++i) { for (size_t i = 0; i < {{ net.op[i].out_max_byte_size | length }}; ++i) {
op.add_out_max_byte_size(out_max_byte_sizes[i]); op->add_out_max_byte_size(out_max_byte_sizes[i]);
} }
{% endif %} {% endif %}
{% endif %} {% endif %}
......
...@@ -110,7 +110,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_ ...@@ -110,7 +110,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
j2_env.filters['stringfy'] = stringfy j2_env.filters['stringfy'] = stringfy
output_dir = os.path.dirname(output) + '/' output_dir = os.path.dirname(output) + '/'
# generate tensor source files # generate tensor source files
template_name = 'tensor_source.template' template_name = 'tensor_source.jinja2'
model_data = [] model_data = []
offset = 0 offset = 0
counter = 0 counter = 0
...@@ -135,7 +135,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_ ...@@ -135,7 +135,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
counter += 1 counter += 1
# generate tensor data # generate tensor data
template_name = 'tensor_data.template' template_name = 'tensor_data.jinja2'
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tag = model_tag, tag = model_tag,
embed_model_data = embed_model_data, embed_model_data = embed_model_data,
...@@ -150,7 +150,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_ ...@@ -150,7 +150,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
f.close() f.close()
# generate op source files # generate op source files
template_name = 'operator.template' template_name = 'operator.jinja2'
counter = 0 counter = 0
op_size = len(net_def.op) op_size = len(net_def.op)
for start in range(0, op_size, 10): for start in range(0, op_size, 10):
...@@ -166,7 +166,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_ ...@@ -166,7 +166,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
counter += 1 counter += 1
# generate model source files # generate model source files
template_name = 'model.template' template_name = 'model.jinja2'
tensors = [TensorInfo(i, net_def.tensors[i], runtime) for i in range(len(net_def.tensors))] tensors = [TensorInfo(i, net_def.tensors[i], runtime) for i in range(len(net_def.tensors))]
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tensors = tensors, tensors = tensors,
...@@ -179,7 +179,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_ ...@@ -179,7 +179,7 @@ def convert_to_source(net_def, mode_pb_checksum, template_dir, obfuscate, model_
f.write(source) f.write(source)
# generate model header file # generate model header file
template_name = 'model_header.template' template_name = 'model_header.jinja2'
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tag = model_tag, tag = model_tag,
) )
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
namespace mace { namespace mace {
namespace {{tag}} { namespace {{tag}} {
void CreateTensor{{tensor_info.id}}(std::vector<mace::ConstTensor> &tensors, void CreateTensor{{tensor_info.id}}(std::vector<mace::ConstTensor> *tensors,
const unsigned char *model_data) { const unsigned char *model_data) {
MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}"); MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}");
tensors.emplace_back(mace::ConstTensor( tensors->emplace_back(mace::ConstTensor(
{{ tensor.name|tojson }}, model_data + {{ offset }}, {{ tensor.name|tojson }}, model_data + {{ offset }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }})); { {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }}));
} }
......
import sh
import re
def adb_split_stdout(stdout_str):
# Filter out last empty line
return [l.strip() for l in stdout_str.split('\n') if len(l.strip()) > 0]
def adb_devices():
outputs = sh.grep(sh.adb("devices"), "^[A-Za-z0-9]\+[[:space:]]\+device$")
raw_lists = sh.cut(outputs, "-f1")
return adb_split_stdout(raw_lists)
def adb_getprop_by_serialno(serialno):
outputs = sh.adb("-s", serialno, "shell", "getprop")
raw_props = adb_split_stdout(outputs)
props = {}
p = re.compile("\[(.+)\]: \[(.+)\]")
for raw_prop in raw_props:
m = p.match(raw_prop)
if m:
props[m.group(1)] = m.group(2)
return props
def adb_get_all_socs():
socs = []
for d in adb_devices():
props = adb_getprop_by_serialno(d)
socs.append(props["ro.product.board"])
return set(socs)
...@@ -15,6 +15,8 @@ import sys ...@@ -15,6 +15,8 @@ import sys
import urllib import urllib
import yaml import yaml
import adb_tools
from ConfigParser import ConfigParser from ConfigParser import ConfigParser
...@@ -201,6 +203,11 @@ def parse_args(): ...@@ -201,6 +203,11 @@ def parse_args():
type=str, type=str,
default="all", default="all",
help="[build|run|validate|merge|all|throughput_test].") help="[build|run|validate|merge|all|throughput_test].")
parser.add_argument(
"--socs",
type=str,
default="all",
help="SoCs to build, comma seperated list (getprop ro.board.platform)")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -227,7 +234,21 @@ def main(unused_args): ...@@ -227,7 +234,21 @@ def main(unused_args):
generate_opencl_and_version_code() generate_opencl_and_version_code()
option_args = ' '.join([arg for arg in unused_args if arg.startswith('--')]) option_args = ' '.join([arg for arg in unused_args if arg.startswith('--')])
for target_soc in configs["target_socs"]: available_socs = adb_tools.adb_get_all_socs()
target_socs = available_socs
if hasattr(configs, "target_socs"):
target_socs = set(configs["target_socs"])
target_socs = target_socs & available_socs
if FLAGS.socs != "all":
socs = set(FLAGS.socs.split(','))
target_socs = target_socs & socs
missing_socs = socs.difference(target_socs)
if len(missing_socs) > 0:
print("Error: devices with SoCs are not connected %s" % missing_socs)
exit(1)
for target_soc in target_socs:
for target_abi in configs["target_abis"]: for target_abi in configs["target_abis"]:
global_runtime = get_global_runtime(configs) global_runtime = get_global_runtime(configs)
# Transfer params by environment # Transfer params by environment
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册