提交 1e24c9ad 编写于 作者: L liuqi

Support target_soc == 'all' for building all SOCs plugged in.

上级 a86482d8
......@@ -47,6 +47,7 @@ CL_PLATFORM_INFO_FILE_NAME = "mace_cl_platform_info.txt"
CODEGEN_BASE_DIR = 'mace/codegen'
MODEL_CODEGEN_DIR = CODEGEN_BASE_DIR + '/models'
MACE_RUN_TARGET = "//mace/tools/validation:mace_run"
ALL_SOC_TAG = 'all'
ABITypeStrs = [
'armeabi-v7a',
......@@ -233,11 +234,19 @@ def format_model_config(config_file_path):
elif not isinstance(target_socs, list):
configs[YAMLKeyword.target_socs] = [target_socs]
configs[YAMLKeyword.target_socs] = \
[soc.lower() for soc in configs[YAMLKeyword.target_socs]]
if ABIType.armeabi_v7a in target_abis \
or ABIType.arm64_v8a in target_abis:
available_socs = sh_commands.adb_get_all_socs()
if YAMLKeyword.target_socs in configs:
target_socs = set(configs[YAMLKeyword.target_socs])
target_socs = configs[YAMLKeyword.target_socs]
if ALL_SOC_TAG in target_socs:
mace_check(available_socs,
ModuleName.YAML_CONFIG,
"Build for all SOCs plugged in computer, "
"you at least plug in one phone")
else:
for soc in target_socs:
mace_check(soc in available_socs,
ModuleName.YAML_CONFIG,
......@@ -670,6 +679,7 @@ def build_specific_lib(target_abi, target_soc, serial_num,
pull_opencl_binary_and_tuning_param(target_abi, serial_num,
[model_output_dir])
sh_commands.touch_tuned_file_flag(build_tmp_binary_dir)
binary_changed = True
if binary_changed:
......@@ -732,6 +742,8 @@ def generate_library(configs, tuning, enable_openmp, address_sanitizer):
build_specific_lib(target_abi, None, None, configs,
tuning, enable_openmp, address_sanitizer)
else:
if ALL_SOC_TAG in target_socs:
target_socs = sh_commands.adb_get_all_socs()
for target_soc in target_socs:
serial_nums = \
sh_commands.get_target_socs_serialnos([target_soc])
......@@ -778,7 +790,8 @@ def report_run_statistics(stdout,
serialno,
model_name,
device_type,
output_dir):
output_dir,
tuned):
metrics = [0] * 3
for line in stdout.split('\n'):
line = line.strip()
......@@ -800,10 +813,10 @@ def report_run_statistics(stdout,
if not os.path.exists(report_filename):
with open(report_filename, 'w') as f:
f.write("model_name,device_name,soc,abi,runtime,"
"init,warmup,run_avg\n")
"init,warmup,run_avg,tuned\n")
data_str = "{model_name},{device_name},{soc},{abi},{device_type}," \
"{init},{warmup},{run_avg}\n" \
"{init},{warmup},{run_avg},{tuned}\n" \
.format(model_name=model_name,
device_name=device_name,
soc=target_soc,
......@@ -811,7 +824,8 @@ def report_run_statistics(stdout,
device_type=device_type,
init=metrics[0],
warmup=metrics[1],
run_avg=metrics[2]
run_avg=metrics[2],
tuned=tuned,
)
with open(report_filename, 'a') as f:
f.write(data_str)
......@@ -929,14 +943,15 @@ def run_specific_target(flags, configs, target_abi,
if flags.report and flags.round > 0:
report_run_statistics(
run_output, target_abi, serial_num,
model_name, device_type, flags.report_dir)
model_name, device_type, flags.report_dir,
sh_commands.is_binary_tuned(build_tmp_binary_dir))
def run_mace(flags):
configs = format_model_config(flags.config)
target_socs = configs[YAMLKeyword.target_socs]
if not target_socs:
if not target_socs or ALL_SOC_TAG in target_socs:
target_socs = sh_commands.adb_get_all_socs()
for target_abi in configs[YAMLKeyword.target_abis]:
......@@ -1043,7 +1058,7 @@ def benchmark_model(flags):
configs = format_model_config(flags.config)
target_socs = configs[YAMLKeyword.target_socs]
if not target_socs:
if not target_socs or ALL_SOC_TAG in target_socs:
target_socs = sh_commands.adb_get_all_socs()
for target_abi in configs[YAMLKeyword.target_abis]:
......
......@@ -520,11 +520,20 @@ def gen_random_input(model_output_dir,
sh.cp("-f", input_file_list[i], dst_input_file)
def update_mace_run_lib(model_output_dir):
mace_run_filepath = model_output_dir + "/mace_run"
def update_mace_run_lib(build_tmp_binary_dir):
mace_run_filepath = build_tmp_binary_dir + "/mace_run"
if os.path.exists(mace_run_filepath):
sh.rm("-rf", mace_run_filepath)
sh.cp("-f", "bazel-bin/mace/tools/validation/mace_run", model_output_dir)
sh.cp("-f", "bazel-bin/mace/tools/validation/mace_run",
build_tmp_binary_dir)
def touch_tuned_file_flag(build_tmp_binary_dir):
sh.touch(build_tmp_binary_dir + '/tuned')
def is_binary_tuned(build_tmp_binary_dir):
return os.path.exists(build_tmp_binary_dir + '/tuned')
def mv_model_file_to_output_dir(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册