提交 d5275892 编写于 作者: L liuqi

Fix source converter bug: use half type for cpu runtime.

上级 cc2908e7
...@@ -73,12 +73,16 @@ def rename_tensor(net_def): ...@@ -73,12 +73,16 @@ def rename_tensor(net_def):
op.output[i] = tensor_map[op.output[i]] op.output[i] = tensor_map[op.output[i]]
class TensorInfo: class TensorInfo:
def __init__(self, t): def __init__(self, t, runtime):
self.name = t.name self.name = t.name
self.data_type = mace_pb2.DataType.Name(t.data_type) self.data_type = mace_pb2.DataType.Name(t.data_type)
if t.data_type == mace_pb2.DT_FLOAT: if t.data_type == mace_pb2.DT_FLOAT:
self.data_type = mace_pb2.DT_HALF if runtime == 'gpu':
self.data = bytearray(np.array(t.float_data).astype(np.float16).tobytes()) self.data_type = mace_pb2.DT_HALF
self.data = bytearray(np.array(t.float_data).astype(np.float16).tobytes())
else:
self.data_type = mace_pb2.DT_FLOAT
self.data = bytearray(np.array(t.float_data).astype(np.float32).tobytes())
elif t.data_type == mace_pb2.DT_INT32: elif t.data_type == mace_pb2.DT_INT32:
self.data = bytearray(np.array(t.int32_data).astype(np.int32).tobytes()) self.data = bytearray(np.array(t.int32_data).astype(np.int32).tobytes())
elif t.data_type == mace_pb2.DT_UINT8: elif t.data_type == mace_pb2.DT_UINT8:
...@@ -107,7 +111,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -107,7 +111,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
# generate tensor source files # generate tensor source files
for t in net_def.tensors: for t in net_def.tensors:
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tensor_info = TensorInfo(t), tensor_info = TensorInfo(t, runtime),
tensor = t, tensor = t,
tag = model_tag, tag = model_tag,
mode = 0, mode = 0,
...@@ -134,7 +138,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -134,7 +138,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
counter += 1 counter += 1
# generate model source files # generate model source files
tensors = [TensorInfo(t) for t in net_def.tensors] tensors = [TensorInfo(t, runtime) for t in net_def.tensors]
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tensors = tensors, tensors = tensors,
net = net_def, net = net_def,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册