From d52758925a202940786b0d5b1cc94a3f3959b1eb Mon Sep 17 00:00:00 2001 From: liuqi Date: Tue, 13 Feb 2018 14:02:01 +0800 Subject: [PATCH] Fix source converter bug: use half type for cpu runtime. --- python/tools/source_converter_lib.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/tools/source_converter_lib.py b/python/tools/source_converter_lib.py index ebf65f04..d842ffab 100644 --- a/python/tools/source_converter_lib.py +++ b/python/tools/source_converter_lib.py @@ -73,12 +73,16 @@ def rename_tensor(net_def): op.output[i] = tensor_map[op.output[i]] class TensorInfo: - def __init__(self, t): + def __init__(self, t, runtime): self.name = t.name self.data_type = mace_pb2.DataType.Name(t.data_type) if t.data_type == mace_pb2.DT_FLOAT: - self.data_type = mace_pb2.DT_HALF - self.data = bytearray(np.array(t.float_data).astype(np.float16).tobytes()) + if runtime == 'gpu': + self.data_type = mace_pb2.DT_HALF + self.data = bytearray(np.array(t.float_data).astype(np.float16).tobytes()) + else: + self.data_type = mace_pb2.DT_FLOAT + self.data = bytearray(np.array(t.float_data).astype(np.float32).tobytes()) elif t.data_type == mace_pb2.DT_INT32: self.data = bytearray(np.array(t.int32_data).astype(np.int32).tobytes()) elif t.data_type == mace_pb2.DT_UINT8: @@ -107,7 +111,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, # generate tensor source files for t in net_def.tensors: source = j2_env.get_template(template_name).render( - tensor_info = TensorInfo(t), + tensor_info = TensorInfo(t, runtime), tensor = t, tag = model_tag, mode = 0, @@ -134,7 +138,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, counter += 1 # generate model source files - tensors = [TensorInfo(t) for t in net_def.tensors] + tensors = [TensorInfo(t, runtime) for t in net_def.tensors] source = j2_env.get_template(template_name).render( tensors = tensors, net = net_def, -- GitLab