提交 a1c6ba92 编写于 作者: L liuqi

Fix input_files config bug.

上级 25a874f7
...@@ -74,7 +74,6 @@ struct ReOrganizeFunctor { ...@@ -74,7 +74,6 @@ struct ReOrganizeFunctor {
} }
} }
} }
} }
}; };
......
...@@ -789,7 +789,6 @@ class CaffeConverter(object): ...@@ -789,7 +789,6 @@ class CaffeConverter(object):
input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]]
output_shape = input_shape output_shape = input_shape
shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]] shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]]
print shape_param
for i in range(len(shape_param)): for i in range(len(shape_param)):
if shape_param[i] != 0: if shape_param[i] != 0:
output_shape[i] = shape_param[i] output_shape[i] = shape_param[i]
......
...@@ -76,26 +76,28 @@ def generate_random_input(target_soc, model_output_dir, ...@@ -76,26 +76,28 @@ def generate_random_input(target_soc, model_output_dir,
target_soc, model_output_dir, int(generate_data_or_not)) target_soc, model_output_dir, int(generate_data_or_not))
run_command(command) run_command(command)
input_name_list = []
input_file_list = [] input_file_list = []
if isinstance(input_names, list):
input_name_list.extend(input_names)
else:
input_name_list.append(input_names)
if isinstance(input_files, list): if isinstance(input_files, list):
input_file_list.extend(input_files) input_file_list.extend(input_files)
else: else:
input_file_list.append(input_files) input_file_list.append(input_files)
assert len(input_file_list) == len(input_name_list) if len(input_file_list) != 0:
for i in range(len(input_file_list)): input_name_list = []
if input_file_list[i] is not None: if isinstance(input_names, list):
dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i]) input_name_list.extend(input_names)
if input_file_list[i].startswith("http://") or \ else:
input_file_list[i].startswith("https://"): input_name_list.append(input_names)
urllib.urlretrieve(input_file_list[i], dst_input_file) if len(input_file_list) != len(input_name_list):
else: raise Exception('If input_files set, the input files should match the input names.')
print 'Copy input data:', dst_input_file for i in range(len(input_file_list)):
shutil.copy(input_file_list[i], dst_input_file) if input_file_list[i] is not None:
dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i])
if input_file_list[i].startswith("http://") or \
input_file_list[i].startswith("https://"):
urllib.urlretrieve(input_file_list[i], dst_input_file)
else:
print 'Copy input data:', dst_input_file
shutil.copy(input_file_list[i], dst_input_file)
def generate_model_code(): def generate_model_code():
command = "bash tools/generate_model_code.sh" command = "bash tools/generate_model_code.sh"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册