提交 6073bca5 编写于 作者: 李滨

Merge branch 'summary' into 'master'

Fix check tensors func

See merge request !1173
...@@ -140,8 +140,6 @@ quantization_tests: ...@@ -140,8 +140,6 @@ quantization_tests:
python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --device_yml=${DEVICE_CONF_FILE} --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1; python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --device_yml=${DEVICE_CONF_FILE} --round=1 --validate --model_graph_format=file --model_data_format=file || exit 1;
done done
- rm -rf mace-models - rm -rf mace-models
only:
- triggers
dynamic_linking_test: dynamic_linking_test:
stage: extra stage: extra
......
...@@ -117,6 +117,12 @@ def convert(conf, output): ...@@ -117,6 +117,12 @@ def convert(conf, output):
model_conf["weight_file_path"], model_conf["weight_file_path"],
model_conf["weight_sha256_checksum"], model_output) model_conf["weight_sha256_checksum"], model_output)
model_conf["weight_file_path"] = weight_file model_conf["weight_file_path"] = weight_file
# TODO: remove the following after quantize tool is made
if "quantize_range_file" in model_conf:
range_file = util.download_or_get_file(
model_conf["quantize_range_file"],
"", model_output)
model_conf["quantize_range_file"] = range_file
mace_model = convert_model(model_conf) mace_model = convert_model(model_conf)
...@@ -215,16 +221,16 @@ def convert_model(conf): ...@@ -215,16 +221,16 @@ def convert_model(conf):
output_node.data_format = cvt.DataFormat.NHWC output_node.data_format = cvt.DataFormat.NHWC
option.add_output_node(output_node) option.add_output_node(output_node)
if "check_node" in conf: if "check_tensors" in conf:
check_node_names = to_list(conf["check_node"]) check_tensors = to_list(conf["check_tensors"])
check_node_shapes = [parse_int_array_from_str(shape) for shape in check_tensors_shapes = [parse_int_array_from_str(shape) for shape in
to_list(conf["check_shape"])] to_list(conf["check_shapes"])]
mace_check(len(check_node_names) == len(check_node_shapes), mace_check(len(check_tensors) == len(check_tensors_shapes),
"check node count and shape count do not match.") "check tensors count and shape count do not match.")
for i in range(len(check_node_names)): for i in range(len(check_tensors)):
check_node = cvt.NodeInfo() check_node = cvt.NodeInfo()
check_node.name = check_node_names[i] check_node.name = check_tensors[i]
check_node.shape = check_node_shapes[i] check_node.shape = check_tensors_shapes[i]
option.add_check_node(check_node) option.add_check_node(check_node)
else: else:
option.check_nodes = option.output_nodes option.check_nodes = option.output_nodes
...@@ -276,18 +282,23 @@ def merge_params(net_def): ...@@ -276,18 +282,23 @@ def merge_params(net_def):
if tensor.data_type == mace_pb2.DT_HALF: if tensor.data_type == mace_pb2.DT_HALF:
data = bytearray( data = bytearray(
np.array(tensor.float_data).astype(np.float16).tobytes()) np.array(tensor.float_data).astype(np.float16).tobytes())
tensor.data_size = len(tensor.float_data)
elif tensor.data_type == mace_pb2.DT_FLOAT: elif tensor.data_type == mace_pb2.DT_FLOAT:
data = bytearray( data = bytearray(
np.array(tensor.float_data).astype(np.float32).tobytes()) np.array(tensor.float_data).astype(np.float32).tobytes())
tensor.data_size = len(tensor.float_data)
elif tensor.data_type == mace_pb2.DT_INT32: elif tensor.data_type == mace_pb2.DT_INT32:
data = bytearray( data = bytearray(
np.array(tensor.int32_data).astype(np.int32).tobytes()) np.array(tensor.int32_data).astype(np.int32).tobytes())
tensor.data_size = len(tensor.int32_data)
elif tensor.data_type == mace_pb2.DT_UINT8: elif tensor.data_type == mace_pb2.DT_UINT8:
data = bytearray( data = bytearray(
np.array(tensor.int32_data).astype(np.uint8).tolist()) np.array(tensor.int32_data).astype(np.uint8).tolist())
tensor.data_size = len(tensor.int32_data)
elif tensor.data_type == mace_pb2.DT_FLOAT16: elif tensor.data_type == mace_pb2.DT_FLOAT16:
data = bytearray( data = bytearray(
np.array(tensor.float_data).astype(np.float16).tobytes()) np.array(tensor.float_data).astype(np.float16).tobytes())
tensor.data_size = len(tensor.float_data)
else: else:
raise Exception('Tensor data type %s not supported' % raise Exception('Tensor data type %s not supported' %
tensor.data_type) tensor.data_type)
......
...@@ -1314,19 +1314,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1314,19 +1314,6 @@ class Transformer(base_converter.ConverterInterface):
data_type = self._option.data_type data_type = self._option.data_type
net.data_type = data_type net.data_type = data_type
for tensor in net.tensors:
if tensor.data_type == mace_pb2.DT_FLOAT:
tensor.data_type = data_type
if tensor.data_type == mace_pb2.DT_FLOAT \
or tensor.data_type == mace_pb2.DT_HALF \
or tensor.data_type == mace_pb2.DT_FLOAT16:
tensor.data_size = len(tensor.float_data)
elif tensor.data_type == mace_pb2.DT_INT32:
tensor.data_size = len(tensor.int32_data)
elif tensor.data_type == mace_pb2.DT_UINT8:
tensor.data_size = len(tensor.int32_data)
if self._option.quantize: if self._option.quantize:
return return
......
...@@ -87,15 +87,16 @@ def file_checksum(fname): ...@@ -87,15 +87,16 @@ def file_checksum(fname):
def download_or_get_file(file, def download_or_get_file(file,
sha256_checksum, sha256_checksum,
output_dir): output_dir):
model_file = output_dir + "/" + sha256_checksum + ".pb" filename = os.path.basename(file)
output_file = "%s/%s-%s.pb" % (output_dir, filename, sha256_checksum)
if file.startswith("http://") or file.startswith("https://"): if file.startswith("http://") or file.startswith("https://"):
if not os.path.exists(model_file) or file_checksum( if not os.path.exists(output_file) or file_checksum(
model_file) != sha256_checksum: output_file) != sha256_checksum:
MaceLogger.info("Downloading file %s, please wait ..." % file) MaceLogger.info("Downloading file %s, please wait ..." % file)
urllib.urlretrieve(file, model_file) urllib.urlretrieve(file, output_file)
MaceLogger.info("Model downloaded successfully.") MaceLogger.info("Model downloaded successfully.")
else: else:
device.execute("cp %s %s" % (file, model_file)) device.execute("cp %s %s" % (file, output_file))
return model_file return output_file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册