diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 8b6997a3611a2e315359742a951195a5fa5267f5..ac507145dc08889682acf214bbffa4d4e0e0b546 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -48,7 +48,7 @@ class MemoryOptimizer(object): mem_id = self.idle_mem.pop() if not op.output_shape: - print('There is no output shape information to do memory optimization.') + print('WARNING: There is no output shape information to do memory optimization.') return op.mem_id = mem_id self.op_mem[self._op_to_tensor(op)] = mem_id diff --git a/mace/python/tools/model.template b/mace/python/tools/model.template index bd81d06a4c0868c2656b27868de3ff2626b0c677..3eaac6b9be7e21235b96c19f7e25f1eefe75c344 100644 --- a/mace/python/tools/model.template +++ b/mace/python/tools/model.template @@ -4,13 +4,21 @@ // {% if mode == 0 %} +#include +#include "mace/core/mace.h" namespace {{tag}}{ -alignas(4) unsigned char {{ tensor.name }}[] = { -{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%} +alignas(4) unsigned char {{ tensor_info.name }}[] = { +{% for d in tensor_info.data %}{{"0x%02X, " % d }}{%endfor%} }; +void Create{{tensor.name}}(std::vector &tensors) { + tensors.emplace_back(mace::TensorProto( + {{ tensor.name|tojson }}, {{ tensor.name }}, + { {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, {{ tensor.node_id }})); +} + } // namespace {{tag}} {% elif mode == 1 %} @@ -92,7 +100,7 @@ void CreateOperator{{i}}(mace::OperatorDef &op) { namespace {{tag}} { {% for tensor in tensors %} -extern unsigned char {{ tensor.name }}[]; +extern void Create{{ tensor.name }}(std::vector &tensors); {% endfor %} @@ -156,9 +164,7 @@ static void CreateTensors(std::vector &tensors) { {% for tensor in net.tensors %} - tensors.emplace_back(mace::TensorProto( - {{ tensor.name|tojson }}, {{ tag + '::' + tensor.name }}, - { {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, {{ tensor.node_id }})); + {{tag}}::Create{{tensor.name}}(tensors); {% endfor %} } diff --git a/mace/python/tools/source_converter_lib.py b/mace/python/tools/source_converter_lib.py index a26d5c13cb7c585a7b99740d45b73604036763e7..c2a12e340bca16e72f108a3707be329dcf79b787 100644 --- a/mace/python/tools/source_converter_lib.py +++ b/mace/python/tools/source_converter_lib.py @@ -102,7 +102,8 @@ def convert_to_source(net_def, template, confuse, model_tag, output): # generate tensor source files for t in net_def.tensors: source = j2_env.get_template(template_name).render( - tensor = TensorInfo(t), + tensor_info = TensorInfo(t), + tensor = t, tag = model_tag, mode = 0, )