提交 f7fe515b 编写于 作者: Y yejianwu

format python code by yapf

上级 da3784cd
......@@ -17,6 +17,7 @@ import yaml
from ConfigParser import ConfigParser
def run_command(command):
print("Run command: {}".format(command))
result = subprocess.Popen(
......@@ -80,22 +81,29 @@ def build_mace_run(production_mode, model_output_dir, hexagon_mode):
run_command(command)
def tuning_run(target_soc, model_output_dir, running_round, tuning, production_mode,
restart_round, option_args=''):
def tuning_run(target_soc,
model_output_dir,
running_round,
tuning,
production_mode,
restart_round,
option_args=''):
command = "bash tools/tuning_run.sh {} {} {} {} {} {} \"{}\"".format(
target_soc, model_output_dir, running_round, int(tuning), int(production_mode),
restart_round, option_args)
target_soc, model_output_dir, running_round, int(tuning),
int(production_mode), restart_round, option_args)
run_command(command)
def benchmark_model(model_output_dir, option_args=''):
command = "bash tools/benchmark.sh {} \"{}\"".format(model_output_dir, option_args)
command = "bash tools/benchmark.sh {} \"{}\"".format(model_output_dir,
option_args)
run_command(command)
def run_model(target_soc, model_output_dir, running_round, restart_round, option_args):
tuning_run(target_soc, model_output_dir, running_round, False, False, restart_round,
option_args)
def run_model(target_soc, model_output_dir, running_round, restart_round,
option_args):
tuning_run(target_soc, model_output_dir, running_round, False, False,
restart_round, option_args)
def generate_production_code(target_soc, model_output_dirs, pull_or_not):
......@@ -132,7 +140,7 @@ def build_mace_run_prod(target_soc, model_output_dir, tuning, global_runtime):
def build_run_throughput_test(run_seconds, merged_lib_file, model_input_dir):
command = "bash tools/build_run_throughput_test.sh {} {} {}".format(
run_seconds, merged_lib_file, model_input_dir)
run_seconds, merged_lib_file, model_input_dir)
run_command(command)
......@@ -155,7 +163,7 @@ def merge_libs_and_tuning_results(target_soc, output_dir, model_output_dirs):
model_output_dirs_str = ",".join(model_output_dirs)
command = "bash tools/merge_libs.sh {} {} {}".format(target_soc, output_dir,
model_output_dirs_str)
model_output_dirs_str)
run_command(command)
......@@ -178,14 +186,20 @@ def parse_args():
"--output_dir", type=str, default="build", help="The output dir.")
parser.add_argument(
"--round", type=int, default=1, help="The model running round.")
parser.add_argument("--run_seconds", type=int, default=10,
help="The model throughput test running seconds.")
parser.add_argument(
"--restart_round", type=int, default=1, help="The model restart round.")
"--run_seconds",
type=int,
default=10,
help="The model throughput test running seconds.")
parser.add_argument(
"--restart_round", type=int, default=1, help="The model restart round.")
parser.add_argument(
"--tuning", type="bool", default="true", help="Tune opencl params.")
parser.add_argument("--mode", type=str, default="all",
help="[build|run|validate|merge|all|throughput_test].")
parser.add_argument(
"--mode",
type=str,
default="all",
help="[build|run|validate|merge|all|throughput_test].")
return parser.parse_known_args()
......@@ -198,7 +212,8 @@ def main(unused_args):
os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(FLAGS.config))[0]
os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(
FLAGS.config))[0]
if FLAGS.mode == "build" or FLAGS.mode == "all":
# Remove previous output dirs
......@@ -223,17 +238,21 @@ def main(unused_args):
print '=======================', model_name, '======================='
model_config = configs["models"][model_name]
for key in model_config:
if key in ['input_nodes', 'output_nodes'] and isinstance(model_config[key], list):
os.environ[key.upper()] = ",".join(model_config[key])
elif key in ['input_shapes', 'output_shapes'] and isinstance(model_config[key], list):
os.environ[key.upper()] = ":".join(model_config[key])
if key in ['input_nodes', 'output_nodes'] and isinstance(
model_config[key], list):
os.environ[key.upper()] = ",".join(model_config[key])
elif key in ['input_shapes', 'output_shapes'] and isinstance(
model_config[key], list):
os.environ[key.upper()] = ":".join(model_config[key])
else:
os.environ[key.upper()] = str(model_config[key])
md5 = hashlib.md5()
md5.update(model_config["model_file_path"])
model_path_digest = md5.hexdigest()
model_output_dir = "%s/%s/%s/%s/%s" % (FLAGS.output_dir, model_name, model_path_digest, target_soc, target_abi)
model_output_dir = "%s/%s/%s/%s/%s" % (FLAGS.output_dir, model_name,
model_path_digest, target_soc,
target_abi)
model_output_dirs.append(model_output_dir)
if FLAGS.mode == "build" or FLAGS.mode == "all":
......@@ -244,14 +263,19 @@ def main(unused_args):
# Support http:// and https://
if model_config["model_file_path"].startswith(
"http://") or model_config["model_file_path"].startswith("https://"):
"http://") or model_config["model_file_path"].startswith(
"https://"):
os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb"
urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"])
urllib.urlretrieve(model_config["model_file_path"],
os.environ["MODEL_FILE_PATH"])
if model_config["platform"] == "caffe" and (model_config["weight_file_path"].startswith(
"http://") or model_config["weight_file_path"].startswith("https://")):
os.environ["WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(model_config["weight_file_path"], os.environ["WEIGHT_FILE_PATH"])
if model_config["platform"] == "caffe" and (
model_config["weight_file_path"].startswith("http://") or
model_config["weight_file_path"].startswith("https://")):
os.environ[
"WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
urllib.urlretrieve(model_config["weight_file_path"],
os.environ["WEIGHT_FILE_PATH"])
if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\
or FLAGS.mode == "benchmark" or FLAGS.mode == "all":
......@@ -259,10 +283,12 @@ def main(unused_args):
if FLAGS.mode == "build" or FLAGS.mode == "all":
generate_model_code()
build_mace_run_prod(target_soc, model_output_dir, FLAGS.tuning, global_runtime)
build_mace_run_prod(target_soc, model_output_dir, FLAGS.tuning,
global_runtime)
if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
run_model(target_soc, model_output_dir, FLAGS.round, FLAGS.restart_round, option_args)
run_model(target_soc, model_output_dir, FLAGS.round,
FLAGS.restart_round, option_args)
if FLAGS.mode == "benchmark":
benchmark_model(model_output_dir, option_args)
......@@ -271,8 +297,9 @@ def main(unused_args):
validate_model(target_soc, model_output_dir)
if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
merge_libs_and_tuning_results(target_soc, FLAGS.output_dir + "/" + os.environ["PROJECT_NAME"],
model_output_dirs)
merge_libs_and_tuning_results(
target_soc, FLAGS.output_dir + "/" + os.environ["PROJECT_NAME"],
model_output_dirs)
if FLAGS.mode == "throughput_test":
merged_lib_file = FLAGS.output_dir + "/%s/libmace/lib/libmace_%s.a" % \
......@@ -281,7 +308,8 @@ def main(unused_args):
for model_name in configs["models"]:
runtime = configs["models"][model_name]["runtime"]
os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name
build_run_throughput_test(FLAGS.run_seconds, merged_lib_file, FLAGS.output_dir)
build_run_throughput_test(FLAGS.run_seconds, merged_lib_file,
FLAGS.output_dir)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册