提交 07e60339 编写于 作者: 卢旭辉

Merge branch 'fix_fp16' into 'master'

fix fp16 support

See merge request !1219
...@@ -82,7 +82,8 @@ def convert(conf, output): ...@@ -82,7 +82,8 @@ def convert(conf, output):
except: # noqa except: # noqa
print("Failed to visualize model:", sys.exc_info()) print("Failed to visualize model:", sys.exc_info())
model, params = merge_params(mace_model) model, params = merge_params(mace_model,
model_conf[ModelKeys.data_type])
output_model_file = model_output + "/" + model_name + ".pb" output_model_file = model_output + "/" + model_name + ".pb"
output_params_file = model_output + "/" + model_name + ".data" output_params_file = model_output + "/" + model_name + ".data"
...@@ -120,7 +121,7 @@ def convert_model(conf): ...@@ -120,7 +121,7 @@ def convert_model(conf):
# used by `base_converter` # used by `base_converter`
option.device = option.device.value option.device = option.device.value
option.data_type = conf[ModelKeys.data_types] option.data_type = conf[ModelKeys.data_type]
for i in range(len(conf[ModelKeys.input_tensors])): for i in range(len(conf[ModelKeys.input_tensors])):
input_node = cvt.NodeInfo() input_node = cvt.NodeInfo()
...@@ -200,7 +201,7 @@ def convert_model(conf): ...@@ -200,7 +201,7 @@ def convert_model(conf):
return output_graph_def return output_graph_def
def merge_params(net_def): def merge_params(net_def, data_type):
def tensor_to_bytes(tensor): def tensor_to_bytes(tensor):
if tensor.data_type == mace_pb2.DT_HALF: if tensor.data_type == mace_pb2.DT_HALF:
data = bytearray( data = bytearray(
...@@ -230,6 +231,8 @@ def merge_params(net_def): ...@@ -230,6 +231,8 @@ def merge_params(net_def):
model_data = [] model_data = []
offset = 0 offset = 0
for tensor in net_def.tensors: for tensor in net_def.tensors:
if tensor.data_type == mace_pb2.DT_FLOAT:
tensor.data_type = data_type
raw_data = tensor_to_bytes(tensor) raw_data = tensor_to_bytes(tensor)
if tensor.data_type != mace_pb2.DT_UINT8 and offset % 4 != 0: if tensor.data_type != mace_pb2.DT_UINT8 and offset % 4 != 0:
padding = 4 - offset % 4 padding = 4 - offset % 4
......
...@@ -95,7 +95,7 @@ class ModelKeys(object): ...@@ -95,7 +95,7 @@ class ModelKeys(object):
change_concat_ranges = "change_concat_ranges" change_concat_ranges = "change_concat_ranges"
winograd = "winograd" winograd = "winograd"
cl_mem_type = "cl_mem_type" cl_mem_type = "cl_mem_type"
data_types = "data_types" data_type = "data_type"
subgraphs = "subgraphs" subgraphs = "subgraphs"
validation_inputs_data = "validation_inputs_data" validation_inputs_data = "validation_inputs_data"
...@@ -205,13 +205,13 @@ def normalize_model_config(conf): ...@@ -205,13 +205,13 @@ def normalize_model_config(conf):
conf[ModelKeys.runtime] = parse_device_type(conf[ModelKeys.runtime]) conf[ModelKeys.runtime] = parse_device_type(conf[ModelKeys.runtime])
if ModelKeys.quantize in conf: if ModelKeys.quantize in conf:
conf[ModelKeys.data_types] = mace_pb2.DT_FLOAT conf[ModelKeys.data_type] = mace_pb2.DT_FLOAT
else: else:
if ModelKeys.data_types in conf: if ModelKeys.data_type in conf:
conf[ModelKeys.data_types] = parse_internal_data_type( conf[ModelKeys.data_type] = parse_internal_data_type(
conf[ModelKeys.data_types]) conf[ModelKeys.data_type])
else: else:
conf[ModelKeys.data_types] = mace_pb2.DT_HALF conf[ModelKeys.data_type] = mace_pb2.DT_HALF
# parse input # parse input
conf[ModelKeys.input_tensors] = to_list(conf[ModelKeys.input_tensors]) conf[ModelKeys.input_tensors] = to_list(conf[ModelKeys.input_tensors])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册