mace_tools.py 22.0 KB
Newer Older
Y
yejianwu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright 2018 Xiaomi, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

# python tools/mace_tools.py \
Y
yejianwu 已提交
16
#     --config=tools/example.yaml \
17 18 19 20
#     --round=100 \
#     --mode=all

import argparse
Y
yejianwu 已提交
21
import filelock
22
import hashlib
23
import os
L
Liangliang He 已提交
24
import sh
25 26
import subprocess
import sys
27
import urllib
Y
yejianwu 已提交
28
import yaml
L
liuqi 已提交
29
import re
30

31
import sh_commands
L
Liangliang He 已提交
32

33 34
from ConfigParser import ConfigParser

Y
yejianwu 已提交
35

Y
yejianwu 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
def get_target_socs(configs):
    available_socs = sh_commands.adb_get_all_socs()
    target_socs = available_socs
    if hasattr(configs, "target_socs"):
        target_socs = set(configs["target_socs"])
        target_socs = target_socs & available_socs

    if FLAGS.target_socs != "all":
        socs = set(FLAGS.target_socs.split(','))
        target_socs = target_socs & socs
        missing_socs = socs.difference(target_socs)
        if len(missing_socs) > 0:
            print(
                "Error: devices with SoCs are not connected %s" % missing_socs)
            exit(1)

    if not target_socs:
        print("Error: no device to run")
        exit(1)

    return target_socs

58

Y
yejianwu 已提交
59 60 61
def get_data_and_device_type(runtime):
    data_type = ""
    device_type = ""
62

Y
yejianwu 已提交
63 64 65 66 67 68 69 70 71 72 73 74
    if runtime == "dsp":
        data_type = "DT_UINT8"
        device_type = "HEXAGON"
    elif runtime == "gpu":
        data_type = "DT_HALF"
        device_type = "OPENCL"
    elif runtime == "cpu":
        data_type = "DT_FLOAT"
        device_type = "CPU"
    elif runtime == "neon":
        data_type = "DT_FLOAT"
        device_type = "NEON"
75

Y
yejianwu 已提交
76
    return data_type, device_type
77

Y
yejianwu 已提交
78 79

def get_hexagon_mode(configs):
L
Liangliang He 已提交
80 81 82 83 84 85 86
    runtime_list = []
    for model_name in configs["models"]:
        model_runtime = configs["models"][model_name]["runtime"]
        runtime_list.append(model_runtime.lower())

    global_runtime = ""
    if "dsp" in runtime_list:
Y
yejianwu 已提交
87 88 89 90
        return True
    return False


W
wuchenghui 已提交
91 92
def gen_opencl_and_tuning_code(target_abi,
                               serialno,
93 94
                               model_output_dirs,
                               pull_or_not):
95
    if pull_or_not:
W
wuchenghui 已提交
96
        sh_commands.pull_binaries(target_abi, serialno, model_output_dirs)
97 98 99 100

    codegen_path = "mace/codegen"

    # generate opencl binary code
W
wuchenghui 已提交
101
    sh_commands.gen_opencl_binary_code(model_output_dirs)
102

W
wuchenghui 已提交
103
    sh_commands.gen_tuning_param_code(model_output_dirs)
Y
yejianwu 已提交
104 105


106 107
def model_benchmark_stdout_processor(stdout,
                                     abi,
W
wuchenghui 已提交
108
                                     serialno,
109
                                     model_name,
W
wuchenghui 已提交
110 111 112
                                     runtime):
    metrics = [0] * 5
    for line in stdout.split('\n'):
L
Liangliang He 已提交
113 114 115
        line = line.strip()
        parts = line.split()
        if len(parts) == 6 and parts[0].startswith("time"):
W
wuchenghui 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
            metrics[0] = str(float(parts[1]))
            metrics[1] = str(float(parts[2]))
            metrics[2] = str(float(parts[3]))
            metrics[3] = str(float(parts[4]))
            metrics[4] = str(float(parts[5]))
            break

    props = sh_commands.adb_getprop_by_serialno(serialno)
    device_type = props.get("ro.product.model", "")
    target_soc = props.get("ro.board.platform", "")

    report_filename = "build/report.csv"
    if not os.path.exists(report_filename):
        with open(report_filename, 'w') as f:
            f.write("model_name,device_type,soc,abi,runtime,create_net,"
                    "engine_ctor,init,warmup,run_avg\n")

    data_str = "{model_name},{device_type},{soc},{abi},{runtime}," \
               "{create_net},{engine_ctor},{init},{warmup},{run_avg}\n" \
        .format(
            model_name=model_name,
            device_type=device_type,
            soc=target_soc,
            abi=abi,
            runtime=runtime,
            create_net=metrics[0],
            engine_ctor=metrics[1],
            init=metrics[2],
            warmup=metrics[3],
            run_avg=metrics[4]
        )
    with open(report_filename, 'a') as f:
        f.write(data_str)
L
Liangliang He 已提交
149

Y
yejianwu 已提交
150

Y
yejianwu 已提交
151 152
def tuning_run(runtime,
               target_abi,
W
wuchenghui 已提交
153
               serialno,
Y
yejianwu 已提交
154 155 156 157 158 159 160 161 162 163 164 165
               vlog_level,
               embed_model_data,
               model_output_dir,
               input_nodes,
               output_nodes,
               input_shapes,
               output_shapes,
               model_name,
               device_type,
               running_round,
               restart_round,
               out_of_range_check,
166
               phone_data_dir,
Y
yejianwu 已提交
167 168 169 170 171
               tuning=False,
               limit_opencl_kernel_time=0,
               option_args=""):
    stdout = sh_commands.tuning_run(
            target_abi,
W
wuchenghui 已提交
172
            serialno,
Y
yejianwu 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186
            vlog_level,
            embed_model_data,
            model_output_dir,
            input_nodes,
            output_nodes,
            input_shapes,
            output_shapes,
            model_name,
            device_type,
            running_round,
            restart_round,
            limit_opencl_kernel_time,
            tuning,
            out_of_range_check,
187
            phone_data_dir,
Y
yejianwu 已提交
188 189
            option_args)

W
wuchenghui 已提交
190 191 192
    if running_round > 0 and FLAGS.collect_report:
        model_benchmark_stdout_processor(
            stdout, target_abi, serialno, model_name, runtime)
Y
yejianwu 已提交
193

W
wuchenghui 已提交
194 195 196 197 198 199 200 201

def build_mace_run_prod(hexagon_mode, runtime, target_abi,
                        serialno, vlog_level, embed_model_data,
                        model_output_dir, input_nodes, output_nodes,
                        input_shapes, output_shapes, model_name, device_type,
                        running_round, restart_round, tuning,
                        limit_opencl_kernel_time, phone_data_dir):
    gen_opencl_and_tuning_code(target_abi, serialno, [], False)
L
Liangliang He 已提交
202
    production_or_not = False
Y
yejianwu 已提交
203 204 205 206 207 208 209 210 211 212
    mace_run_target = "//mace/tools/validation:mace_run"
    sh_commands.bazel_build(
            mace_run_target,
            abi=target_abi,
            model_tag=model_name,
            production_mode=False,
            hexagon_mode=hexagon_mode)
    sh_commands.update_mace_run_lib(model_output_dir, target_abi, model_name,
                                    embed_model_data)

W
wuchenghui 已提交
213 214 215 216
    tuning_run(runtime, target_abi, serialno, vlog_level,
               embed_model_data, model_output_dir, input_nodes, output_nodes,
               input_shapes, output_shapes, model_name, device_type,
               running_round=0, restart_round=1, out_of_range_check=True,
217
               phone_data_dir=phone_data_dir, tuning=False)
Y
yejianwu 已提交
218

W
wuchenghui 已提交
219 220 221 222
    tuning_run(runtime, target_abi, serialno, vlog_level,
               embed_model_data, model_output_dir, input_nodes, output_nodes,
               input_shapes, output_shapes, model_name, device_type,
               running_round=0, restart_round=1, out_of_range_check=False,
223
               phone_data_dir=phone_data_dir, tuning=tuning,
Y
yejianwu 已提交
224 225
               limit_opencl_kernel_time=limit_opencl_kernel_time)

W
wuchenghui 已提交
226
    gen_opencl_and_tuning_code(target_abi, serialno, [model_output_dir], True)
L
Liangliang He 已提交
227
    production_or_not = True
Y
yejianwu 已提交
228 229 230 231 232 233 234 235 236 237 238 239
    sh_commands.bazel_build(
            mace_run_target,
            abi=target_abi,
            model_tag=model_name,
            production_mode=True,
            hexagon_mode=hexagon_mode)
    sh_commands.update_mace_run_lib(model_output_dir, target_abi, model_name,
                                    embed_model_data)


def merge_libs_and_tuning_results(target_soc,
                                  target_abi,
W
wuchenghui 已提交
240
                                  serialno,
Y
yejianwu 已提交
241 242 243 244 245
                                  project_name,
                                  output_dir,
                                  model_output_dirs,
                                  hexagon_mode,
                                  embed_model_data):
246
    gen_opencl_and_tuning_code(
W
wuchenghui 已提交
247
            target_abi, serialno, model_output_dirs, False)
Y
yejianwu 已提交
248 249 250 251 252 253 254 255 256
    sh_commands.build_production_code(target_abi)

    sh_commands.merge_libs(target_soc,
                           target_abi,
                           project_name,
                           output_dir,
                           model_output_dirs,
                           hexagon_mode,
                           embed_model_data)
L
Liangliang He 已提交
257

Y
yejianwu 已提交
258

Y
yejianwu 已提交
259 260 261
def get_model_files(model_file_path,
                    model_output_dir,
                    weight_file_path=""):
Y
yejianwu 已提交
262 263
    model_file = ""
    weight_file = ""
L
Liangliang He 已提交
264 265
    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
Y
yejianwu 已提交
266 267
        model_file = model_output_dir + "/model.pb"
        urllib.urlretrieve(model_file_path, model_file)
Y
yejianwu 已提交
268 269
    else:
        model_file = model_file_path
L
Liangliang He 已提交
270 271 272

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
Y
yejianwu 已提交
273 274
        weight_file = model_output_dir + "/model.caffemodel"
        urllib.urlretrieve(weight_file_path, weight_file)
Y
yejianwu 已提交
275 276
    else:
        weight_file = weight_file_path
Y
yejianwu 已提交
277 278

    return model_file, weight_file
L
Liangliang He 已提交
279

L
liuqi 已提交
280 281

def md5sum(str):
L
Liangliang He 已提交
282 283 284
    md5 = hashlib.md5()
    md5.update(str)
    return md5.hexdigest()
L
liuqi 已提交
285

286 287

def parse_model_configs():
L
Liangliang He 已提交
288 289 290
    with open(FLAGS.config) as f:
        configs = yaml.load(f)
        return configs
291 292 293


def parse_args():
L
Liangliang He 已提交
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
    """Parses command line arguments."""
    parser = argparse.ArgumentParser()
    parser.register("type", "bool", lambda v: v.lower() == "true")
    parser.add_argument(
        "--config",
        type=str,
        default="./tool/config",
        help="The global config file of models.")
    parser.add_argument(
        "--output_dir", type=str, default="build", help="The output dir.")
    parser.add_argument(
        "--round", type=int, default=1, help="The model running round.")
    parser.add_argument(
        "--run_seconds",
        type=int,
        default=10,
        help="The model throughput test running seconds.")
    parser.add_argument(
        "--restart_round",
        type=int,
        default=1,
        help="The model restart round.")
    parser.add_argument(
        "--tuning", type="bool", default="true", help="Tune opencl params.")
    parser.add_argument(
        "--mode",
        type=str,
        default="all",
        help="[build|run|validate|merge|all|throughput_test].")
    parser.add_argument(
        "--target_socs",
        type=str,
        default="all",
        help="SoCs to build, comma seperated list (getprop ro.board.platform)")
Y
yejianwu 已提交
328 329 330 331 332
    parser.add_argument(
        "--out_of_range_check",
        type="bool",
        default="false",
        help="Enable out of range check for opencl.")
W
wuchenghui 已提交
333 334 335 336 337
    parser.add_argument(
        "--collect_report",
        type="bool",
        default="false",
        help="Collect report.")
L
Liangliang He 已提交
338 339
    return parser.parse_known_args()

340

Y
yejianwu 已提交
341
def process_models(project_name, configs, embed_model_data, vlog_level,
W
wuchenghui 已提交
342 343
                   target_soc, target_abi, serialno, phone_data_dir,
                   option_args):
Y
yejianwu 已提交
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
    hexagon_mode = get_hexagon_mode(configs)
    model_output_dirs = []
    for model_name in configs["models"]:
        print '===================', model_name, '==================='
        model_config = configs["models"][model_name]
        input_file_list = model_config.get("validation_inputs_data",
                                           [])
        data_type, device_type = get_data_and_device_type(
                model_config["runtime"])

        for key in ["input_nodes", "output_nodes", "input_shapes",
                    "output_shapes"]:
            if not isinstance(model_config[key], list):
                model_config[key] = [model_config[key]]

        # Create model build directory
        model_path_digest = md5sum(model_config["model_file_path"])
W
wuchenghui 已提交
361 362
        device_name = sh_commands.adb_get_device_name_by_serialno(serialno)
        model_output_dir = "%s/%s/%s/%s/%s/%s_%s/%s" % (
Y
yejianwu 已提交
363
            FLAGS.output_dir, project_name, "build",
W
wuchenghui 已提交
364 365
            model_name, model_path_digest, device_name.replace(' ', ''),
            target_soc, target_abi)
Y
yejianwu 已提交
366 367 368 369 370 371
        model_output_dirs.append(model_output_dir)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
            if os.path.exists(model_output_dir):
                sh.rm("-rf", model_output_dir)
            os.makedirs(model_output_dir)
372
            sh_commands.clear_mace_run_data(
W
wuchenghui 已提交
373
                    target_abi, serialno, phone_data_dir)
Y
yejianwu 已提交
374

Y
yejianwu 已提交
375
        model_file_path, weight_file_path = get_model_files(
Y
yejianwu 已提交
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
                model_config["model_file_path"],
                model_output_dir,
                model_config.get("weight_file_path", ""))

        if FLAGS.mode == "build" or FLAGS.mode == "run" or \
                FLAGS.mode == "validate" or \
                FLAGS.mode == "benchmark" or FLAGS.mode == "all":
            sh_commands.gen_random_input(model_output_dir,
                                         model_config["input_nodes"],
                                         model_config["input_shapes"],
                                         input_file_list)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
            sh_commands.gen_model_code(
                    "mace/codegen/models/%s" % model_name,
                    model_config["platform"],
                    model_file_path,
                    weight_file_path,
                    model_config["model_sha256_checksum"],
                    ",".join(model_config["input_nodes"]),
                    ",".join(model_config["output_nodes"]),
                    data_type,
                    model_config["runtime"],
                    model_name,
                    ":".join(model_config["input_shapes"]),
                    model_config["dsp_mode"],
                    embed_model_data,
                    model_config["fast_conv"],
                    model_config["obfuscate"])
            build_mace_run_prod(hexagon_mode,
                                model_config["runtime"],
                                target_abi,
W
wuchenghui 已提交
408
                                serialno,
Y
yejianwu 已提交
409 410 411 412 413 414 415 416 417 418 419 420
                                vlog_level,
                                embed_model_data,
                                model_output_dir,
                                model_config["input_nodes"],
                                model_config["output_nodes"],
                                model_config["input_shapes"],
                                model_config["output_shapes"],
                                model_name,
                                device_type,
                                FLAGS.round,
                                FLAGS.restart_round,
                                FLAGS.tuning,
421 422
                                model_config["limit_opencl_kernel_time"],
                                phone_data_dir)
Y
yejianwu 已提交
423 424 425 426 427

        if FLAGS.mode == "run" or FLAGS.mode == "validate" or \
                FLAGS.mode == "all":
            tuning_run(model_config["runtime"],
                       target_abi,
W
wuchenghui 已提交
428
                       serialno,
Y
yejianwu 已提交
429 430 431 432 433 434 435 436 437 438 439
                       vlog_level,
                       embed_model_data,
                       model_output_dir,
                       model_config["input_nodes"],
                       model_config["output_nodes"],
                       model_config["input_shapes"],
                       model_config["output_shapes"],
                       model_name,
                       device_type,
                       FLAGS.round,
                       FLAGS.restart_round,
440 441
                       FLAGS.out_of_range_check,
                       phone_data_dir)
Y
yejianwu 已提交
442 443

        if FLAGS.mode == "benchmark":
W
wuchenghui 已提交
444 445
            sh_commands.benchmark_model(target_abi,
                                        serialno,
Y
yejianwu 已提交
446 447 448 449 450 451 452 453 454 455
                                        vlog_level,
                                        embed_model_data,
                                        model_output_dir,
                                        model_config["input_nodes"],
                                        model_config["output_nodes"],
                                        model_config["input_shapes"],
                                        model_config["output_shapes"],
                                        model_name,
                                        device_type,
                                        hexagon_mode,
456
                                        phone_data_dir,
Y
yejianwu 已提交
457 458 459
                                        option_args)

        if FLAGS.mode == "validate" or FLAGS.mode == "all":
W
wuchenghui 已提交
460 461
            sh_commands.validate_model(target_abi,
                                       serialno,
Y
yejianwu 已提交
462 463 464 465 466 467 468 469
                                       model_file_path,
                                       weight_file_path,
                                       model_config["platform"],
                                       model_config["runtime"],
                                       model_config["input_nodes"],
                                       model_config["output_nodes"],
                                       model_config["input_shapes"],
                                       model_config["output_shapes"],
470 471
                                       model_output_dir,
                                       phone_data_dir)
Y
yejianwu 已提交
472 473 474 475 476 477

    if FLAGS.mode == "build" or FLAGS.mode == "merge" or \
            FLAGS.mode == "all":
        merge_libs_and_tuning_results(
            target_soc,
            target_abi,
W
wuchenghui 已提交
478
            serialno,
Y
yejianwu 已提交
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
            project_name,
            FLAGS.output_dir,
            model_output_dirs,
            hexagon_mode,
            embed_model_data)

    if FLAGS.mode == "throughput_test":
        merged_lib_file = FLAGS.output_dir + \
                "/%s/%s/libmace_%s.%s.a" % \
                (project_name, target_abi, project_name, target_soc)
        first_model = configs["models"].values()[0]
        throughput_test_output_dir = "%s/%s/%s/%s" % (
                FLAGS.output_dir, project_name, "build",
                "throughput_test")
        if os.path.exists(throughput_test_output_dir):
            sh.rm("-rf", throughput_test_output_dir)
        os.makedirs(throughput_test_output_dir)
        input_file_list = model_config.get("validation_inputs_data",
                                           [])
        sh_commands.gen_random_input(throughput_test_output_dir,
                                     first_model["input_nodes"],
                                     first_model["input_shapes"],
                                     input_file_list)
        model_tag_dict = {}
        for model_name in configs["models"]:
            runtime = configs["models"][model_name]["runtime"]
            model_tag_dict[runtime] = model_name
W
wuchenghui 已提交
506 507
        sh_commands.build_run_throughput_test(target_abi,
                                              serialno,
Y
yejianwu 已提交
508 509 510 511 512 513 514 515 516 517 518
                                              vlog_level,
                                              FLAGS.run_seconds,
                                              merged_lib_file,
                                              throughput_test_output_dir,
                                              embed_model_data,
                                              model_config["input_nodes"],
                                              model_config["output_nodes"],
                                              model_config["input_shapes"],
                                              model_config["output_shapes"],
                                              model_tag_dict.get("cpu", ""),
                                              model_tag_dict.get("gpu", ""),
519 520
                                              model_tag_dict.get("dsp", ""),
                                              phone_data_dir)
L
Liangliang He 已提交
521

522 523

def main(unused_args):
L
Liangliang He 已提交
524 525 526 527 528 529
    configs = parse_model_configs()

    if FLAGS.mode == "validate":
        FLAGS.round = 1
        FLAGS.restart_round = 1

Y
yejianwu 已提交
530
    project_name = os.path.splitext(os.path.basename(FLAGS.config))[0]
L
Liangliang He 已提交
531 532 533 534 535
    if FLAGS.mode == "build" or FLAGS.mode == "all":
        # Remove previous output dirs
        if not os.path.exists(FLAGS.output_dir):
            os.makedirs(FLAGS.output_dir)
        elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
Y
yejianwu 已提交
536 537
            sh.rm("-rf", os.path.join(FLAGS.output_dir, project_name))
            os.makedirs(os.path.join(FLAGS.output_dir, project_name))
L
Liangliang He 已提交
538

Y
yejianwu 已提交
539 540 541
        # generate source
        sh_commands.gen_mace_version()
        sh_commands.gen_encrypted_opencl_source()
L
Liangliang He 已提交
542 543 544 545

    option_args = ' '.join(
        [arg for arg in unused_args if arg.startswith('--')])

Y
yejianwu 已提交
546
    target_socs = get_target_socs(configs)
L
Liangliang He 已提交
547

Y
yejianwu 已提交
548 549
    embed_model_data = configs.get("embed_model_data", 1)
    vlog_level = configs.get("vlog_level", 0)
550
    phone_data_dir = "/data/local/tmp/mace_run/"
L
Liangliang He 已提交
551 552
    for target_soc in target_socs:
        for target_abi in configs["target_abis"]:
W
wuchenghui 已提交
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
            if target_abi != 'host':
                serialnos = sh_commands.get_target_socs_serialnos([target_soc])
                for serialno in serialnos:
                    props = sh_commands.adb_getprop_by_serialno(serialno)
                    print(
                        "===================================================="
                    )
                    print("Trying to lock device", serialno)
                    with sh_commands.device_lock(serialno):
                        print("Run on device: %s, %s, %s" % (
                            serialno, props["ro.board.platform"],
                              props["ro.product.model"]))
                        process_models(project_name, configs, embed_model_data,
                                       vlog_level, target_soc, target_abi,
                                       serialno, phone_data_dir, option_args)
            else:
                print("====================================================")
                print("Run on host")
Y
yejianwu 已提交
571
                process_models(project_name, configs, embed_model_data,
W
wuchenghui 已提交
572
                               vlog_level, target_soc, target_abi, '',
573
                               phone_data_dir, option_args)
L
Liangliang He 已提交
574 575

    if FLAGS.mode == "build" or FLAGS.mode == "all":
Y
yejianwu 已提交
576
        sh_commands.packaging_lib(FLAGS.output_dir, project_name)
Y
yejianwu 已提交
577

578

Y
yejianwu 已提交
579
if __name__ == "__main__":
L
Liangliang He 已提交
580 581
    FLAGS, unparsed = parse_args()
    main(unused_args=[sys.argv[0]] + unparsed)