From 8fda71d4316354ae08543dd06479d69f754bc8cd Mon Sep 17 00:00:00 2001 From: lichao18 Date: Tue, 21 May 2019 16:20:46 +0800 Subject: [PATCH] Add 'code code' converter for apu --- mace/python/tools/model.jinja2 | 6 ++++++ mace/python/tools/operator.jinja2 | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/mace/python/tools/model.jinja2 b/mace/python/tools/model.jinja2 index 0d1396c4..00093e89 100644 --- a/mace/python/tools/model.jinja2 +++ b/mace/python/tools/model.jinja2 @@ -79,8 +79,11 @@ void CreateInputInfo(NetDef *net_def) { {% for idx in range(net.input_info|length) %} input_info = net_def->add_input_info(); input_info->set_name({{ net.input_info[idx].name|tojson }}); + input_info->set_node_id({{net.input_info[idx].node_id }}); input_info->set_data_type(static_cast({{ net.input_info[idx].data_type }})); input_info->set_data_format({{ net.input_info[idx].data_format }}); + input_info->set_scale({{ net.input_info[idx].scale }}); + input_info->set_zero_point({{ net.input_info[idx].zero_point }}); input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }}); {% for dim in net.input_info[idx].dims %} input_info->add_dims({{ dim }}); @@ -96,8 +99,11 @@ void CreateOutputInfo(NetDef *net_def) { {% for idx in range(net.output_info|length) %} output_info = net_def->add_output_info(); output_info->set_name({{ net.output_info[idx].name|tojson }}); + output_info->set_node_id({{ net.output_info[idx].node_id }}); output_info->set_data_type(static_cast({{ net.output_info[idx].data_type }})); output_info->set_data_format({{ net.output_info[idx].data_format }}); + output_info->set_scale({{ net.output_info[idx].scale }}); + output_info->set_zero_point({{ net.output_info[idx].zero_point }}); output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }}); {% for dim in net.output_info[idx].dims %} output_info->add_dims({{dim}}); diff --git a/mace/python/tools/operator.jinja2 b/mace/python/tools/operator.jinja2 index 0fc941db..8ef5bbfc 100644 --- a/mace/python/tools/operator.jinja2 +++ b/mace/python/tools/operator.jinja2 @@ -160,6 +160,20 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { } {% endif %} {% endif %} + + {% if device == 5 %} + {% if net.op[i].node_input | length > 0 %} + std::vector input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} }); + + mace::NodeInput *node_input = nullptr; + op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }}); + for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) { + node_input = op->add_node_input(); + node_input->set_node_id(input_node_ids[i]); + } + {% endif %} + + {% endif %} } {% endfor %} -- GitLab