diff --git a/mace/python/tools/model.jinja2 b/mace/python/tools/model.jinja2 index 0d1396c498988ac39f2d1509c8eff90c2deeccab..00093e89c63b3aed9003dc02e3e5e0d5ce32e14f 100644 --- a/mace/python/tools/model.jinja2 +++ b/mace/python/tools/model.jinja2 @@ -79,8 +79,11 @@ void CreateInputInfo(NetDef *net_def) { {% for idx in range(net.input_info|length) %} input_info = net_def->add_input_info(); input_info->set_name({{ net.input_info[idx].name|tojson }}); + input_info->set_node_id({{net.input_info[idx].node_id }}); input_info->set_data_type(static_cast({{ net.input_info[idx].data_type }})); input_info->set_data_format({{ net.input_info[idx].data_format }}); + input_info->set_scale({{ net.input_info[idx].scale }}); + input_info->set_zero_point({{ net.input_info[idx].zero_point }}); input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }}); {% for dim in net.input_info[idx].dims %} input_info->add_dims({{ dim }}); @@ -96,8 +99,11 @@ void CreateOutputInfo(NetDef *net_def) { {% for idx in range(net.output_info|length) %} output_info = net_def->add_output_info(); output_info->set_name({{ net.output_info[idx].name|tojson }}); + output_info->set_node_id({{ net.output_info[idx].node_id }}); output_info->set_data_type(static_cast({{ net.output_info[idx].data_type }})); output_info->set_data_format({{ net.output_info[idx].data_format }}); + output_info->set_scale({{ net.output_info[idx].scale }}); + output_info->set_zero_point({{ net.output_info[idx].zero_point }}); output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }}); {% for dim in net.output_info[idx].dims %} output_info->add_dims({{dim}}); diff --git a/mace/python/tools/operator.jinja2 b/mace/python/tools/operator.jinja2 index 0fc941db48b87da0c2bd683230b9e4f4ed9c4deb..8ef5bbfc54355724e4f79a301564af5cb3cec416 100644 --- a/mace/python/tools/operator.jinja2 +++ b/mace/python/tools/operator.jinja2 @@ -160,6 +160,20 @@ void CreateOperator{{i}}(mace::OperatorDef *op) { } {% endif %} {% endif %} + + {% if device == 5 %} + {% if net.op[i].node_input | length > 0 %} + std::vector input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} }); + + mace::NodeInput *node_input = nullptr; + op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }}); + for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) { + node_input = op->add_node_input(); + node_input->set_node_id(input_node_ids[i]); + } + {% endif %} + + {% endif %} } {% endfor %}