提交 682ffb58 编写于 作者: L liuqi

Remove skip_validation flag and add validation_inputs_data flag

上级 33b049e6
...@@ -20,6 +20,8 @@ models: ...@@ -20,6 +20,8 @@ models:
dsp_mode: 0 dsp_mode: 0
obfuscate: 1 obfuscate: 1
fast_conv: 0 fast_conv: 0
validation_inputs_data:
- path/to/input_files
capture_net: capture_net:
platform: caffe platform: caffe
model_file_path: path/to/model.prototxt model_file_path: path/to/model.prototxt
......
...@@ -96,7 +96,6 @@ def generate_random_input(target_soc, model_output_dir, ...@@ -96,7 +96,6 @@ def generate_random_input(target_soc, model_output_dir,
input_file_list[i].startswith("https://"): input_file_list[i].startswith("https://"):
urllib.urlretrieve(input_file_list[i], dst_input_file) urllib.urlretrieve(input_file_list[i], dst_input_file)
else: else:
print 'Copy input data:', dst_input_file
shutil.copy(input_file_list[i], dst_input_file) shutil.copy(input_file_list[i], dst_input_file)
def generate_model_code(): def generate_model_code():
...@@ -294,10 +293,8 @@ def main(unused_args): ...@@ -294,10 +293,8 @@ def main(unused_args):
# Transfer params by environment # Transfer params by environment
os.environ["MODEL_TAG"] = model_name os.environ["MODEL_TAG"] = model_name
print '=======================', model_name, '=======================' print '=======================', model_name, '======================='
skip_validation = configs["models"][model_name].get(
"skip_validation", 0)
model_config = configs["models"][model_name] model_config = configs["models"][model_name]
input_file_list = model_config.get("input_files", []) input_file_list = model_config.get("validation_inputs_data", [])
for key in model_config: for key in model_config:
if key in ['input_nodes', 'output_nodes'] and isinstance( if key in ['input_nodes', 'output_nodes'] and isinstance(
model_config[key], list): model_config[key], list):
...@@ -357,8 +354,7 @@ def main(unused_args): ...@@ -357,8 +354,7 @@ def main(unused_args):
if FLAGS.mode == "benchmark": if FLAGS.mode == "benchmark":
benchmark_model(target_soc, model_output_dir, option_args) benchmark_model(target_soc, model_output_dir, option_args)
if FLAGS.mode == "validate" or (FLAGS.mode == "all" and if FLAGS.mode == "validate" or FLAGS.mode == "all":
skip_validation == 0):
validate_model(target_soc, model_output_dir) validate_model(target_soc, model_output_dir)
if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all": if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册