From 007c4380adaa9fa559b635c2142fed53a6a88572 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8F=B6=E5=89=91=E6=AD=A6?= Date: Fri, 1 Nov 2019 09:55:08 +0800 Subject: [PATCH] fix fp16 support format code --- tools/python/convert.py | 9 ++++++--- tools/python/utils/config_parser.py | 12 ++++++------ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tools/python/convert.py b/tools/python/convert.py index 24ad303b..3e2a82cb 100644 --- a/tools/python/convert.py +++ b/tools/python/convert.py @@ -82,7 +82,8 @@ def convert(conf, output): except: # noqa print("Failed to visualize model:", sys.exc_info()) - model, params = merge_params(mace_model) + model, params = merge_params(mace_model, + model_conf[ModelKeys.data_type]) output_model_file = model_output + "/" + model_name + ".pb" output_params_file = model_output + "/" + model_name + ".data" @@ -120,7 +121,7 @@ def convert_model(conf): # used by `base_converter` option.device = option.device.value - option.data_type = conf[ModelKeys.data_types] + option.data_type = conf[ModelKeys.data_type] for i in range(len(conf[ModelKeys.input_tensors])): input_node = cvt.NodeInfo() @@ -200,7 +201,7 @@ def convert_model(conf): return output_graph_def -def merge_params(net_def): +def merge_params(net_def, data_type): def tensor_to_bytes(tensor): if tensor.data_type == mace_pb2.DT_HALF: data = bytearray( @@ -230,6 +231,8 @@ def merge_params(net_def): model_data = [] offset = 0 for tensor in net_def.tensors: + if tensor.data_type == mace_pb2.DT_FLOAT: + tensor.data_type = data_type raw_data = tensor_to_bytes(tensor) if tensor.data_type != mace_pb2.DT_UINT8 and offset % 4 != 0: padding = 4 - offset % 4 diff --git a/tools/python/utils/config_parser.py b/tools/python/utils/config_parser.py index 35de4bd9..56b7a341 100644 --- a/tools/python/utils/config_parser.py +++ b/tools/python/utils/config_parser.py @@ -95,7 +95,7 @@ class ModelKeys(object): change_concat_ranges = "change_concat_ranges" winograd = "winograd" cl_mem_type = "cl_mem_type" - data_types = "data_types" + data_type = "data_type" subgraphs = "subgraphs" validation_inputs_data = "validation_inputs_data" @@ -205,13 +205,13 @@ def normalize_model_config(conf): conf[ModelKeys.runtime] = parse_device_type(conf[ModelKeys.runtime]) if ModelKeys.quantize in conf: - conf[ModelKeys.data_types] = mace_pb2.DT_FLOAT + conf[ModelKeys.data_type] = mace_pb2.DT_FLOAT else: - if ModelKeys.data_types in conf: - conf[ModelKeys.data_types] = parse_internal_data_type( - conf[ModelKeys.data_types]) + if ModelKeys.data_type in conf: + conf[ModelKeys.data_type] = parse_internal_data_type( + conf[ModelKeys.data_type]) else: - conf[ModelKeys.data_types] = mace_pb2.DT_HALF + conf[ModelKeys.data_type] = mace_pb2.DT_HALF # parse input conf[ModelKeys.input_tensors] = to_list(conf[ModelKeys.input_tensors]) -- GitLab