// Copyright 2018 The MACE Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This is a generated file. DO NOT EDIT! #include #include #include "mace/proto/mace.pb.h" #include "mace/public/mace.h" #include "mace/port/env.h" #include "mace/utils/logging.h" namespace mace { namespace { void UpdateOp(mace::OperatorDef *op, const std::string &name, const std::string &type, const std::vector &inputs, const std::vector &outputs, const std::vector &output_types, uint32_t node_id, const std::vector &mem_ids) { op->set_name(name); op->set_type(type); op->set_node_id(node_id); op->mutable_input()->Reserve(inputs.size()); for (auto input : inputs) { op->add_input(input); } op->mutable_output()->Reserve(outputs.size()); for (auto output : outputs) { op->add_output(output); } op->mutable_output_type()->Reserve(output_types.size()); for (auto output_type : output_types) { op->add_output_type(output_type); } op->mutable_mem_id()->Reserve(mem_ids.size()); for (auto mem_id : mem_ids) { op->add_mem_id(mem_id); } } } // namespace } // namespace mace namespace mace { namespace {{tag}} { {% for i in range(start, end) %} void CreateOperator{{i}}(mace::OperatorDef *op) { MACE_LATENCY_LOGGER(2, "Create operator {{ net.op[i].name }}"); mace::Argument *arg = nullptr; op->mutable_arg()->Reserve({{ net.op[i].arg|length }}); {% for arg in net.op[i].arg %} arg = op->add_arg(); arg->set_name({{ arg.name|tojson }}); {%- if arg.HasField('f') %} arg->set_f({{ arg.f }}); {%- endif %} {%- if arg.HasField('i') %} arg->set_i({{ arg.i }}); {%- endif %} {%- if arg.HasField('s') %} arg->set_s({{ arg.s.decode('utf-8')|tojson }}); {%- endif %} {% if arg.floats|length > 0 %} arg->mutable_floats()->Reserve({{ arg.floats|length }}); {% for float_value in arg.floats %} arg->add_floats({{ float_value }}); {% endfor %} {% endif %} {% if arg.ints|length > 0 %} arg->mutable_ints()->Reserve({{ arg.ints|length }}); {% for int_value in arg.ints %} arg->add_ints({{ int_value }}); {% endfor %} {% endif %} {% endfor %} {% if net.op[i].output_shape|length > 0 %} op->mutable_output_shape()->Reserve({{ net.op[i].output_shape|length }}); {% for shape in net.op[i].output_shape %} {% if shape.dims|length > 0 %} { mace::OutputShape *output_shape = op->add_output_shape(); output_shape->mutable_dims()->Reserve({{ shape.dims|length }}); {% for dim in shape.dims %} output_shape->add_dims({{ dim }}); {% endfor %} } {% endif %} {% endfor %} {% endif %} std::vector output_types_int({ {{ net.op[i].output_type | join(', ') }} }); std::vector output_types({{ net.op[i].output_type | length }}); for (int k = 0; k < {{ net.op[i].output_type | length }}; ++k) { output_types[k] = static_cast(output_types_int[k]); } UpdateOp(op, {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, { {{ net.op[i].input|stringfy }} }, { {{ net.op[i].output|stringfy }} }, output_types, {{ net.op[i].node_id }}, { {{ net.op[i].mem_id | join(', ') }} }); op->mutable_quantize_info()->Reserve({{ net.op[i].quantize_info | length }}); {% for j in range(net.op[i].quantize_info|length) %} auto quantize_info{{j}} = op->add_quantize_info(); quantize_info{{j}}->set_scale({{ net.op[i].quantize_info[j].scale }}); quantize_info{{j}}->set_zero_point({{ net.op[i].quantize_info[j].zero_point }}); quantize_info{{j}}->set_minval({{ net.op[i].quantize_info[j].minval }}); quantize_info{{j}}->set_maxval({{ net.op[i].quantize_info[j].maxval }}); {% endfor %} {% if device == 3 %} op->set_padding({{ net.op[i].padding }}); {% if net.op[i].node_input | length > 0 %} std::vector input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} }); std::vector input_output_ports({ {{ net.op[i].node_input | map(attribute='output_port') | join(', ')}} }); mace::NodeInput *node_input = nullptr; op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }}); for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) { node_input = op->add_node_input(); node_input->set_node_id(input_node_ids[i]); node_input->set_output_port(input_output_ports[i]); } {% endif %} {% if net.op[i].out_max_byte_size | length > 0 %} std::vector out_max_byte_sizes {{ net.op[i].out_max_byte_size | replace('[', '{') | replace(']', '}') }}; op->mutable_out_max_byte_size()->Reserve({{ net.op[i].out_max_byte_size|length }}); for (size_t i = 0; i < {{ net.op[i].out_max_byte_size|length }}; ++i) { op->add_out_max_byte_size(out_max_byte_sizes[i]); } {% endif %} {% endif %} {% if device == 5 %} {% if net.op[i].node_input | length > 0 %} std::vector input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} }); mace::NodeInput *node_input = nullptr; op->mutable_node_input()->Reserve({{ net.op[i].node_input|length }}); for (size_t i = 0; i < {{ net.op[i].node_input|length }}; ++i) { node_input = op->add_node_input(); node_input->set_node_id(input_node_ids[i]); } {% endif %} {% endif %} } {% endfor %} } // namespace {{tag}} } // namespace mace