提交 07e60339 编写于 作者: 卢旭辉

Merge branch 'fix_fp16' into 'master'

fix fp16 support

See merge request !1219
......@@ -82,7 +82,8 @@ def convert(conf, output):
except: # noqa
print("Failed to visualize model:", sys.exc_info())
model, params = merge_params(mace_model)
model, params = merge_params(mace_model,
model_conf[ModelKeys.data_type])
output_model_file = model_output + "/" + model_name + ".pb"
output_params_file = model_output + "/" + model_name + ".data"
......@@ -120,7 +121,7 @@ def convert_model(conf):
# used by `base_converter`
option.device = option.device.value
option.data_type = conf[ModelKeys.data_types]
option.data_type = conf[ModelKeys.data_type]
for i in range(len(conf[ModelKeys.input_tensors])):
input_node = cvt.NodeInfo()
......@@ -200,7 +201,7 @@ def convert_model(conf):
return output_graph_def
def merge_params(net_def):
def merge_params(net_def, data_type):
def tensor_to_bytes(tensor):
if tensor.data_type == mace_pb2.DT_HALF:
data = bytearray(
......@@ -230,6 +231,8 @@ def merge_params(net_def):
model_data = []
offset = 0
for tensor in net_def.tensors:
if tensor.data_type == mace_pb2.DT_FLOAT:
tensor.data_type = data_type
raw_data = tensor_to_bytes(tensor)
if tensor.data_type != mace_pb2.DT_UINT8 and offset % 4 != 0:
padding = 4 - offset % 4
......
......@@ -95,7 +95,7 @@ class ModelKeys(object):
change_concat_ranges = "change_concat_ranges"
winograd = "winograd"
cl_mem_type = "cl_mem_type"
data_types = "data_types"
data_type = "data_type"
subgraphs = "subgraphs"
validation_inputs_data = "validation_inputs_data"
......@@ -205,13 +205,13 @@ def normalize_model_config(conf):
conf[ModelKeys.runtime] = parse_device_type(conf[ModelKeys.runtime])
if ModelKeys.quantize in conf:
conf[ModelKeys.data_types] = mace_pb2.DT_FLOAT
conf[ModelKeys.data_type] = mace_pb2.DT_FLOAT
else:
if ModelKeys.data_types in conf:
conf[ModelKeys.data_types] = parse_internal_data_type(
conf[ModelKeys.data_types])
if ModelKeys.data_type in conf:
conf[ModelKeys.data_type] = parse_internal_data_type(
conf[ModelKeys.data_type])
else:
conf[ModelKeys.data_types] = mace_pb2.DT_HALF
conf[ModelKeys.data_type] = mace_pb2.DT_HALF
# parse input
conf[ModelKeys.input_tensors] = to_list(conf[ModelKeys.input_tensors])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册