提交 68e3f169 编写于 作者: L liuqi

Fix validation_input_data bug.

上级 660dcff6
...@@ -371,6 +371,14 @@ def format_model_config(config_file_path): ...@@ -371,6 +371,14 @@ def format_model_config(config_file_path):
"'%s' is necessary in subgraph" % key) "'%s' is necessary in subgraph" % key)
if not isinstance(value, list): if not isinstance(value, list):
subgraph[key] = [value] subgraph[key] = [value]
validation_inputs_data = subgraph.get(
YAMLKeyword.validation_inputs_data, [])
if not isinstance(validation_inputs_data, list):
subgraph[YAMLKeyword.validation_inputs_data] = [
validation_inputs_data]
else:
subgraph[YAMLKeyword.validation_inputs_data] = \
validation_inputs_data
for key in [YAMLKeyword.limit_opencl_kernel_time, for key in [YAMLKeyword.limit_opencl_kernel_time,
YAMLKeyword.nnlib_graph_mode, YAMLKeyword.nnlib_graph_mode,
...@@ -380,15 +388,6 @@ def format_model_config(config_file_path): ...@@ -380,15 +388,6 @@ def format_model_config(config_file_path):
if value == "": if value == "":
model_config[key] = 0 model_config[key] = 0
validation_inputs_data = model_config.get(
YAMLKeyword.validation_inputs_data, [])
if not isinstance(validation_inputs_data, list):
model_config[YAMLKeyword.validation_inputs_data] = [
validation_inputs_data]
else:
model_config[YAMLKeyword.validation_inputs_data] = \
validation_inputs_data
weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "") weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
model_config[YAMLKeyword.weight_file_path] = weight_file_path model_config[YAMLKeyword.weight_file_path] = weight_file_path
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册