From a1c6ba9257c0b8abea09959f0d6f832a13b526c5 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 28 Mar 2018 14:06:30 +0800 Subject: [PATCH] Fix input_files config bug. --- mace/kernels/reorganize.h | 1 - mace/python/tools/caffe_converter_lib.py | 1 - tools/mace_tools.py | 32 +++++++++++++----------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/mace/kernels/reorganize.h b/mace/kernels/reorganize.h index 68c77209..a64d55b9 100644 --- a/mace/kernels/reorganize.h +++ b/mace/kernels/reorganize.h @@ -74,7 +74,6 @@ struct ReOrganizeFunctor { } } } - } }; diff --git a/mace/python/tools/caffe_converter_lib.py b/mace/python/tools/caffe_converter_lib.py index 7c7cd9ab..2cd2107a 100644 --- a/mace/python/tools/caffe_converter_lib.py +++ b/mace/python/tools/caffe_converter_lib.py @@ -789,7 +789,6 @@ class CaffeConverter(object): input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] output_shape = input_shape shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]] - print shape_param for i in range(len(shape_param)): if shape_param[i] != 0: output_shape[i] = shape_param[i] diff --git a/tools/mace_tools.py b/tools/mace_tools.py index c9a22f64..4f2b209a 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -76,26 +76,28 @@ def generate_random_input(target_soc, model_output_dir, target_soc, model_output_dir, int(generate_data_or_not)) run_command(command) - input_name_list = [] input_file_list = [] - if isinstance(input_names, list): - input_name_list.extend(input_names) - else: - input_name_list.append(input_names) if isinstance(input_files, list): input_file_list.extend(input_files) else: input_file_list.append(input_files) - assert len(input_file_list) == len(input_name_list) - for i in range(len(input_file_list)): - if input_file_list[i] is not None: - dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i]) - if input_file_list[i].startswith("http://") or \ - input_file_list[i].startswith("https://"): - urllib.urlretrieve(input_file_list[i], dst_input_file) - else: - print 'Copy input data:', dst_input_file - shutil.copy(input_file_list[i], dst_input_file) + if len(input_file_list) != 0: + input_name_list = [] + if isinstance(input_names, list): + input_name_list.extend(input_names) + else: + input_name_list.append(input_names) + if len(input_file_list) != len(input_name_list): + raise Exception('If input_files set, the input files should match the input names.') + for i in range(len(input_file_list)): + if input_file_list[i] is not None: + dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i]) + if input_file_list[i].startswith("http://") or \ + input_file_list[i].startswith("https://"): + urllib.urlretrieve(input_file_list[i], dst_input_file) + else: + print 'Copy input data:', dst_input_file + shutil.copy(input_file_list[i], dst_input_file) def generate_model_code(): command = "bash tools/generate_model_code.sh" -- GitLab