提交 89d36182 编写于 作者: 李寅

Merge branch 'fix-converter-bugs' into 'master'

fix converter bugs

See merge request !1182
......@@ -176,13 +176,28 @@ def convert_model(conf):
input_count = len(input_tensors)
input_data_types = [data_type_map[dt] for dt in
to_list(conf.get("input_data_types",
["float32"] * input_count))]
["float32"]))]
if len(input_data_types) == 1 and input_count > 1:
input_data_types = [input_data_types[0]] * input_count
mace_check(len(input_data_types) == input_count,
"the number of input_data_types should be "
"the same as input tensors")
input_data_formats = [data_format_map[df] for df in
to_list(conf.get("input_data_formats",
["NHWC"] * input_count))]
["NHWC"]))]
if len(input_data_formats) == 1 and input_count > 1:
input_data_formats = [input_data_formats[0]] * input_count
mace_check(len(input_data_formats) == input_count,
"the number of input_data_formats should be "
"the same as input tensors")
input_ranges = [parse_float_array_from_str(r) for r in
to_list(conf.get("input_ranges",
["-1.0,1.0"] * input_count))]
["-1.0,1.0"]))]
if len(input_ranges) == 1 and input_count > 1:
input_ranges = [input_ranges[0]] * input_count
mace_check(len(input_ranges) == input_count,
"the number of input_ranges should be "
"the same as input tensors")
for i in range(len(input_tensors)):
input_node = cvt.NodeInfo()
input_node.name = input_tensors[i]
......@@ -204,10 +219,20 @@ def convert_model(conf):
output_count = len(output_tensors)
output_data_types = [data_type_map[dt] for dt in
to_list(conf.get("output_data_types",
["float32"] * output_count))]
["float32"]))]
if len(output_data_types) == 1 and output_count > 1:
output_data_types = [output_data_types[0]] * output_count
mace_check(len(output_data_types) == output_count,
"the number of output_data_types should be "
"the same as output tensors")
output_data_formats = [data_format_map[df] for df in
to_list(conf.get("output_data_formats",
["NHWC"] * output_count))]
["NHWC"]))]
if len(output_data_formats) == 1 and output_count > 1:
output_data_formats = [output_data_formats[0]] * output_count
mace_check(len(output_data_formats) == output_count,
"the number of output_data_formats should be "
"the same as output tensors")
for i in range(len(output_tensors)):
output_node = cvt.NodeInfo()
output_node.name = output_tensors[i]
......
......@@ -46,7 +46,7 @@ CPP_KEYWORDS = [
def sanitize_load(s):
# do not let yaml parse ON/OFF to boolean
for w in ["ON", "OFF", "on", "off"]:
s = re.sub(r":\s+" + w, r": '" + w + "'", s)
s = re.sub(r":\s+" + w + "$", r": '" + w + "'", s)
# sub ${} to env value
s = re.sub(r"\${(\w+)}", lambda x: os.environ[x.group(1)], s)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册