From b6eb7b8a5b7f58acd6d4cf688b1fb6069807fdc5 Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 1 Feb 2018 19:48:18 +0800 Subject: [PATCH] Change the float-type const tensor to half type. --- python/tools/model.template | 2 +- python/tools/source_converter_lib.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tools/model.template b/python/tools/model.template index 03273053..ff1b531f 100644 --- a/python/tools/model.template +++ b/python/tools/model.template @@ -17,7 +17,7 @@ namespace {{tag}} { void Create{{tensor.name}}(std::vector &tensors) { tensors.emplace_back(mace::ConstTensor( {{ tensor.name|tojson }}, {{ tensor.name }}, - { {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }}, {{ tensor.node_id }})); + { {{ tensor.dims|join(', ') }} }, {{ tensor_info.data_type }}, {{ tensor.node_id }})); } } // namespace {{tag}} diff --git a/python/tools/source_converter_lib.py b/python/tools/source_converter_lib.py index 5a93d6af..ebf65f04 100644 --- a/python/tools/source_converter_lib.py +++ b/python/tools/source_converter_lib.py @@ -77,9 +77,10 @@ class TensorInfo: self.name = t.name self.data_type = mace_pb2.DataType.Name(t.data_type) if t.data_type == mace_pb2.DT_FLOAT: - self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data)) + self.data_type = mace_pb2.DT_HALF + self.data = bytearray(np.array(t.float_data).astype(np.float16).tobytes()) elif t.data_type == mace_pb2.DT_INT32: - self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data)) + self.data = bytearray(np.array(t.int32_data).astype(np.int32).tobytes()) elif t.data_type == mace_pb2.DT_UINT8: self.data = bytearray(np.array(t.int32_data).astype(np.uint8).tolist()) -- GitLab