提交 6073bca5 编写于 作者: 李滨

Merge branch 'summary' into 'master'

Fix check tensors func

See merge request !1173
......@@ -140,8 +140,6 @@ quantization_tests:
python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --device_yml=${DEVICE_CONF_FILE} --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
done
- rm -rf mace-models
only:
- triggers
dynamic_linking_test:
stage: extra
......
......@@ -117,6 +117,12 @@ def convert(conf, output):
model_conf["weight_file_path"],
model_conf["weight_sha256_checksum"], model_output)
model_conf["weight_file_path"] = weight_file
# TODO: remove the following after quantize tool is made
if "quantize_range_file" in model_conf:
range_file = util.download_or_get_file(
model_conf["quantize_range_file"],
"", model_output)
model_conf["quantize_range_file"] = range_file
mace_model = convert_model(model_conf)
......@@ -215,16 +221,16 @@ def convert_model(conf):
output_node.data_format = cvt.DataFormat.NHWC
option.add_output_node(output_node)
if "check_node" in conf:
check_node_names = to_list(conf["check_node"])
check_node_shapes = [parse_int_array_from_str(shape) for shape in
to_list(conf["check_shape"])]
mace_check(len(check_node_names) == len(check_node_shapes),
"check node count and shape count do not match.")
for i in range(len(check_node_names)):
if "check_tensors" in conf:
check_tensors = to_list(conf["check_tensors"])
check_tensors_shapes = [parse_int_array_from_str(shape) for shape in
to_list(conf["check_shapes"])]
mace_check(len(check_tensors) == len(check_tensors_shapes),
"check tensors count and shape count do not match.")
for i in range(len(check_tensors)):
check_node = cvt.NodeInfo()
check_node.name = check_node_names[i]
check_node.shape = check_node_shapes[i]
check_node.name = check_tensors[i]
check_node.shape = check_tensors_shapes[i]
option.add_check_node(check_node)
else:
option.check_nodes = option.output_nodes
......@@ -276,18 +282,23 @@ def merge_params(net_def):
if tensor.data_type == mace_pb2.DT_HALF:
data = bytearray(
np.array(tensor.float_data).astype(np.float16).tobytes())
tensor.data_size = len(tensor.float_data)
elif tensor.data_type == mace_pb2.DT_FLOAT:
data = bytearray(
np.array(tensor.float_data).astype(np.float32).tobytes())
tensor.data_size = len(tensor.float_data)
elif tensor.data_type == mace_pb2.DT_INT32:
data = bytearray(
np.array(tensor.int32_data).astype(np.int32).tobytes())
tensor.data_size = len(tensor.int32_data)
elif tensor.data_type == mace_pb2.DT_UINT8:
data = bytearray(
np.array(tensor.int32_data).astype(np.uint8).tolist())
tensor.data_size = len(tensor.int32_data)
elif tensor.data_type == mace_pb2.DT_FLOAT16:
data = bytearray(
np.array(tensor.float_data).astype(np.float16).tobytes())
tensor.data_size = len(tensor.float_data)
else:
raise Exception('Tensor data type %s not supported' %
tensor.data_type)
......
......@@ -1314,19 +1314,6 @@ class Transformer(base_converter.ConverterInterface):
data_type = self._option.data_type
net.data_type = data_type
for tensor in net.tensors:
if tensor.data_type == mace_pb2.DT_FLOAT:
tensor.data_type = data_type
if tensor.data_type == mace_pb2.DT_FLOAT \
or tensor.data_type == mace_pb2.DT_HALF \
or tensor.data_type == mace_pb2.DT_FLOAT16:
tensor.data_size = len(tensor.float_data)
elif tensor.data_type == mace_pb2.DT_INT32:
tensor.data_size = len(tensor.int32_data)
elif tensor.data_type == mace_pb2.DT_UINT8:
tensor.data_size = len(tensor.int32_data)
if self._option.quantize:
return
......
......@@ -87,15 +87,16 @@ def file_checksum(fname):
def download_or_get_file(file,
sha256_checksum,
output_dir):
model_file = output_dir + "/" + sha256_checksum + ".pb"
filename = os.path.basename(file)
output_file = "%s/%s-%s.pb" % (output_dir, filename, sha256_checksum)
if file.startswith("http://") or file.startswith("https://"):
if not os.path.exists(model_file) or file_checksum(
model_file) != sha256_checksum:
if not os.path.exists(output_file) or file_checksum(
output_file) != sha256_checksum:
MaceLogger.info("Downloading file %s, please wait ..." % file)
urllib.urlretrieve(file, model_file)
urllib.urlretrieve(file, output_file)
MaceLogger.info("Model downloaded successfully.")
else:
device.execute("cp %s %s" % (file, model_file))
device.execute("cp %s %s" % (file, output_file))
return model_file
return output_file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册