diff --git a/mace/python/tools/model.template b/mace/python/tools/model.template index 0fcbcde420dbb8be87727352ee5c63dc7c68f391..7fa60ae6b55ac8185db40d7c88494f71c686a231 100644 --- a/mace/python/tools/model.template +++ b/mace/python/tools/model.template @@ -66,7 +66,6 @@ static void CreateNetArg(mace::NetDef &net_def) { static void UpdateOp(mace::OperatorDef &op, const std::string &name, const std::string &type, - const int mem_id, const std::vector &inputs, const std::vector &outputs, const std::vector &output_types) { @@ -74,7 +73,6 @@ static void UpdateOp(mace::OperatorDef &op, op.set_type(type); op.set_input(inputs); op.set_output(outputs); - op.set_mem_id(mem_id); op.set_output_type(output_types); } @@ -108,11 +106,15 @@ static void CreateOperators(std::vector &ops) { {% endif %} {% endfor %} + {% if net.op[i].HasField('mem_id') %} + ops[{{i}}].set_mem_id({{net.op[i].mem_id}}); + {% endif %} + {% for shape in net.op[i].output_shape %} ops[{{i}}].add_output_shape(mace::OutputShape({ {{ shape.dims|join(', ') }} })); {% endfor %} - UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }}, + UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, { {{ net.op[i].input|stringfy }} }, { {{ net.op[i].output|stringfy }} }, { {{ net.op[i].output_type|join(', ') }} }); @@ -139,7 +141,7 @@ static void CreateTensors(std::vector &tensors) { {% if net.mem_arena.mem_block|length != 0 %} static void CreateMemoryArena(mace::MemoryArena &mem_arena) { - auto mem_block = mem_arena.mutable_mem_block(); + std::vector &mem_block = mem_arena.mutable_mem_block(); mem_block.reserve({{ net.mem_arena.mem_block|length }}); {% for mem_blk in net.mem_arena.mem_block %}