converter.py 51.4 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
Y
yejianwu 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

import argparse
L
liuqi 已提交
16
import glob
L
Liangliang He 已提交
17
import sh
18
import sys
19
import time
Y
yejianwu 已提交
20
import yaml
L
liuqi 已提交
21

22
from enum import Enum
23
import six
24

25
import sh_commands
L
Liangliang He 已提交
26

L
liuqi 已提交
27 28
from common import *
from device import DeviceWrapper, DeviceManager
29

Y
yejianwu 已提交
30 31 32 33
################################
# set environment
################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
A
Allen 已提交
34

35 36 37 38 39
################################
# common definitions
################################

ABITypeStrs = [
L
liuqi 已提交
40 41
    'armeabi-v7a',
    'arm64-v8a',
L
liuqi 已提交
42 43
    'arm64',
    'armhf',
L
liuqi 已提交
44
    'host',
45
]
L
liuqi 已提交
46

47 48 49 50 51
ModelFormatStrs = [
    "file",
    "code",
]

52 53 54
PlatformTypeStrs = [
    "tensorflow",
    "caffe",
L
liutuo 已提交
55
    "onnx",
56 57 58 59 60 61 62 63
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
                    type=str)

RuntimeTypeStrs = [
    "cpu",
    "gpu",
    "dsp",
B
Bin Li 已提交
64
    "hta",
65
    "apu",
66 67 68
    "cpu+gpu"
]

L
liuqi 已提交
69
InOutDataTypeStrs = [
Y
yejianwu 已提交
70 71 72 73
    "int32",
    "float32",
]

L
liuqi 已提交
74 75
InOutDataType = Enum('InputDataType',
                     [(ele, ele) for ele in InOutDataTypeStrs],
Y
yejianwu 已提交
76 77
                     type=str)

L
liuqi 已提交
78
FPDataTypeStrs = [
79 80 81 82
    "fp16_fp32",
    "fp32_fp32",
]

L
liuqi 已提交
83 84
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
                  type=str)
85

L
liuqi 已提交
86 87 88 89 90 91 92
DSPDataTypeStrs = [
    "uint8",
]

DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
                   type=str)

93 94 95 96 97 98 99
APUDataTypeStrs = [
    "uint8",
]

APUDataType = Enum('APUDataType', [(ele, ele) for ele in APUDataTypeStrs],
                   type=str)

100 101
WinogradParameters = [0, 2, 4]

102 103 104
DataFormatStrs = [
    "NONE",
    "NHWC",
105
    "NCHW",
106
    "OIHW",
107 108 109
]


110
class DefaultValues(object):
111
    mace_lib_type = MACELibType.static
112 113 114 115 116 117
    omp_num_threads = -1,
    cpu_affinity_policy = 1,
    gpu_perf_hint = 3,
    gpu_priority_hint = 3,


118 119 120
class ValidationThreshold(object):
    cpu_threshold = 0.999,
    gpu_threshold = 0.995,
B
Bin Li 已提交
121
    quantize_threshold = 0.980,
122 123


124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
CPP_KEYWORDS = [
    'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
    'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor',
    'bool', 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t',
    'class', 'compl', 'concept', 'const', 'constexpr', 'const_cast',
    'continue', 'co_await', 'co_return', 'co_yield', 'decltype', 'default',
    'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit',
    'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if',
    'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace',
    'new', 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq',
    'private', 'protected', 'public', 'register', 'reinterpret_cast',
    'requires', 'return', 'short', 'signed', 'sizeof', 'static',
    'static_assert', 'static_cast', 'struct', 'switch', 'synchronized',
    'template', 'this', 'thread_local', 'throw', 'true', 'try', 'typedef',
    'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void',
    'volatile', 'wchar_t', 'while', 'xor', 'xor_eq', 'override', 'final',
    'transaction_safe', 'transaction_safe_dynamic', 'if', 'elif', 'else',
    'endif', 'defined', 'ifdef', 'ifndef', 'define', 'undef', 'include',
    'line', 'error', 'pragma',
]
Y
yejianwu 已提交
144

145

146 147 148
################################
# common functions
################################
149
def parse_device_type(runtime):
Y
yejianwu 已提交
150
    device_type = ""
151

152
    if runtime == RuntimeType.dsp:
153
        device_type = DeviceType.HEXAGON
B
Bin Li 已提交
154 155
    elif runtime == RuntimeType.hta:
        device_type = DeviceType.HTA
156
    elif runtime == RuntimeType.gpu:
157
        device_type = DeviceType.GPU
158
    elif runtime == RuntimeType.cpu:
159
        device_type = DeviceType.CPU
160 161
    elif runtime == RuntimeType.apu:
        device_type = DeviceType.APU
162

163
    return device_type
164

Y
yejianwu 已提交
165 166

def get_hexagon_mode(configs):
L
Liangliang He 已提交
167
    runtime_list = []
L
liuqi 已提交
168
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
169
        model_runtime = \
L
liuqi 已提交
170 171
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
L
Liangliang He 已提交
172 173
        runtime_list.append(model_runtime.lower())

L
liuqi 已提交
174
    if RuntimeType.dsp in runtime_list:
Y
yejianwu 已提交
175 176 177 178
        return True
    return False


B
Bin Li 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191
def get_hta_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
        model_runtime = \
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.hta in runtime_list:
        return True
    return False


Y
yejianwu 已提交
192 193 194
def get_opencl_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
195
        model_runtime = \
Y
yejianwu 已提交
196 197 198 199 200 201 202 203 204
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.gpu in runtime_list or RuntimeType.cpu_gpu in runtime_list:
        return True
    return False


205 206 207 208 209 210 211 212 213 214 215
def get_quantize_mode(configs):
    for model_name in configs[YAMLKeyword.models]:
        quantize =\
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.quantize, 0)
        if quantize == 1:
            return True

    return False


216 217 218 219 220 221 222 223 224
def get_symbol_hidden_mode(debug_mode, mace_lib_type=None):
    if not mace_lib_type:
        return True
    if debug_mode or mace_lib_type == MACELibType.dynamic:
        return False
    else:
        return True


225 226
def md5sum(str):
    md5 = hashlib.md5()
227
    md5.update(str.encode('utf-8'))
228
    return md5.hexdigest()
229

Y
yejianwu 已提交
230

231 232 233 234 235 236
def sha256_checksum(fname):
    hash_func = hashlib.sha256()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_func.update(chunk)
    return hash_func.hexdigest()
Y
yejianwu 已提交
237

W
wuchenghui 已提交
238

239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
def download_file(url, dst, num_retries=3):
    from six.moves import urllib

    try:
        urllib.request.urlretrieve(url, dst)
        MaceLogger.info('\nDownloaded successfully.')
    except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
            urllib.error.URLError) as e:
        MaceLogger.warning('Download error:' + str(e))
        if num_retries > 0:
            return download_file(url, dst, num_retries - 1)
        else:
            return False
    return True


def get_model_files(model_config, model_output_dir):
    if not os.path.exists(model_output_dir):
        os.makedirs(model_output_dir)
    model_file_path = model_config[YAMLKeyword.model_file_path]
    model_sha256_checksum = model_config[YAMLKeyword.model_sha256_checksum]
    weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
    weight_sha256_checksum = model_config.get(YAMLKeyword.weight_sha256_checksum, "")  # noqa
    quantize_range_file_path = model_config.get(YAMLKeyword.quantize_range_file, "")  # noqa
    model_file = model_file_path
    weight_file = weight_file_path
    quantize_range_file = quantize_range_file_path

    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
        model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
        if not os.path.exists(model_file) or \
                sha256_checksum(model_file) != model_sha256_checksum:
            MaceLogger.info("Downloading model, please wait ...")
            if not download_file(model_file_path, model_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
        model_config[YAMLKeyword.model_file_path] = model_file

    if sha256_checksum(model_file) != model_sha256_checksum:
        MaceLogger.error(ModuleName.MODEL_CONVERTER,
                         "model file sha256checksum not match")

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
        weight_file = \
            model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
        if not os.path.exists(weight_file) or \
                sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.info("Downloading model weight, please wait ...")
            if not download_file(weight_file_path, weight_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
    model_config[YAMLKeyword.weight_file_path] = weight_file

    if weight_file:
        if sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "weight file sha256checksum not match")

    if quantize_range_file_path.startswith("http://") or \
            quantize_range_file_path.startswith("https://"):
        quantize_range_file = \
            model_output_dir + "/" + md5sum(quantize_range_file_path) \
            + ".range"
        if not download_file(quantize_range_file_path, quantize_range_file):
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "Model range file download failed.")
    model_config[YAMLKeyword.quantize_range_file] = quantize_range_file


310 311
def format_model_config(flags):
    with open(flags.config) as f:
312
        configs = yaml.load(f)
W
wuchenghui 已提交
313

314 315
    library_name = configs.get(YAMLKeyword.library_name, "")
    mace_check(len(library_name) > 0,
L
liuqi 已提交
316
               ModuleName.YAML_CONFIG, "library name should not be empty")
317

318 319 320 321
    if flags.target_abis:
        target_abis = flags.target_abis.split(',')
    else:
        target_abis = configs.get(YAMLKeyword.target_abis, [])
322 323
    mace_check((isinstance(target_abis, list) and len(target_abis) > 0),
               ModuleName.YAML_CONFIG, "target_abis list is needed")
324
    configs[YAMLKeyword.target_abis] = target_abis
325 326 327 328 329 330
    for abi in target_abis:
        mace_check(abi in ABITypeStrs,
                   ModuleName.YAML_CONFIG,
                   "target_abis must be in " + str(ABITypeStrs))

    target_socs = configs.get(YAMLKeyword.target_socs, "")
L
liuqi 已提交
331 332
    if flags.target_socs and flags.target_socs != TargetSOCTag.random \
            and flags.target_socs != TargetSOCTag.all:
333
        configs[YAMLKeyword.target_socs] = \
L
liuqi 已提交
334
            [soc.lower() for soc in flags.target_socs.split(',')]
335
    elif not target_socs:
336 337 338 339
        configs[YAMLKeyword.target_socs] = []
    elif not isinstance(target_socs, list):
        configs[YAMLKeyword.target_socs] = [target_socs]

340 341 342
    configs[YAMLKeyword.target_socs] = \
        [soc.lower() for soc in configs[YAMLKeyword.target_socs]]

L
liuqi 已提交
343 344
    if ABIType.armeabi_v7a in target_abis \
            or ABIType.arm64_v8a in target_abis:
345
        available_socs = sh_commands.adb_get_all_socs()
346
        target_socs = configs[YAMLKeyword.target_socs]
L
liuqi 已提交
347
        if TargetSOCTag.all in target_socs:
348 349
            mace_check(available_socs,
                       ModuleName.YAML_CONFIG,
L
liuqi 已提交
350 351 352
                       "Android abi is listed in config file and "
                       "build for all SOCs plugged in computer, "
                       "But no android phone found, "
353 354
                       "you at least plug in one phone")
        else:
355 356 357 358 359 360
            for soc in target_socs:
                mace_check(soc in available_socs,
                           ModuleName.YAML_CONFIG,
                           "Build specified SOC library, "
                           "you must plug in a phone using the SOC")

361 362
    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
363
    else:
364 365 366 367 368 369 370 371
        model_graph_format = configs.get(YAMLKeyword.model_graph_format, "")
    mace_check(model_graph_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_graph_format and '
               "model_graph_format must be in " + str(ModelFormatStrs))
    configs[YAMLKeyword.model_graph_format] = model_graph_format
    if flags.model_data_format:
        model_data_format = flags.model_data_format
372
    else:
373 374 375 376 377 378 379 380 381 382 383 384
        model_data_format = configs.get(YAMLKeyword.model_data_format, "")
    configs[YAMLKeyword.model_data_format] = model_data_format
    mace_check(model_data_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_data_format and '
               "model_data_format must be in " + str(ModelFormatStrs))

    mace_check(not (model_graph_format == ModelFormat.file
                    and model_data_format == ModelFormat.code),
               ModuleName.YAML_CONFIG,
               "If model_graph format is 'file',"
               " the model_data_format must be 'file' too")
Y
yejianwu 已提交
385

386 387 388 389
    model_names = configs.get(YAMLKeyword.models, [])
    mace_check(len(model_names) > 0, ModuleName.YAML_CONFIG,
               "no model found in config file")

L
liuqi 已提交
390
    model_name_reg = re.compile(r'^[a-zA-Z0-9_]+$')
391 392 393 394 395 396 397 398
    for model_name in model_names:
        # check model_name legality
        mace_check(model_name not in CPP_KEYWORDS,
                   ModuleName.YAML_CONFIG,
                   "model name should not be c++ keyword.")
        mace_check((model_name[0] == '_' or model_name[0].isalpha())
                   and bool(model_name_reg.match(model_name)),
                   ModuleName.YAML_CONFIG,
L
liuqi 已提交
399
                   "model name should Meet the c++ naming convention"
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
                   " which start with '_' or alpha"
                   " and only contain alpha, number and '_'")

        model_config = configs[YAMLKeyword.models][model_name]
        platform = model_config.get(YAMLKeyword.platform, "")
        mace_check(platform in PlatformTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'platform' must be in " + str(PlatformTypeStrs))

        for key in [YAMLKeyword.model_file_path,
                    YAMLKeyword.model_sha256_checksum]:
            value = model_config.get(key, "")
            mace_check(value != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" % key)

        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        if weight_file_path:
L
liuqi 已提交
417
            weight_checksum = \
418 419 420 421 422 423 424
                model_config.get(YAMLKeyword.weight_sha256_checksum, "")
            mace_check(weight_checksum != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" %
                       YAMLKeyword.weight_sha256_checksum)
        else:
            model_config[YAMLKeyword.weight_sha256_checksum] = ""

425 426
        get_model_files(model_config, BUILD_DOWNLOADS_DIR)

427 428 429 430 431 432 433 434 435 436
        runtime = model_config.get(YAMLKeyword.runtime, "")
        mace_check(runtime in RuntimeTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'runtime' must be in " + str(RuntimeTypeStrs))
        if ABIType.host in target_abis:
            mace_check(runtime == RuntimeType.cpu,
                       ModuleName.YAML_CONFIG,
                       "host only support cpu runtime now.")

        data_type = model_config.get(YAMLKeyword.data_type, "")
L
liuqi 已提交
437
        if runtime == RuntimeType.dsp:
L
liuqi 已提交
438 439 440 441 442 443 444 445
            if len(data_type) > 0:
                mace_check(data_type in DSPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(DSPDataTypeStrs)
                           + " for dsp runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    DSPDataType.uint8.value
446 447 448 449 450 451 452 453 454
        elif runtime == RuntimeType.apu:
            if len(data_type) > 0:
                mace_check(data_type in APUDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(APUDataTypeStrs)
                           + " for apu runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    APUDataType.uint8.value
L
liuqi 已提交
455 456 457 458 459 460 461 462 463 464 465 466 467
        else:
            if len(data_type) > 0:
                mace_check(data_type in FPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(FPDataTypeStrs)
                           + " for cpu runtime")
            else:
                if runtime == RuntimeType.cpu:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp32_fp32.value
                else:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp16_fp32.value
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482

        subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
        mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
                   "at least one subgraph is needed")

        for subgraph in subgraphs:
            for key in [YAMLKeyword.input_tensors,
                        YAMLKeyword.input_shapes,
                        YAMLKeyword.output_tensors,
                        YAMLKeyword.output_shapes]:
                value = subgraph.get(key, "")
                mace_check(value != "", ModuleName.YAML_CONFIG,
                           "'%s' is necessary in subgraph" % key)
                if not isinstance(value, list):
                    subgraph[key] = [value]
483
                subgraph[key] = [str(v) for v in subgraph[key]]
484 485 486 487 488 489 490 491 492
            input_size = len(subgraph[YAMLKeyword.input_tensors])
            output_size = len(subgraph[YAMLKeyword.output_tensors])

            mace_check(len(subgraph[YAMLKeyword.input_shapes]) == input_size,
                       ModuleName.YAML_CONFIG,
                       "input shapes' size not equal inputs' size.")
            mace_check(len(subgraph[YAMLKeyword.output_shapes]) == output_size,
                       ModuleName.YAML_CONFIG,
                       "output shapes' size not equal outputs' size.")
493

B
Bin Li 已提交
494 495 496 497 498 499 500 501 502 503
            for key in [YAMLKeyword.check_tensors,
                        YAMLKeyword.check_shapes]:
                value = subgraph.get(key, "")
                if value != "":
                    if not isinstance(value, list):
                        subgraph[key] = [value]
                    subgraph[key] = [str(v) for v in subgraph[key]]
                else:
                    subgraph[key] = []

L
liuqi 已提交
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
            for key in [YAMLKeyword.input_data_types,
                        YAMLKeyword.output_data_types]:
                if key == YAMLKeyword.input_data_types:
                    count = input_size
                else:
                    count = output_size
                data_types = subgraph.get(key, "")
                if data_types:
                    if not isinstance(data_types, list):
                        subgraph[key] = [data_types] * count
                    for data_type in subgraph[key]:
                        mace_check(data_type in InOutDataTypeStrs,
                                   ModuleName.YAML_CONFIG,
                                   key + " must be in "
                                   + str(InOutDataTypeStrs))
                else:
                    subgraph[key] = [InOutDataType.float32] * count
Y
yejianwu 已提交
521

522 523 524 525 526
            input_data_formats = subgraph.get(YAMLKeyword.input_data_formats,
                                              [])
            if input_data_formats:
                if not isinstance(input_data_formats, list):
                    subgraph[YAMLKeyword.input_data_formats] =\
527
                        [input_data_formats] * input_size
528 529
                else:
                    mace_check(len(input_data_formats)
530
                               == input_size,
531 532
                               ModuleName.YAML_CONFIG,
                               "input_data_formats should match"
533
                               " the size of input.")
534 535 536 537 538 539
                for input_data_format in\
                        subgraph[YAMLKeyword.input_data_formats]:
                    mace_check(input_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs) + ", but got "
Y
yejianwu 已提交
540
                               + input_data_format)
541
            else:
542 543
                subgraph[YAMLKeyword.input_data_formats] = \
                    [DataFormat.NHWC] * input_size
544 545 546 547 548 549

            output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
                                               [])
            if output_data_formats:
                if not isinstance(output_data_formats, list):
                    subgraph[YAMLKeyword.output_data_formats] = \
550
                        [output_data_formats] * output_size
551 552
                else:
                    mace_check(len(output_data_formats)
553
                               == output_size,
554 555 556 557 558 559 560
                               ModuleName.YAML_CONFIG,
                               "output_data_formats should match"
                               " the size of output")
                for output_data_format in\
                        subgraph[YAMLKeyword.output_data_formats]:
                    mace_check(output_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
Y
yejianwu 已提交
561
                               "'output_data_formats' must be in "
562 563
                               + str(DataFormatStrs))
            else:
564 565
                subgraph[YAMLKeyword.output_data_formats] =\
                    [DataFormat.NHWC] * output_size
566

567 568 569 570
            validation_threshold = subgraph.get(
                YAMLKeyword.validation_threshold, {})
            if not isinstance(validation_threshold, dict):
                raise argparse.ArgumentTypeError(
L
liuqi 已提交
571
                    'similarity threshold must be a dict.')
572 573

            threshold_dict = {
574 575
                DeviceType.CPU: ValidationThreshold.cpu_threshold,
                DeviceType.GPU: ValidationThreshold.gpu_threshold,
B
Bin Li 已提交
576 577 578
                DeviceType.HEXAGON: ValidationThreshold.quantize_threshold,
                DeviceType.HTA: ValidationThreshold.quantize_threshold,
                DeviceType.QUANTIZE: ValidationThreshold.quantize_threshold,
L
liuqi 已提交
579
            }
580 581 582 583 584
            for k, v in six.iteritems(validation_threshold):
                if k.upper() == 'DSP':
                    k = DeviceType.HEXAGON
                if k.upper() not in (DeviceType.CPU,
                                     DeviceType.GPU,
李寅 已提交
585
                                     DeviceType.HEXAGON,
B
Bin Li 已提交
586
                                     DeviceType.HTA,
B
Bin Li 已提交
587
                                     DeviceType.QUANTIZE):
588
                    raise argparse.ArgumentTypeError(
L
liuqi 已提交
589
                        'Unsupported validation threshold runtime: %s' % k)
590 591 592 593
                threshold_dict[k.upper()] = v

            subgraph[YAMLKeyword.validation_threshold] = threshold_dict

L
liuqi 已提交
594 595 596 597 598 599 600 601
            validation_inputs_data = subgraph.get(
                YAMLKeyword.validation_inputs_data, [])
            if not isinstance(validation_inputs_data, list):
                subgraph[YAMLKeyword.validation_inputs_data] = [
                    validation_inputs_data]
            else:
                subgraph[YAMLKeyword.validation_inputs_data] = \
                    validation_inputs_data
L
liutuo 已提交
602 603 604 605

            onnx_backend = subgraph.get(
                YAMLKeyword.backend, "tensorflow")
            subgraph[YAMLKeyword.backend] = onnx_backend
606 607 608 609 610 611 612 613
            validation_outputs_data = subgraph.get(
                YAMLKeyword.validation_outputs_data, [])
            if not isinstance(validation_outputs_data, list):
                subgraph[YAMLKeyword.validation_outputs_data] = [
                    validation_outputs_data]
            else:
                subgraph[YAMLKeyword.validation_outputs_data] = \
                    validation_outputs_data
614 615 616 617 618 619
            input_ranges = subgraph.get(
                YAMLKeyword.input_ranges, [])
            if not isinstance(input_ranges, list):
                subgraph[YAMLKeyword.input_ranges] = [input_ranges]
            else:
                subgraph[YAMLKeyword.input_ranges] = input_ranges
L
liuqi 已提交
620
            subgraph[YAMLKeyword.input_ranges] = \
621
                [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
W
wuchenghui 已提交
622

623 624 625 626 627 628 629 630 631 632
            accuracy_validation_script = subgraph.get(
                YAMLKeyword.accuracy_validation_script, "")
            if isinstance(accuracy_validation_script, list):
                mace_check(len(accuracy_validation_script) == 1,
                           ModuleName.YAML_CONFIG,
                           "Only support one accuracy validation script")
                accuracy_validation_script = accuracy_validation_script[0]
            subgraph[YAMLKeyword.accuracy_validation_script] = \
                accuracy_validation_script

633 634 635
        for key in [YAMLKeyword.limit_opencl_kernel_time,
                    YAMLKeyword.nnlib_graph_mode,
                    YAMLKeyword.obfuscate,
李寅 已提交
636
                    YAMLKeyword.winograd,
637
                    YAMLKeyword.quantize,
B
Bin Li 已提交
638
                    YAMLKeyword.quantize_large_weights,
639
                    YAMLKeyword.change_concat_ranges]:
640 641 642
            value = model_config.get(key, "")
            if value == "":
                model_config[key] = 0
L
Liangliang He 已提交
643

B
Bin Li 已提交
644 645 646 647 648 649
        mace_check(model_config[YAMLKeyword.quantize] == 0 or
                   model_config[YAMLKeyword.quantize_large_weights] == 0,
                   ModuleName.YAML_CONFIG,
                   "quantize and quantize_large_weights should not be set to 1"
                   " at the same time.")

650 651 652 653 654 655
        mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
                   ModuleName.YAML_CONFIG,
                   "'winograd' parameters must be in "
                   + str(WinogradParameters) +
                   ". 0 for disable winograd convolution")

656
    return configs
Y
yejianwu 已提交
657

W
wuchenghui 已提交
658

659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679
def clear_build_dirs(library_name):
    # make build dir
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    # clear temp build dir
    tmp_build_dir = os.path.join(BUILD_OUTPUT_DIR, library_name,
                                 BUILD_TMP_DIR_NAME)
    if os.path.exists(tmp_build_dir):
        sh.rm('-rf', tmp_build_dir)
    os.makedirs(tmp_build_dir)
    # clear lib dir
    lib_output_dir = os.path.join(
        BUILD_OUTPUT_DIR, library_name, OUTPUT_LIBRARY_DIR_NAME)
    if os.path.exists(lib_output_dir):
        sh.rm('-rf', lib_output_dir)


################################
# convert
################################
def print_configuration(configs):
680 681 682 683 684 685 686 687 688
    title = "Common Configuration"
    header = ["key", "value"]
    data = list()
    data.append([YAMLKeyword.library_name,
                 configs[YAMLKeyword.library_name]])
    data.append([YAMLKeyword.target_abis,
                 configs[YAMLKeyword.target_abis]])
    data.append([YAMLKeyword.target_socs,
                 configs[YAMLKeyword.target_socs]])
689 690 691 692
    data.append([YAMLKeyword.model_graph_format,
                 configs[YAMLKeyword.model_graph_format]])
    data.append([YAMLKeyword.model_data_format,
                 configs[YAMLKeyword.model_data_format]])
693
    MaceLogger.summary(StringFormatter.table(header, data, title))
L
Liangliang He 已提交
694

Y
yejianwu 已提交
695

696
def convert_model(configs, cl_mem_type):
697 698 699 700
    # Remove previous output dirs
    library_name = configs[YAMLKeyword.library_name]
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
L
liuqi 已提交
701 702 703
    elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
        sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
    os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
B
Bin Li 已提交
704 705
    if not os.path.exists(BUILD_DOWNLOADS_DIR):
        os.makedirs(BUILD_DOWNLOADS_DIR)
706 707 708

    model_output_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
L
liuqi 已提交
709 710
    model_header_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_HEADER_DIR_PATH)
711
    # clear output dir
712 713 714
    if os.path.exists(model_output_dir):
        sh.rm("-rf", model_output_dir)
    os.makedirs(model_output_dir)
L
liuqi 已提交
715 716
    if os.path.exists(model_header_dir):
        sh.rm("-rf", model_header_dir)
717 718 719 720 721 722 723 724 725 726 727 728 729 730

    embed_model_data = \
        configs[YAMLKeyword.model_data_format] == ModelFormat.code

    if os.path.exists(MODEL_CODEGEN_DIR):
        sh.rm("-rf", MODEL_CODEGEN_DIR)
    if os.path.exists(ENGINE_CODEGEN_DIR):
        sh.rm("-rf", ENGINE_CODEGEN_DIR)

    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        os.makedirs(model_header_dir)
        sh_commands.gen_mace_engine_factory_source(
            configs[YAMLKeyword.models].keys(),
            embed_model_data)
L
liuqi 已提交
731 732 733
        sh.cp("-f", glob.glob("mace/codegen/engine/*.h"),
              model_header_dir)

734 735 736 737 738
    for model_name in configs[YAMLKeyword.models]:
        MaceLogger.header(
            StringFormatter.block("Convert %s model" % model_name))
        model_config = configs[YAMLKeyword.models][model_name]
        runtime = model_config[YAMLKeyword.runtime]
739 740 741 742
        if cl_mem_type:
            model_config[YAMLKeyword.cl_mem_type] = cl_mem_type
        else:
            model_config[YAMLKeyword.cl_mem_type] = "image"
743 744 745 746 747

        data_type = model_config[YAMLKeyword.data_type]
        # TODO(liuqi): support multiple subgraphs
        subgraphs = model_config[YAMLKeyword.subgraphs]

L
liuqi 已提交
748
        model_codegen_dir = "%s/%s" % (MODEL_CODEGEN_DIR, model_name)
749 750 751
        sh_commands.gen_model_code(
            model_codegen_dir,
            model_config[YAMLKeyword.platform],
752 753
            model_config[YAMLKeyword.model_file_path],
            model_config[YAMLKeyword.weight_file_path],
754 755 756
            model_config[YAMLKeyword.model_sha256_checksum],
            model_config[YAMLKeyword.weight_sha256_checksum],
            ",".join(subgraphs[0][YAMLKeyword.input_tensors]),
L
liuqi 已提交
757
            ",".join(subgraphs[0][YAMLKeyword.input_data_types]),
758
            ",".join(subgraphs[0][YAMLKeyword.input_data_formats]),
759
            ",".join(subgraphs[0][YAMLKeyword.output_tensors]),
L
liuqi 已提交
760
            ",".join(subgraphs[0][YAMLKeyword.output_data_types]),
761
            ",".join(subgraphs[0][YAMLKeyword.output_data_formats]),
B
Bin Li 已提交
762
            ",".join(subgraphs[0][YAMLKeyword.check_tensors]),
763 764 765
            runtime,
            model_name,
            ":".join(subgraphs[0][YAMLKeyword.input_shapes]),
李寅 已提交
766
            ":".join(subgraphs[0][YAMLKeyword.input_ranges]),
B
Bin Li 已提交
767 768
            ":".join(subgraphs[0][YAMLKeyword.output_shapes]),
            ":".join(subgraphs[0][YAMLKeyword.check_shapes]),
769 770 771
            model_config[YAMLKeyword.nnlib_graph_mode],
            embed_model_data,
            model_config[YAMLKeyword.winograd],
李寅 已提交
772
            model_config[YAMLKeyword.quantize],
B
Bin Li 已提交
773
            model_config[YAMLKeyword.quantize_large_weights],
774
            model_config[YAMLKeyword.quantize_range_file],
775
            model_config[YAMLKeyword.change_concat_ranges],
776
            model_config[YAMLKeyword.obfuscate],
777
            configs[YAMLKeyword.model_graph_format],
李寅 已提交
778
            data_type,
779
            model_config[YAMLKeyword.cl_mem_type],
李寅 已提交
780
            ",".join(model_config.get(YAMLKeyword.graph_optimize_options, [])))
781

782
        if configs[YAMLKeyword.model_graph_format] == ModelFormat.file:
L
liuqi 已提交
783 784 785
            sh.mv("-f",
                  '%s/%s.pb' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
786 787 788
            sh.mv("-f",
                  '%s/%s.data' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
789 790
        else:
            if not embed_model_data:
L
liuqi 已提交
791
                sh.mv("-f",
L
liuqi 已提交
792
                      '%s/%s.data' % (model_codegen_dir, model_name),
L
liuqi 已提交
793
                      model_output_dir)
L
liuqi 已提交
794 795
            sh.cp("-f", glob.glob("mace/codegen/models/*/*.h"),
                  model_header_dir)
796

L
liuqi 已提交
797
        MaceLogger.summary(
798 799 800
            StringFormatter.block("Model %s converted" % model_name))


801
def build_model_lib(configs, address_sanitizer, debug_mode):
802
    MaceLogger.header(StringFormatter.block("Building model library"))
803

804 805 806
    # create model library dir
    library_name = configs[YAMLKeyword.library_name]
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
807 808 809 810 811
        model_lib_output_path = get_model_lib_output_path(library_name,
                                                          target_abi)
        library_out_dir = os.path.dirname(model_lib_output_path)
        if not os.path.exists(library_out_dir):
            os.makedirs(library_out_dir)
L
liuqi 已提交
812
        toolchain = infer_toolchain(target_abi)
813
        sh_commands.bazel_build(
814
            MODEL_LIB_TARGET,
815
            abi=target_abi,
L
liuqi 已提交
816
            toolchain=toolchain,
B
Bin Li 已提交
817 818
            enable_hexagon=get_hexagon_mode(configs),
            enable_hta=get_hta_mode(configs),
Y
yejianwu 已提交
819
            enable_opencl=get_opencl_mode(configs),
820
            enable_quantize=get_quantize_mode(configs),
821
            address_sanitizer=address_sanitizer,
822 823
            symbol_hidden=get_symbol_hidden_mode(debug_mode),
            debug_mode=debug_mode
824 825
        )

826
        sh.cp("-f", MODEL_LIB_PATH, model_lib_output_path)
827 828 829 830 831 832 833


def print_library_summary(configs):
    library_name = configs[YAMLKeyword.library_name]
    title = "Library"
    header = ["key", "value"]
    data = list()
834 835 836 837 838 839 840 841 842
    data.append(["MACE Model Path",
                 "%s/%s/%s"
                 % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)])
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        data.append(["MACE Model Header Path",
                     "%s/%s/%s"
                     % (BUILD_OUTPUT_DIR, library_name,
                        MODEL_HEADER_DIR_PATH)])

843 844 845
    MaceLogger.summary(StringFormatter.table(header, data, title))


846
def convert_func(flags):
847
    configs = format_model_config(flags)
848

849
    print_configuration(configs)
850

851
    convert_model(configs, flags.cl_mem_type)
852

853
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
854
        build_model_lib(configs, flags.address_sanitizer, flags.debug_mode)
855 856 857 858 859 860 861

    print_library_summary(configs)


################################
# run
################################
L
liuqi 已提交
862
def build_mace_run(configs, target_abi, toolchain, enable_openmp,
863
                   address_sanitizer, mace_lib_type, debug_mode):
864 865 866 867 868 869 870 871 872 873 874 875 876 877
    library_name = configs[YAMLKeyword.library_name]

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    mace_run_target = MACE_RUN_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
        mace_run_target = MACE_RUN_DYNAMIC_TARGET
    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
L
liuqi 已提交
878
                   "You should convert model first.")
879 880 881 882 883
        build_arg = "--per_file_copt=mace/tools/validation/mace_run.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(
        mace_run_target,
        abi=target_abi,
L
liuqi 已提交
884
        toolchain=toolchain,
B
Bin Li 已提交
885 886
        enable_hexagon=get_hexagon_mode(configs),
        enable_hta=get_hta_mode(configs),
887
        enable_openmp=enable_openmp,
Y
yejianwu 已提交
888
        enable_opencl=get_opencl_mode(configs),
889
        enable_quantize=get_quantize_mode(configs),
890
        address_sanitizer=address_sanitizer,
891 892
        symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),
        debug_mode=debug_mode,
893 894 895 896 897 898
        extra_args=build_arg
    )
    sh_commands.update_mace_run_binary(build_tmp_binary_dir,
                                       mace_lib_type == MACELibType.dynamic)


899 900
def build_example(configs, target_abi, toolchain, enable_openmp, mace_lib_type,
                  cl_binary_to_code, device, debug_mode):
901 902 903 904 905 906 907
    library_name = configs[YAMLKeyword.library_name]

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

908 909 910 911 912 913 914 915 916 917 918 919 920 921
    if cl_binary_to_code:
        sh_commands.gen_opencl_binary_cpps(
            get_opencl_binary_output_path(
                library_name, target_abi, device),
            get_opencl_parameter_output_path(
                library_name, target_abi, device),
            OPENCL_CODEGEN_DIR + '/opencl_binary.cc',
            OPENCL_CODEGEN_DIR + '/opencl_parameter.cc')
    else:
        sh_commands.gen_opencl_binary_cpps(
            "", "",
            OPENCL_CODEGEN_DIR + '/opencl_binary.cc',
            OPENCL_CODEGEN_DIR + '/opencl_parameter.cc')

922 923 924 925 926 927
    libmace_target = LIBMACE_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
        libmace_target = LIBMACE_SO_TARGET

    sh_commands.bazel_build(libmace_target,
                            abi=target_abi,
L
liuqi 已提交
928
                            toolchain=toolchain,
929
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
930
                            enable_opencl=get_opencl_mode(configs),
931
                            enable_quantize=get_quantize_mode(configs),
B
Bin Li 已提交
932 933
                            enable_hexagon=get_hexagon_mode(configs),
                            enable_hta=get_hta_mode(configs),
L
liuqi 已提交
934
                            address_sanitizer=flags.address_sanitizer,
935 936
                            symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),  # noqa
                            debug_mode=debug_mode)
937 938 939 940 941 942

    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)
    sh.mkdir("-p", LIB_CODEGEN_DIR)

    build_arg = ""
943 944 945 946 947 948 949
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
                   "You should convert model first.")
        model_lib_path = get_model_lib_output_path(library_name,
                                                   target_abi)
        sh.cp("-f", model_lib_path, LIB_CODEGEN_DIR)
L
Liangliang He 已提交
950
        build_arg = "--per_file_copt=examples/cli/example.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa
951

952 953 954 955 956
    if mace_lib_type == MACELibType.dynamic:
        example_target = EXAMPLE_DYNAMIC_TARGET
        sh.cp("-f", LIBMACE_DYNAMIC_PATH, LIB_CODEGEN_DIR)
    else:
        example_target = EXAMPLE_STATIC_TARGET
957
        sh.cp("-f", LIBMACE_STATIC_PATH, LIB_CODEGEN_DIR)
958 959 960

    sh_commands.bazel_build(example_target,
                            abi=target_abi,
L
liuqi 已提交
961
                            toolchain=toolchain,
962
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
963
                            enable_opencl=get_opencl_mode(configs),
964
                            enable_quantize=get_quantize_mode(configs),
B
Bin Li 已提交
965 966
                            enable_hexagon=get_hexagon_mode(configs),
                            enable_hta=get_hta_mode(configs),
L
liuqi 已提交
967
                            address_sanitizer=flags.address_sanitizer,
968
                            debug_mode=debug_mode,
969 970 971 972 973 974 975 976
                            extra_args=build_arg)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(example_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)
    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)


L
liuqi 已提交
977 978 979 980 981 982 983 984 985 986
def print_package_summary(package_path):
    title = "Library"
    header = ["key", "value"]
    data = list()
    data.append(["MACE Model package Path",
                 package_path])

    MaceLogger.summary(StringFormatter.table(header, data, title))


987
def run_mace(flags):
988
    configs = format_model_config(flags)
989 990

    clear_build_dirs(configs[YAMLKeyword.library_name])
991 992

    target_socs = configs[YAMLKeyword.target_socs]
993
    device_list = DeviceManager.list_devices(flags.device_yml)
L
liuqi 已提交
994
    if target_socs and TargetSOCTag.all not in target_socs:
L
liuqi 已提交
995 996
        device_list = [dev for dev in device_list
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
997
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
998 999 1000 1001 1002
        if flags.target_socs == TargetSOCTag.random:
            target_devices = sh_commands.choose_a_random_device(
                device_list, target_abi)
        else:
            target_devices = device_list
1003
        # build target
L
liuqi 已提交
1004
        for dev in target_devices:
L
liuqi 已提交
1005 1006 1007
            if target_abi in dev[YAMLKeyword.target_abis]:
                # get toolchain
                toolchain = infer_toolchain(target_abi)
1008
                device = DeviceWrapper(dev)
L
liuqi 已提交
1009 1010 1011 1012
                if flags.example:
                    build_example(configs,
                                  target_abi,
                                  toolchain,
L
liyin 已提交
1013
                                  flags.enable_openmp,
1014 1015
                                  flags.mace_lib_type,
                                  flags.cl_binary_to_code,
1016 1017
                                  device,
                                  flags.debug_mode)
L
liuqi 已提交
1018 1019 1020 1021
                else:
                    build_mace_run(configs,
                                   target_abi,
                                   toolchain,
L
liyin 已提交
1022
                                   flags.enable_openmp,
L
liuqi 已提交
1023
                                   flags.address_sanitizer,
1024 1025
                                   flags.mace_lib_type,
                                   flags.debug_mode)
L
liuqi 已提交
1026
                # run
1027
                start_time = time.time()
L
liuqi 已提交
1028 1029
                with device.lock():
                    device.run_specify_abi(flags, configs, target_abi)
1030 1031
                elapse_minutes = (time.time() - start_time) / 60
                print("Elapse time: %f minutes." % elapse_minutes)
L
liuqi 已提交
1032 1033 1034 1035
            elif dev[YAMLKeyword.device_name] != SystemType.host:
                six.print_('The device with soc %s do not support abi %s' %
                           (dev[YAMLKeyword.target_socs], target_abi),
                           file=sys.stderr)
1036

L
liuqi 已提交
1037 1038 1039 1040 1041
    # package the output files
    package_path = sh_commands.packaging_lib(BUILD_OUTPUT_DIR,
                                             configs[YAMLKeyword.library_name])
    print_package_summary(package_path)

1042 1043 1044 1045

################################
#  benchmark model
################################
L
liuqi 已提交
1046 1047 1048 1049
def build_benchmark_model(configs,
                          target_abi,
                          toolchain,
                          enable_openmp,
1050 1051
                          mace_lib_type,
                          debug_mode):
1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
    library_name = configs[YAMLKeyword.library_name]

    link_dynamic = mace_lib_type == MACELibType.dynamic
    if link_dynamic:
        benchmark_target = BM_MODEL_DYNAMIC_TARGET
    else:
        benchmark_target = BM_MODEL_STATIC_TARGET

    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.BENCHMARK,
L
liuqi 已提交
1064
                   "You should convert model first.")
L
Liangliang He 已提交
1065
        build_arg = "--per_file_copt=mace/tools/benchmark/benchmark_model.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa
1066 1067 1068

    sh_commands.bazel_build(benchmark_target,
                            abi=target_abi,
L
liuqi 已提交
1069
                            toolchain=toolchain,
1070
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
1071
                            enable_opencl=get_opencl_mode(configs),
1072
                            enable_quantize=get_quantize_mode(configs),
B
Bin Li 已提交
1073 1074
                            enable_hexagon=get_hexagon_mode(configs),
                            enable_hta=get_hta_mode(configs),
1075 1076
                            symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),  # noqa
                            debug_mode=debug_mode,
1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087
                            extra_args=build_arg)
    # clear tmp binary dir
    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(benchmark_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)


1088
def benchmark_model(flags):
1089
    configs = format_model_config(flags)
1090 1091

    clear_build_dirs(configs[YAMLKeyword.library_name])
1092 1093

    target_socs = configs[YAMLKeyword.target_socs]
1094
    device_list = DeviceManager.list_devices(flags.device_yml)
L
liuqi 已提交
1095
    if target_socs and TargetSOCTag.all not in target_socs:
L
liuqi 已提交
1096
        device_list = [dev for dev in device_list
1097
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
1098
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
1099 1100 1101 1102 1103
        if flags.target_socs == TargetSOCTag.random:
            target_devices = sh_commands.choose_a_random_device(
                device_list, target_abi)
        else:
            target_devices = device_list
1104
        # build benchmark_model binary
L
liuqi 已提交
1105
        for dev in target_devices:
L
liuqi 已提交
1106 1107 1108 1109 1110
            if target_abi in dev[YAMLKeyword.target_abis]:
                toolchain = infer_toolchain(target_abi)
                build_benchmark_model(configs,
                                      target_abi,
                                      toolchain,
L
liyin 已提交
1111
                                      flags.enable_openmp,
1112 1113
                                      flags.mace_lib_type,
                                      flags.debug_mode)
L
liuqi 已提交
1114
                device = DeviceWrapper(dev)
1115
                start_time = time.time()
L
liuqi 已提交
1116 1117
                with device.lock():
                    device.bm_specific_target(flags, configs, target_abi)
1118 1119
                elapse_minutes = (time.time() - start_time) / 60
                print("Elapse time: %f minutes." % elapse_minutes)
L
liuqi 已提交
1120 1121 1122 1123
            else:
                six.print_('There is no abi %s with soc %s' %
                           (target_abi, dev[YAMLKeyword.target_socs]),
                           file=sys.stderr)
L
liuqi 已提交
1124

1125

L
liuqi 已提交
1126
################################
Y
yejianwu 已提交
1127
# parsing arguments
L
liuqi 已提交
1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
1140
        return CaffeEnvType.DOCKER
L
liuqi 已提交
1141
    elif v.lower() == 'local':
1142
        return CaffeEnvType.LOCAL
L
liuqi 已提交
1143 1144 1145 1146
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


1147 1148 1149 1150 1151 1152 1153 1154 1155
def str_to_mace_lib_type(v):
    if v.lower() == 'dynamic':
        return MACELibType.dynamic
    elif v.lower() == 'static':
        return MACELibType.static
    else:
        raise argparse.ArgumentTypeError('[dynamic| static] expected.')


1156
def parse_args():
L
Liangliang He 已提交
1157
    """Parses command line arguments."""
1158 1159 1160
    all_type_parent_parser = argparse.ArgumentParser(add_help=False)
    all_type_parent_parser.add_argument(
        '--config',
L
Liangliang He 已提交
1161
        type=str,
1162
        default="",
L
liuqi 已提交
1163
        required=True,
1164
        help="the path of model yaml configuration file.")
1165
    all_type_parent_parser.add_argument(
1166
        "--model_graph_format",
1167 1168
        type=str,
        default="",
1169 1170 1171 1172 1173 1174
        help="[file, code], MACE Model graph format.")
    all_type_parent_parser.add_argument(
        "--model_data_format",
        type=str,
        default="",
        help="['file', 'code'], MACE Model data format.")
1175 1176 1177 1178 1179
    all_type_parent_parser.add_argument(
        "--target_abis",
        type=str,
        default="",
        help="Target ABIs, comma seperated list.")
1180 1181 1182 1183 1184
    all_type_parent_parser.add_argument(
        "--target_socs",
        type=str,
        default="",
        help="Target SOCs, comma seperated list.")
1185 1186 1187 1188
    all_type_parent_parser.add_argument(
        "--debug_mode",
        action="store_true",
        help="Reserve debug symbols.")
1189 1190
    convert_run_parent_parser = argparse.ArgumentParser(add_help=False)
    convert_run_parent_parser.add_argument(
1191 1192
        '--address_sanitizer',
        action="store_true",
L
liuqi 已提交
1193
        help="Whether to use address sanitizer to check memory error")
1194
    run_bm_parent_parser = argparse.ArgumentParser(add_help=False)
1195 1196 1197 1198 1199 1200
    run_bm_parent_parser.add_argument(
        "--mace_lib_type",
        type=str_to_mace_lib_type,
        default=DefaultValues.mace_lib_type,
        help="[static | dynamic], Which type MACE library to use.")
    run_bm_parent_parser.add_argument(
L
liyin 已提交
1201
        "--enable_openmp",
1202
        action="store_true",
L
liyin 已提交
1203
        help="Enable openmp for multiple thread.")
1204
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1205 1206
        "--omp_num_threads",
        type=int,
1207
        default=DefaultValues.omp_num_threads,
W
wuchenghui 已提交
1208
        help="num of openmp threads")
1209
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1210 1211
        "--cpu_affinity_policy",
        type=int,
1212
        default=DefaultValues.cpu_affinity_policy,
W
wuchenghui 已提交
1213
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
1214
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1215 1216
        "--gpu_perf_hint",
        type=int,
1217
        default=DefaultValues.gpu_perf_hint,
W
wuchenghui 已提交
1218
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
1219
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1220 1221
        "--gpu_priority_hint",
        type=int,
1222
        default=DefaultValues.gpu_priority_hint,
W
wuchenghui 已提交
1223
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liuqi 已提交
1224 1225 1226 1227 1228 1229
    run_bm_parent_parser.add_argument(
        "--device_yml",
        type=str,
        default='',
        help='embedded linux device config yml file'
    )
1230 1231
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
1232 1233 1234 1235
    convert = subparsers.add_parser(
        'convert',
        parents=[all_type_parent_parser, convert_run_parent_parser],
        help='convert to mace model (file or code)')
1236 1237 1238 1239 1240
    convert.add_argument(
        "--cl_mem_type",
        type=str,
        default=None,
        help="Which type of OpenCL memory type to use [image | buffer].")
1241
    convert.set_defaults(func=convert_func)
1242 1243 1244
    run = subparsers.add_parser(
        'run',
        parents=[all_type_parent_parser, run_bm_parent_parser,
1245
                 convert_run_parent_parser],
1246 1247
        help='run model in command line')
    run.set_defaults(func=run_mace)
1248 1249 1250 1251
    run.add_argument(
        "--disable_tuning",
        action="store_true",
        help="Disable tuning for specific thread.")
1252 1253
    run.add_argument(
        "--round",
L
Liangliang He 已提交
1254
        type=int,
1255 1256 1257 1258 1259
        default=1,
        help="The model running round.")
    run.add_argument(
        "--validate",
        action="store_true",
1260 1261
        help="whether to verify the results are consistent with "
             "the frameworks.")
B
Bin Li 已提交
1262
    run.add_argument(
B
Bin Li 已提交
1263 1264 1265 1266 1267
        "--layers",
        type=str,
        default="-1",
        help="'start_layer:end_layer' or 'layer', similar to python slice."
             " Use with --validate flag.")
1268
    run.add_argument(
L
liuqi 已提交
1269 1270 1271
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
1272 1273
        help="[docker | local] you can specific caffe environment for"
             " validation. local environment or caffe docker image.")
1274 1275 1276 1277
    run.add_argument(
        "--vlog_level",
        type=int,
        default=0,
1278
        help="[1~5]. Verbose log level for debug.")
1279
    run.add_argument(
L
Liangliang He 已提交
1280
        "--gpu_out_of_range_check",
1281 1282 1283 1284 1285 1286
        action="store_true",
        help="Enable out of memory check for gpu.")
    run.add_argument(
        "--restart_round",
        type=int,
        default=1,
1287
        help="restart round between run.")
1288 1289 1290 1291 1292 1293
    run.add_argument(
        "--report",
        action="store_true",
        help="print run statistics report.")
    run.add_argument(
        "--report_dir",
1294 1295
        type=str,
        default="",
1296 1297
        help="print run statistics report.")
    run.add_argument(
李寅 已提交
1298 1299 1300 1301
        "--runtime_failure_ratio",
        type=float,
        default=0.0,
        help="[mock runtime failure ratio].")
L
liuqi 已提交
1302 1303 1304 1305
    run.add_argument(
        "--example",
        action="store_true",
        help="whether to run example.")
李寅 已提交
1306 1307 1308 1309 1310 1311 1312 1313 1314
    run.add_argument(
        "--quantize_stat",
        action="store_true",
        help="whether to stat quantization range.")
    run.add_argument(
        "--input_dir",
        type=str,
        default="",
        help="quantize stat input dir.")
1315 1316 1317 1318 1319
    run.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="quantize stat output dir.")
1320 1321 1322 1323
    run.add_argument(
        "--cl_binary_to_code",
        action="store_true",
        help="convert OpenCL binaries to cpp.")
1324 1325
    benchmark = subparsers.add_parser(
        'benchmark',
1326
        parents=[all_type_parent_parser, run_bm_parent_parser],
1327 1328
        help='benchmark model for detail information')
    benchmark.set_defaults(func=benchmark_model)
B
Bin Li 已提交
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
    benchmark.add_argument(
        "--max_num_runs",
        type=int,
        default=100,
        help="max number of runs.")
    benchmark.add_argument(
        "--max_seconds",
        type=float,
        default=10.0,
        help="max number of seconds to run.")
L
Liangliang He 已提交
1339 1340
    return parser.parse_known_args()

1341

Y
yejianwu 已提交
1342
if __name__ == "__main__":
1343 1344
    flags, unparsed = parse_args()
    flags.func(flags)