From aadc9ded2748b21e38c17893ed66ea83f3406da1 Mon Sep 17 00:00:00 2001 From: liyin Date: Mon, 12 Aug 2019 17:45:12 +0800 Subject: [PATCH] Fix check tensors func --- .gitlab-ci.yml | 2 -- tools/python/convert.py | 29 ++++++++++++++++++--------- tools/python/transform/transformer.py | 13 ------------ tools/python/utils/util.py | 13 ++++++------ 4 files changed, 27 insertions(+), 30 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a6f6b42f..ce199ed0 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -140,8 +140,6 @@ quantization_tests: python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --device_yml=${DEVICE_CONF_FILE} --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1; done - rm -rf mace-models - only: - - triggers dynamic_linking_test: stage: extra diff --git a/tools/python/convert.py b/tools/python/convert.py index 1ea115ee..66143a9d 100644 --- a/tools/python/convert.py +++ b/tools/python/convert.py @@ -117,6 +117,12 @@ def convert(conf, output): model_conf["weight_file_path"], model_conf["weight_sha256_checksum"], model_output) model_conf["weight_file_path"] = weight_file + # TODO: remove the following after quantize tool is made + if "quantize_range_file" in model_conf: + range_file = util.download_or_get_file( + model_conf["quantize_range_file"], + "", model_output) + model_conf["quantize_range_file"] = range_file mace_model = convert_model(model_conf) @@ -215,16 +221,16 @@ def convert_model(conf): output_node.data_format = cvt.DataFormat.NHWC option.add_output_node(output_node) - if "check_node" in conf: - check_node_names = to_list(conf["check_node"]) - check_node_shapes = [parse_int_array_from_str(shape) for shape in - to_list(conf["check_shape"])] - mace_check(len(check_node_names) == len(check_node_shapes), - "check node count and shape count do not match.") - for i in range(len(check_node_names)): + if "check_tensors" in conf: + check_tensors = to_list(conf["check_tensors"]) + check_tensors_shapes = [parse_int_array_from_str(shape) for shape in + to_list(conf["check_shapes"])] + mace_check(len(check_tensors) == len(check_tensors_shapes), + "check tensors count and shape count do not match.") + for i in range(len(check_tensors)): check_node = cvt.NodeInfo() - check_node.name = check_node_names[i] - check_node.shape = check_node_shapes[i] + check_node.name = check_tensors[i] + check_node.shape = check_tensors_shapes[i] option.add_check_node(check_node) else: option.check_nodes = option.output_nodes @@ -276,18 +282,23 @@ def merge_params(net_def): if tensor.data_type == mace_pb2.DT_HALF: data = bytearray( np.array(tensor.float_data).astype(np.float16).tobytes()) + tensor.data_size = len(tensor.float_data) elif tensor.data_type == mace_pb2.DT_FLOAT: data = bytearray( np.array(tensor.float_data).astype(np.float32).tobytes()) + tensor.data_size = len(tensor.float_data) elif tensor.data_type == mace_pb2.DT_INT32: data = bytearray( np.array(tensor.int32_data).astype(np.int32).tobytes()) + tensor.data_size = len(tensor.int32_data) elif tensor.data_type == mace_pb2.DT_UINT8: data = bytearray( np.array(tensor.int32_data).astype(np.uint8).tolist()) + tensor.data_size = len(tensor.int32_data) elif tensor.data_type == mace_pb2.DT_FLOAT16: data = bytearray( np.array(tensor.float_data).astype(np.float16).tobytes()) + tensor.data_size = len(tensor.float_data) else: raise Exception('Tensor data type %s not supported' % tensor.data_type) diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index d815cd99..cc4b61f3 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -1314,19 +1314,6 @@ class Transformer(base_converter.ConverterInterface): data_type = self._option.data_type net.data_type = data_type - for tensor in net.tensors: - if tensor.data_type == mace_pb2.DT_FLOAT: - tensor.data_type = data_type - - if tensor.data_type == mace_pb2.DT_FLOAT \ - or tensor.data_type == mace_pb2.DT_HALF \ - or tensor.data_type == mace_pb2.DT_FLOAT16: - tensor.data_size = len(tensor.float_data) - elif tensor.data_type == mace_pb2.DT_INT32: - tensor.data_size = len(tensor.int32_data) - elif tensor.data_type == mace_pb2.DT_UINT8: - tensor.data_size = len(tensor.int32_data) - if self._option.quantize: return diff --git a/tools/python/utils/util.py b/tools/python/utils/util.py index 8cdb98f6..423a9ef6 100644 --- a/tools/python/utils/util.py +++ b/tools/python/utils/util.py @@ -87,15 +87,16 @@ def file_checksum(fname): def download_or_get_file(file, sha256_checksum, output_dir): - model_file = output_dir + "/" + sha256_checksum + ".pb" + filename = os.path.basename(file) + output_file = "%s/%s-%s.pb" % (output_dir, filename, sha256_checksum) if file.startswith("http://") or file.startswith("https://"): - if not os.path.exists(model_file) or file_checksum( - model_file) != sha256_checksum: + if not os.path.exists(output_file) or file_checksum( + output_file) != sha256_checksum: MaceLogger.info("Downloading file %s, please wait ..." % file) - urllib.urlretrieve(file, model_file) + urllib.urlretrieve(file, output_file) MaceLogger.info("Model downloaded successfully.") else: - device.execute("cp %s %s" % (file, model_file)) + device.execute("cp %s %s" % (file, output_file)) - return model_file + return output_file -- GitLab