// // Copyright (c) 2017 XiaoMi All rights reserved. // Generated by the mace converter. DO NOT EDIT! // #include #include #include "mace/public/mace.h" #include "mace/utils/env_time.h" #include "mace/utils/logging.h" namespace mace { namespace { void UpdateOp(mace::OperatorDef &op, const std::string &name, const std::string &type, const std::vector &inputs, const std::vector &outputs, const std::vector &output_types, uint32_t node_id) { op.set_name(name); op.set_type(type); op.set_input(inputs); op.set_output(outputs); op.set_output_type(output_types); op.set_node_id(node_id); } } // namespace } // namespace mace namespace mace { namespace {{tag}} { {% for i in range(start, end) %} void CreateOperator{{i}}(mace::OperatorDef &op) { MACE_LATENCY_LOGGER(2, "Create operator {{ net.op[i].name }}"); mace::Argument *arg = nullptr; {% for arg in net.op[i].arg %} arg = op.add_arg(); arg->set_name({{ arg.name|tojson }}); {%- if arg.HasField('f') %} arg->set_f({{ arg.f }}); {%- endif %} {%- if arg.HasField('i') %} arg->set_i({{ arg.i }}); {%- endif %} {%- if arg.HasField('s') %} arg->set_s({{ arg.s|tojson }}); {%- endif %} {% if arg.floats|length != 0 %} arg->set_floats({ {{ arg.floats|join(', ') }} }); {% endif %} {% if arg.ints|length != 0 %} arg->set_ints({ {{ arg.ints|join(', ') }} }); {% endif %} {% if arg.strings|length != 0 %} arg->set_strings({ {{ arg.strings|stringfy() }} }); {% endif %} {% endfor %} {% if net.op[i].HasField('mem_id') %} op.set_mem_id({{net.op[i].mem_id}}); {% endif %} {% for shape in net.op[i].output_shape %} {% if shape.dims | length > 0 %} op.add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} })); {% endif %} {% endfor %} std::vector output_types_int({ {{ net.op[i].output_type | join(', ') }} }); std::vector output_types({{ net.op[i].output_type | length }}); for (int k = 0; k < {{ net.op[i].output_type | length }}; ++k) { output_types[k] = static_cast(output_types_int[k]); } UpdateOp(op, {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, { {{ net.op[i].input|stringfy }} }, { {{ net.op[i].output|stringfy }} }, output_types, {{ net.op[i].node_id }}); {% if runtime == 'dsp' %} op.set_padding({{ net.op[i].padding }}); {% if net.op[i].node_input | length > 0 %} std::vector input_node_ids({ {{ net.op[i].node_input | map(attribute='node_id') | join(', ') }} }); std::vector input_output_ports({ {{ net.op[i].node_input | map(attribute='output_port') | join(', ')}} }); for (size_t i = 0; i < {{ net.op[i].node_input | length }}; ++i) { mace::NodeInput input(input_node_ids[i], input_output_ports[i]); op.add_node_input(input); } {% endif %} {% if net.op[i].out_max_byte_size | length > 0 %} std::vector out_max_byte_sizes {{ net.op[i].out_max_byte_size | replace('[', '{') | replace(']', '}') }}; for (size_t i = 0; i < {{ net.op[i].out_max_byte_size | length }}; ++i) { op.add_out_max_byte_size(out_max_byte_sizes[i]); } {% endif %} {% endif %} } {% endfor %} } // namespace {{tag}} } // namespace mace