diff --git a/tools/python/convert.py b/tools/python/convert.py index cfc04d0e963b3056783ab7a311142bcc781e2a8a..1ef320c5110aa9f8095a4846a88486128455f104 100644 --- a/tools/python/convert.py +++ b/tools/python/convert.py @@ -176,13 +176,28 @@ def convert_model(conf): input_count = len(input_tensors) input_data_types = [data_type_map[dt] for dt in to_list(conf.get("input_data_types", - ["float32"] * input_count))] + ["float32"]))] + if len(input_data_types) == 1 and input_count > 1: + input_data_types = [input_data_types[0]] * input_count + mace_check(len(input_data_types) == input_count, + "the number of input_data_types should be " + "the same as input tensors") input_data_formats = [data_format_map[df] for df in to_list(conf.get("input_data_formats", - ["NHWC"] * input_count))] + ["NHWC"]))] + if len(input_data_formats) == 1 and input_count > 1: + input_data_formats = [input_data_formats[0]] * input_count + mace_check(len(input_data_formats) == input_count, + "the number of input_data_formats should be " + "the same as input tensors") input_ranges = [parse_float_array_from_str(r) for r in to_list(conf.get("input_ranges", - ["-1.0,1.0"] * input_count))] + ["-1.0,1.0"]))] + if len(input_ranges) == 1 and input_count > 1: + input_ranges = [input_ranges[0]] * input_count + mace_check(len(input_ranges) == input_count, + "the number of input_ranges should be " + "the same as input tensors") for i in range(len(input_tensors)): input_node = cvt.NodeInfo() input_node.name = input_tensors[i] @@ -204,10 +219,20 @@ def convert_model(conf): output_count = len(output_tensors) output_data_types = [data_type_map[dt] for dt in to_list(conf.get("output_data_types", - ["float32"] * output_count))] + ["float32"]))] + if len(output_data_types) == 1 and output_count > 1: + output_data_types = [output_data_types[0]] * output_count + mace_check(len(output_data_types) == output_count, + "the number of output_data_types should be " + "the same as output tensors") output_data_formats = [data_format_map[df] for df in to_list(conf.get("output_data_formats", - ["NHWC"] * output_count))] + ["NHWC"]))] + if len(output_data_formats) == 1 and output_count > 1: + output_data_formats = [output_data_formats[0]] * output_count + mace_check(len(output_data_formats) == output_count, + "the number of output_data_formats should be " + "the same as output tensors") for i in range(len(output_tensors)): output_node = cvt.NodeInfo() output_node.name = output_tensors[i] diff --git a/tools/python/utils/config_parser.py b/tools/python/utils/config_parser.py index cd5089170c79927434e555af476498df586a7925..e4e2a04c4c783947f2398adbaee71ec14707bdbb 100644 --- a/tools/python/utils/config_parser.py +++ b/tools/python/utils/config_parser.py @@ -46,7 +46,7 @@ CPP_KEYWORDS = [ def sanitize_load(s): # do not let yaml parse ON/OFF to boolean for w in ["ON", "OFF", "on", "off"]: - s = re.sub(r":\s+" + w, r": '" + w + "'", s) + s = re.sub(r":\s+" + w + "$", r": '" + w + "'", s) # sub ${} to env value s = re.sub(r"\${(\w+)}", lambda x: os.environ[x.group(1)], s)