mace_tools.py 25.4 KB
Newer Older
Y
yejianwu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright 2018 Xiaomi, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

# python tools/mace_tools.py \
Y
yejianwu 已提交
16
#     --config=tools/example.yaml \
17 18 19 20
#     --round=100 \
#     --mode=all

import argparse
L
liuqi 已提交
21
import enum
Y
yejianwu 已提交
22
import filelock
23
import hashlib
24
import os
L
Liangliang He 已提交
25
import sh
26 27
import subprocess
import sys
28
import urllib
Y
yejianwu 已提交
29
import yaml
L
liuqi 已提交
30
import re
31

L
liuqi 已提交
32
import common
33
import sh_commands
L
Liangliang He 已提交
34

35 36
from ConfigParser import ConfigParser

Y
yejianwu 已提交
37

Y
yejianwu 已提交
38
def get_target_socs(configs):
Y
yejianwu 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    if "host" in configs["target_abis"]:
        return [""]
    else:
        available_socs = sh_commands.adb_get_all_socs()
        target_socs = available_socs
        if hasattr(configs, "target_socs"):
            target_socs = set(configs["target_socs"])
            target_socs = target_socs & available_socs

        if FLAGS.target_socs != "all":
            socs = set(FLAGS.target_socs.split(','))
            target_socs = target_socs & socs
            missing_socs = socs.difference(target_socs)
            if len(missing_socs) > 0:
                print(
                    "Error: devices with SoCs are not connected %s" %
                    missing_socs)
                exit(1)

        if not target_socs:
            print("Error: no device to run")
Y
yejianwu 已提交
60 61
            exit(1)

Y
yejianwu 已提交
62
        return target_socs
Y
yejianwu 已提交
63

64

Y
yejianwu 已提交
65 66 67
def get_data_and_device_type(runtime):
    data_type = ""
    device_type = ""
68

Y
yejianwu 已提交
69 70 71 72 73
    if runtime == "dsp":
        data_type = "DT_UINT8"
        device_type = "HEXAGON"
    elif runtime == "gpu":
        data_type = "DT_HALF"
74
        device_type = "GPU"
Y
yejianwu 已提交
75 76 77
    elif runtime == "cpu":
        data_type = "DT_FLOAT"
        device_type = "CPU"
78

Y
yejianwu 已提交
79
    return data_type, device_type
80

Y
yejianwu 已提交
81 82

def get_hexagon_mode(configs):
L
Liangliang He 已提交
83 84 85 86 87 88 89
    runtime_list = []
    for model_name in configs["models"]:
        model_runtime = configs["models"][model_name]["runtime"]
        runtime_list.append(model_runtime.lower())

    global_runtime = ""
    if "dsp" in runtime_list:
Y
yejianwu 已提交
90 91 92 93
        return True
    return False


W
wuchenghui 已提交
94 95
def gen_opencl_and_tuning_code(target_abi,
                               serialno,
96 97
                               model_output_dirs,
                               pull_or_not):
98
    if pull_or_not:
W
wuchenghui 已提交
99
        sh_commands.pull_binaries(target_abi, serialno, model_output_dirs)
100 101 102 103

    codegen_path = "mace/codegen"

    # generate opencl binary code
W
wuchenghui 已提交
104
    sh_commands.gen_opencl_binary_code(model_output_dirs)
105

W
wuchenghui 已提交
106
    sh_commands.gen_tuning_param_code(model_output_dirs)
Y
yejianwu 已提交
107 108


109 110
def model_benchmark_stdout_processor(stdout,
                                     abi,
W
wuchenghui 已提交
111
                                     serialno,
112
                                     model_name,
W
wuchenghui 已提交
113 114 115
                                     runtime):
    metrics = [0] * 5
    for line in stdout.split('\n'):
L
Liangliang He 已提交
116 117 118
        line = line.strip()
        parts = line.split()
        if len(parts) == 6 and parts[0].startswith("time"):
W
wuchenghui 已提交
119 120 121 122 123 124 125
            metrics[0] = str(float(parts[1]))
            metrics[1] = str(float(parts[2]))
            metrics[2] = str(float(parts[3]))
            metrics[3] = str(float(parts[4]))
            metrics[4] = str(float(parts[5]))
            break

Y
yejianwu 已提交
126 127 128 129 130 131
    device_name = ""
    target_soc = ""
    if abi != "host":
        props = sh_commands.adb_getprop_by_serialno(serialno)
        device_name = props.get("ro.product.model", "")
        target_soc = props.get("ro.board.platform", "")
W
wuchenghui 已提交
132

W
wuchenghui 已提交
133
    report_filename = FLAGS.output_dir + "/report.csv"
W
wuchenghui 已提交
134 135
    if not os.path.exists(report_filename):
        with open(report_filename, 'w') as f:
Y
yejianwu 已提交
136
            f.write("model_name,device_name,soc,abi,runtime,create_net,"
W
wuchenghui 已提交
137 138
                    "engine_ctor,init,warmup,run_avg\n")

Y
yejianwu 已提交
139
    data_str = "{model_name},{device_name},{soc},{abi},{runtime}," \
W
wuchenghui 已提交
140 141 142
               "{create_net},{engine_ctor},{init},{warmup},{run_avg}\n" \
        .format(
            model_name=model_name,
Y
yejianwu 已提交
143
            device_name=device_name,
W
wuchenghui 已提交
144 145 146 147 148 149 150 151 152 153 154
            soc=target_soc,
            abi=abi,
            runtime=runtime,
            create_net=metrics[0],
            engine_ctor=metrics[1],
            init=metrics[2],
            warmup=metrics[3],
            run_avg=metrics[4]
        )
    with open(report_filename, 'a') as f:
        f.write(data_str)
L
Liangliang He 已提交
155

Y
yejianwu 已提交
156

Y
yejianwu 已提交
157 158
def tuning_run(runtime,
               target_abi,
W
wuchenghui 已提交
159
               serialno,
Y
yejianwu 已提交
160 161 162 163 164 165 166 167 168 169 170 171
               vlog_level,
               embed_model_data,
               model_output_dir,
               input_nodes,
               output_nodes,
               input_shapes,
               output_shapes,
               model_name,
               device_type,
               running_round,
               restart_round,
               out_of_range_check,
172
               phone_data_dir,
Y
yejianwu 已提交
173 174
               tuning=False,
               limit_opencl_kernel_time=0,
W
wuchenghui 已提交
175 176 177 178
               omp_num_threads=-1,
               cpu_affinity_policy=1,
               gpu_perf_hint=3,
               gpu_priority_hint=3):
Y
yejianwu 已提交
179
    stdout = sh_commands.tuning_run(
W
wuchenghui 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        target_abi,
        serialno,
        vlog_level,
        embed_model_data,
        model_output_dir,
        input_nodes,
        output_nodes,
        input_shapes,
        output_shapes,
        model_name,
        device_type,
        running_round,
        restart_round,
        limit_opencl_kernel_time,
        tuning,
        out_of_range_check,
        phone_data_dir,
        omp_num_threads,
        cpu_affinity_policy,
        gpu_perf_hint,
        gpu_priority_hint
    )
Y
yejianwu 已提交
202

W
wuchenghui 已提交
203 204 205
    if running_round > 0 and FLAGS.collect_report:
        model_benchmark_stdout_processor(
            stdout, target_abi, serialno, model_name, runtime)
Y
yejianwu 已提交
206

W
wuchenghui 已提交
207 208 209 210 211 212

def build_mace_run_prod(hexagon_mode, runtime, target_abi,
                        serialno, vlog_level, embed_model_data,
                        model_output_dir, input_nodes, output_nodes,
                        input_shapes, output_shapes, model_name, device_type,
                        running_round, restart_round, tuning,
W
wuchenghui 已提交
213 214
                        limit_opencl_kernel_time, phone_data_dir,
                        enable_openmp):
Y
yejianwu 已提交
215
    mace_run_target = "//mace/tools/validation:mace_run"
Y
yejianwu 已提交
216 217 218
    if runtime == "gpu":
        gen_opencl_and_tuning_code(target_abi, serialno, [], False)
        sh_commands.bazel_build(
W
wuchenghui 已提交
219 220 221 222 223 224 225
            mace_run_target,
            abi=target_abi,
            model_tag=model_name,
            production_mode=False,
            hexagon_mode=hexagon_mode,
            enable_openmp=enable_openmp
        )
Y
yejianwu 已提交
226 227 228 229 230 231
        sh_commands.update_mace_run_lib(model_output_dir, target_abi,
                                        model_name, embed_model_data)

        tuning_run(runtime, target_abi, serialno, vlog_level, embed_model_data,
                   model_output_dir, input_nodes, output_nodes, input_shapes,
                   output_shapes, model_name, device_type, running_round=0,
232 233 234
                   restart_round=1, out_of_range_check=False,
                   phone_data_dir=phone_data_dir, tuning=tuning,
                   limit_opencl_kernel_time=limit_opencl_kernel_time)
Y
yejianwu 已提交
235 236 237 238

        tuning_run(runtime, target_abi, serialno, vlog_level, embed_model_data,
                   model_output_dir, input_nodes, output_nodes, input_shapes,
                   output_shapes, model_name, device_type, running_round=0,
239 240
                   restart_round=1, out_of_range_check=True,
                   phone_data_dir=phone_data_dir, tuning=False)
Y
yejianwu 已提交
241 242 243 244

        gen_opencl_and_tuning_code(target_abi, serialno, [model_output_dir],
                                   True)
        sh_commands.bazel_build(
W
wuchenghui 已提交
245 246 247 248 249 250 251
            mace_run_target,
            abi=target_abi,
            model_tag=model_name,
            production_mode=True,
            hexagon_mode=hexagon_mode,
            enable_openmp=enable_openmp
        )
Y
yejianwu 已提交
252 253 254 255 256
        sh_commands.update_mace_run_lib(model_output_dir, target_abi,
                                        model_name, embed_model_data)
    else:
        gen_opencl_and_tuning_code(target_abi, serialno, [], False)
        sh_commands.bazel_build(
W
wuchenghui 已提交
257 258 259 260 261 262 263
            mace_run_target,
            abi=target_abi,
            model_tag=model_name,
            production_mode=True,
            hexagon_mode=hexagon_mode,
            enable_openmp=enable_openmp
        )
Y
yejianwu 已提交
264 265
        sh_commands.update_mace_run_lib(model_output_dir, target_abi,
                                        model_name, embed_model_data)
Y
yejianwu 已提交
266 267 268 269


def merge_libs_and_tuning_results(target_soc,
                                  target_abi,
W
wuchenghui 已提交
270
                                  serialno,
Y
yejianwu 已提交
271 272 273 274 275
                                  project_name,
                                  output_dir,
                                  model_output_dirs,
                                  hexagon_mode,
                                  embed_model_data):
276
    gen_opencl_and_tuning_code(
W
wuchenghui 已提交
277
            target_abi, serialno, model_output_dirs, False)
Y
yejianwu 已提交
278 279 280 281 282 283 284 285 286
    sh_commands.build_production_code(target_abi)

    sh_commands.merge_libs(target_soc,
                           target_abi,
                           project_name,
                           output_dir,
                           model_output_dirs,
                           hexagon_mode,
                           embed_model_data)
L
Liangliang He 已提交
287

Y
yejianwu 已提交
288

Y
yejianwu 已提交
289 290 291
def get_model_files(model_file_path,
                    model_output_dir,
                    weight_file_path=""):
Y
yejianwu 已提交
292 293
    model_file = ""
    weight_file = ""
L
Liangliang He 已提交
294 295
    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
Y
yejianwu 已提交
296 297
        model_file = model_output_dir + "/model.pb"
        urllib.urlretrieve(model_file_path, model_file)
Y
yejianwu 已提交
298 299
    else:
        model_file = model_file_path
L
Liangliang He 已提交
300 301 302

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
Y
yejianwu 已提交
303 304
        weight_file = model_output_dir + "/model.caffemodel"
        urllib.urlretrieve(weight_file_path, weight_file)
Y
yejianwu 已提交
305 306
    else:
        weight_file = weight_file_path
Y
yejianwu 已提交
307 308

    return model_file, weight_file
L
Liangliang He 已提交
309

L
liuqi 已提交
310 311

def md5sum(str):
L
Liangliang He 已提交
312 313 314
    md5 = hashlib.md5()
    md5.update(str)
    return md5.hexdigest()
L
liuqi 已提交
315

316

L
liuqi 已提交
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
################################
# Parsing arguments
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
        return common.CaffeEnvType.DOCKER
    elif v.lower() == 'local':
        return common.CaffeEnvType.LOCAL
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


338
def parse_model_configs():
L
Liangliang He 已提交
339 340 341
    with open(FLAGS.config) as f:
        configs = yaml.load(f)
        return configs
342 343 344


def parse_args():
L
Liangliang He 已提交
345 346 347 348 349 350
    """Parses command line arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config",
        type=str,
        default="./tool/config",
L
liuqi 已提交
351
        required=True,
L
Liangliang He 已提交
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
        help="The global config file of models.")
    parser.add_argument(
        "--output_dir", type=str, default="build", help="The output dir.")
    parser.add_argument(
        "--round", type=int, default=1, help="The model running round.")
    parser.add_argument(
        "--run_seconds",
        type=int,
        default=10,
        help="The model throughput test running seconds.")
    parser.add_argument(
        "--restart_round",
        type=int,
        default=1,
        help="The model restart round.")
    parser.add_argument(
L
liuqi 已提交
368 369 370 371
        "--tuning",
        type=str2bool,
        default=True,
        help="Tune opencl params.")
L
Liangliang He 已提交
372 373 374 375
    parser.add_argument(
        "--mode",
        type=str,
        default="all",
L
liuqi 已提交
376
        help="[build|run|validate|benchmark|merge|all|throughput_test].")
L
Liangliang He 已提交
377 378 379 380 381
    parser.add_argument(
        "--target_socs",
        type=str,
        default="all",
        help="SoCs to build, comma seperated list (getprop ro.board.platform)")
Y
yejianwu 已提交
382 383
    parser.add_argument(
        "--out_of_range_check",
L
liuqi 已提交
384 385
        type=str2bool,
        default=False,
Y
yejianwu 已提交
386
        help="Enable out of range check for opencl.")
W
wuchenghui 已提交
387 388
    parser.add_argument(
        "--enable_openmp",
389 390
        type=str2bool,
        default=True,
W
wuchenghui 已提交
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
        help="Enable openmp.")
    parser.add_argument(
        "--omp_num_threads",
        type=int,
        default=-1,
        help="num of openmp threads")
    parser.add_argument(
        "--cpu_affinity_policy",
        type=int,
        default=1,
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
    parser.add_argument(
        "--gpu_perf_hint",
        type=int,
        default=3,
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
    parser.add_argument(
        "--gpu_priority_hint",
        type=int,
        default=3,
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
W
wuchenghui 已提交
412 413
    parser.add_argument(
        "--collect_report",
L
liuqi 已提交
414 415
        type=str2bool,
        default=False,
W
wuchenghui 已提交
416
        help="Collect report.")
L
Liangliang He 已提交
417 418 419 420 421
    parser.add_argument(
        "--vlog_level",
        type=int,
        default=0,
        help="VLOG level.")
L
liuqi 已提交
422 423 424 425 426
    parser.add_argument(
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
        help="[docker | local] caffe environment.")
L
Liangliang He 已提交
427 428
    return parser.parse_known_args()

429

Y
yejianwu 已提交
430
def process_models(project_name, configs, embed_model_data, vlog_level,
W
wuchenghui 已提交
431
                   target_abi, phone_data_dir, target_soc="", serialno=""):
Y
yejianwu 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
    hexagon_mode = get_hexagon_mode(configs)
    model_output_dirs = []
    for model_name in configs["models"]:
        print '===================', model_name, '==================='
        model_config = configs["models"][model_name]
        input_file_list = model_config.get("validation_inputs_data",
                                           [])
        data_type, device_type = get_data_and_device_type(
                model_config["runtime"])

        for key in ["input_nodes", "output_nodes", "input_shapes",
                    "output_shapes"]:
            if not isinstance(model_config[key], list):
                model_config[key] = [model_config[key]]

        # Create model build directory
        model_path_digest = md5sum(model_config["model_file_path"])
Y
yejianwu 已提交
449 450 451 452 453 454

        if target_abi == "host":
            model_output_dir = "%s/%s/%s/%s/%s/%s" % (
                FLAGS.output_dir, project_name, "build",
                model_name, model_path_digest, target_abi)
        else:
Y
yejianwu 已提交
455
            device_name = sh_commands.adb_get_device_name_by_serialno(serialno)
Y
yejianwu 已提交
456 457 458 459
            model_output_dir = "%s/%s/%s/%s/%s/%s_%s/%s" % (
                FLAGS.output_dir, project_name, "build",
                model_name, model_path_digest, device_name.replace(' ', ''),
                target_soc, target_abi)
Y
yejianwu 已提交
460 461 462 463 464 465
        model_output_dirs.append(model_output_dir)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
            if os.path.exists(model_output_dir):
                sh.rm("-rf", model_output_dir)
            os.makedirs(model_output_dir)
466 467 468

        if FLAGS.mode == "build" or FLAGS.mode == "benchmark" or \
                FLAGS.mode == "all":
469
            sh_commands.clear_mace_run_data(
W
wuchenghui 已提交
470
                    target_abi, serialno, phone_data_dir)
Y
yejianwu 已提交
471

Y
yejianwu 已提交
472
        model_file_path, weight_file_path = get_model_files(
Y
yejianwu 已提交
473 474 475 476 477 478 479 480 481 482 483 484
                model_config["model_file_path"],
                model_output_dir,
                model_config.get("weight_file_path", ""))

        if FLAGS.mode == "build" or FLAGS.mode == "run" or \
                FLAGS.mode == "validate" or \
                FLAGS.mode == "benchmark" or FLAGS.mode == "all":
            sh_commands.gen_random_input(model_output_dir,
                                         model_config["input_nodes"],
                                         model_config["input_shapes"],
                                         input_file_list)

Y
yejianwu 已提交
485 486
        if FLAGS.mode == "build" or FLAGS.mode == "benchmark" or \
                FLAGS.mode == "all":
Y
yejianwu 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
            sh_commands.gen_model_code(
                    "mace/codegen/models/%s" % model_name,
                    model_config["platform"],
                    model_file_path,
                    weight_file_path,
                    model_config["model_sha256_checksum"],
                    ",".join(model_config["input_nodes"]),
                    ",".join(model_config["output_nodes"]),
                    data_type,
                    model_config["runtime"],
                    model_name,
                    ":".join(model_config["input_shapes"]),
                    model_config["dsp_mode"],
                    embed_model_data,
                    model_config["fast_conv"],
                    model_config["obfuscate"])
503 504

        if FLAGS.mode == "build" or FLAGS.mode == "all":
Y
yejianwu 已提交
505 506 507
            build_mace_run_prod(hexagon_mode,
                                model_config["runtime"],
                                target_abi,
W
wuchenghui 已提交
508
                                serialno,
Y
yejianwu 已提交
509 510 511 512 513 514 515 516 517 518 519 520
                                vlog_level,
                                embed_model_data,
                                model_output_dir,
                                model_config["input_nodes"],
                                model_config["output_nodes"],
                                model_config["input_shapes"],
                                model_config["output_shapes"],
                                model_name,
                                device_type,
                                FLAGS.round,
                                FLAGS.restart_round,
                                FLAGS.tuning,
521
                                model_config["limit_opencl_kernel_time"],
W
wuchenghui 已提交
522 523
                                phone_data_dir,
                                FLAGS.enable_openmp)
Y
yejianwu 已提交
524 525 526 527 528

        if FLAGS.mode == "run" or FLAGS.mode == "validate" or \
                FLAGS.mode == "all":
            tuning_run(model_config["runtime"],
                       target_abi,
W
wuchenghui 已提交
529
                       serialno,
Y
yejianwu 已提交
530 531 532 533 534 535 536 537 538 539 540
                       vlog_level,
                       embed_model_data,
                       model_output_dir,
                       model_config["input_nodes"],
                       model_config["output_nodes"],
                       model_config["input_shapes"],
                       model_config["output_shapes"],
                       model_name,
                       device_type,
                       FLAGS.round,
                       FLAGS.restart_round,
541
                       FLAGS.out_of_range_check,
W
wuchenghui 已提交
542 543 544 545 546
                       phone_data_dir,
                       omp_num_threads=FLAGS.omp_num_threads,
                       cpu_affinity_policy=FLAGS.cpu_affinity_policy,
                       gpu_perf_hint=FLAGS.gpu_perf_hint,
                       gpu_priority_hint=FLAGS.gpu_priority_hint)
Y
yejianwu 已提交
547 548

        if FLAGS.mode == "benchmark":
549 550
            gen_opencl_and_tuning_code(
                    target_abi, serialno, [model_output_dir], False)
W
wuchenghui 已提交
551 552
            sh_commands.benchmark_model(target_abi,
                                        serialno,
Y
yejianwu 已提交
553 554 555 556 557 558 559 560 561 562
                                        vlog_level,
                                        embed_model_data,
                                        model_output_dir,
                                        model_config["input_nodes"],
                                        model_config["output_nodes"],
                                        model_config["input_shapes"],
                                        model_config["output_shapes"],
                                        model_name,
                                        device_type,
                                        hexagon_mode,
563
                                        phone_data_dir,
W
wuchenghui 已提交
564 565 566 567
                                        FLAGS.omp_num_threads,
                                        FLAGS.cpu_affinity_policy,
                                        FLAGS.gpu_perf_hint,
                                        FLAGS.gpu_priority_hint)
Y
yejianwu 已提交
568 569

        if FLAGS.mode == "validate" or FLAGS.mode == "all":
W
wuchenghui 已提交
570 571
            sh_commands.validate_model(target_abi,
                                       serialno,
Y
yejianwu 已提交
572 573 574 575 576 577 578 579
                                       model_file_path,
                                       weight_file_path,
                                       model_config["platform"],
                                       model_config["runtime"],
                                       model_config["input_nodes"],
                                       model_config["output_nodes"],
                                       model_config["input_shapes"],
                                       model_config["output_shapes"],
580
                                       model_output_dir,
L
liuqi 已提交
581 582
                                       phone_data_dir,
                                       FLAGS.caffe_env)
Y
yejianwu 已提交
583 584 585 586 587 588

    if FLAGS.mode == "build" or FLAGS.mode == "merge" or \
            FLAGS.mode == "all":
        merge_libs_and_tuning_results(
            target_soc,
            target_abi,
W
wuchenghui 已提交
589
            serialno,
Y
yejianwu 已提交
590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
            project_name,
            FLAGS.output_dir,
            model_output_dirs,
            hexagon_mode,
            embed_model_data)

    if FLAGS.mode == "throughput_test":
        merged_lib_file = FLAGS.output_dir + \
                "/%s/%s/libmace_%s.%s.a" % \
                (project_name, target_abi, project_name, target_soc)
        first_model = configs["models"].values()[0]
        throughput_test_output_dir = "%s/%s/%s/%s" % (
                FLAGS.output_dir, project_name, "build",
                "throughput_test")
        if os.path.exists(throughput_test_output_dir):
            sh.rm("-rf", throughput_test_output_dir)
        os.makedirs(throughput_test_output_dir)
        input_file_list = model_config.get("validation_inputs_data",
                                           [])
        sh_commands.gen_random_input(throughput_test_output_dir,
                                     first_model["input_nodes"],
                                     first_model["input_shapes"],
                                     input_file_list)
        model_tag_dict = {}
        for model_name in configs["models"]:
            runtime = configs["models"][model_name]["runtime"]
            model_tag_dict[runtime] = model_name
W
wuchenghui 已提交
617 618
        sh_commands.build_run_throughput_test(target_abi,
                                              serialno,
Y
yejianwu 已提交
619 620 621 622 623 624 625 626 627 628 629
                                              vlog_level,
                                              FLAGS.run_seconds,
                                              merged_lib_file,
                                              throughput_test_output_dir,
                                              embed_model_data,
                                              model_config["input_nodes"],
                                              model_config["output_nodes"],
                                              model_config["input_shapes"],
                                              model_config["output_shapes"],
                                              model_tag_dict.get("cpu", ""),
                                              model_tag_dict.get("gpu", ""),
630 631
                                              model_tag_dict.get("dsp", ""),
                                              phone_data_dir)
L
Liangliang He 已提交
632

633 634

def main(unused_args):
L
liuqi 已提交
635
    common.init_logging()
L
Liangliang He 已提交
636 637 638 639 640 641
    configs = parse_model_configs()

    if FLAGS.mode == "validate":
        FLAGS.round = 1
        FLAGS.restart_round = 1

Y
yejianwu 已提交
642
    project_name = os.path.splitext(os.path.basename(FLAGS.config))[0]
L
Liangliang He 已提交
643 644 645 646 647
    if FLAGS.mode == "build" or FLAGS.mode == "all":
        # Remove previous output dirs
        if not os.path.exists(FLAGS.output_dir):
            os.makedirs(FLAGS.output_dir)
        elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
Y
yejianwu 已提交
648 649
            sh.rm("-rf", os.path.join(FLAGS.output_dir, project_name))
            os.makedirs(os.path.join(FLAGS.output_dir, project_name))
L
Liangliang He 已提交
650

Y
yejianwu 已提交
651 652 653
        # generate source
        sh_commands.gen_mace_version()
        sh_commands.gen_encrypted_opencl_source()
L
Liangliang He 已提交
654

Y
yejianwu 已提交
655
    target_socs = get_target_socs(configs)
L
Liangliang He 已提交
656

Y
yejianwu 已提交
657
    embed_model_data = configs.get("embed_model_data", 1)
L
Liangliang He 已提交
658
    vlog_level = FLAGS.vlog_level
659
    phone_data_dir = "/data/local/tmp/mace_run/"
L
Liangliang He 已提交
660 661
    for target_abi in configs["target_abis"]:
        for target_soc in target_socs:
W
wuchenghui 已提交
662 663 664 665 666 667 668
            if target_abi != 'host':
                serialnos = sh_commands.get_target_socs_serialnos([target_soc])
                for serialno in serialnos:
                    props = sh_commands.adb_getprop_by_serialno(serialno)
                    print(
                        "===================================================="
                    )
L
Liangliang He 已提交
669
                    print("Trying to lock device %s" % serialno)
W
wuchenghui 已提交
670 671 672 673 674
                    with sh_commands.device_lock(serialno):
                        print("Run on device: %s, %s, %s" % (
                            serialno, props["ro.board.platform"],
                              props["ro.product.model"]))
                        process_models(project_name, configs, embed_model_data,
Y
yejianwu 已提交
675
                                       vlog_level, target_abi, phone_data_dir,
W
wuchenghui 已提交
676
                                       target_soc, serialno)
W
wuchenghui 已提交
677 678 679
            else:
                print("====================================================")
                print("Run on host")
Y
yejianwu 已提交
680
                process_models(project_name, configs, embed_model_data,
W
wuchenghui 已提交
681
                               vlog_level, target_abi, phone_data_dir)
L
Liangliang He 已提交
682 683

    if FLAGS.mode == "build" or FLAGS.mode == "all":
Y
yejianwu 已提交
684
        sh_commands.packaging_lib(FLAGS.output_dir, project_name)
Y
yejianwu 已提交
685

686

Y
yejianwu 已提交
687
if __name__ == "__main__":
L
Liangliang He 已提交
688 689
    FLAGS, unparsed = parse_args()
    main(unused_args=[sys.argv[0]] + unparsed)