converter.py 45.9 KB
Newer Older
Y
yejianwu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright 2018 Xiaomi, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

import argparse
L
liuqi 已提交
16
import glob
17
import hashlib
18
import os
L
liuqi 已提交
19
import re
L
Liangliang He 已提交
20
import sh
21
import sys
22
import urllib
Y
yejianwu 已提交
23
import yaml
L
liuqi 已提交
24

25
from enum import Enum
26
import six
27

28
import sh_commands
L
Liangliang He 已提交
29

L
liuqi 已提交
30 31
from common import *
from device import DeviceWrapper, DeviceManager
32

Y
yejianwu 已提交
33 34 35 36
################################
# set environment
################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
A
Allen 已提交
37

38 39 40 41 42
################################
# common definitions
################################

ABITypeStrs = [
L
liuqi 已提交
43 44
    'armeabi-v7a',
    'arm64-v8a',
L
liuqi 已提交
45 46
    'arm64',
    'armhf',
L
liuqi 已提交
47
    'host',
48
]
L
liuqi 已提交
49

50 51 52 53 54
ModelFormatStrs = [
    "file",
    "code",
]

55 56 57 58 59 60 61 62 63 64 65 66 67 68
PlatformTypeStrs = [
    "tensorflow",
    "caffe",
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
                    type=str)

RuntimeTypeStrs = [
    "cpu",
    "gpu",
    "dsp",
    "cpu+gpu"
]

Y
yejianwu 已提交
69 70 71 72 73 74 75 76 77
InputDataTypeStrs = [
    "int32",
    "float32",
]

InputDataType = Enum('InputDataType',
                     [(ele, ele) for ele in InputDataTypeStrs],
                     type=str)

L
liuqi 已提交
78
FPDataTypeStrs = [
79 80 81 82
    "fp16_fp32",
    "fp32_fp32",
]

L
liuqi 已提交
83 84
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
                  type=str)
85

L
liuqi 已提交
86 87 88 89 90 91 92
DSPDataTypeStrs = [
    "uint8",
]

DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
                   type=str)

93 94
WinogradParameters = [0, 2, 4]

95 96 97 98 99 100 101 102 103 104
DataFormatStrs = [
    "NONE",
    "NHWC",
]


class DataFormat(object):
    NONE = "NONE"
    NHWC = "NHWC"

105 106

class DefaultValues(object):
107
    mace_lib_type = MACELibType.static
108 109 110 111 112 113
    omp_num_threads = -1,
    cpu_affinity_policy = 1,
    gpu_perf_hint = 3,
    gpu_priority_hint = 3,


114 115 116 117 118 119 120
class ValidationThreshold(object):
    cpu_threshold = 0.999,
    gpu_threshold = 0.995,
    hexagon_threshold = 0.930,
    cpu_quantize_threshold = 0.980,


121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
CPP_KEYWORDS = [
    'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
    'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor',
    'bool', 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t',
    'class', 'compl', 'concept', 'const', 'constexpr', 'const_cast',
    'continue', 'co_await', 'co_return', 'co_yield', 'decltype', 'default',
    'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit',
    'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if',
    'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace',
    'new', 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq',
    'private', 'protected', 'public', 'register', 'reinterpret_cast',
    'requires', 'return', 'short', 'signed', 'sizeof', 'static',
    'static_assert', 'static_cast', 'struct', 'switch', 'synchronized',
    'template', 'this', 'thread_local', 'throw', 'true', 'try', 'typedef',
    'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void',
    'volatile', 'wchar_t', 'while', 'xor', 'xor_eq', 'override', 'final',
    'transaction_safe', 'transaction_safe_dynamic', 'if', 'elif', 'else',
    'endif', 'defined', 'ifdef', 'ifndef', 'define', 'undef', 'include',
    'line', 'error', 'pragma',
]
Y
yejianwu 已提交
141

142

143 144 145
################################
# common functions
################################
146
def parse_device_type(runtime):
Y
yejianwu 已提交
147
    device_type = ""
148

149
    if runtime == RuntimeType.dsp:
150
        device_type = DeviceType.HEXAGON
151
    elif runtime == RuntimeType.gpu:
152
        device_type = DeviceType.GPU
153
    elif runtime == RuntimeType.cpu:
154
        device_type = DeviceType.CPU
155

156
    return device_type
157

Y
yejianwu 已提交
158 159

def get_hexagon_mode(configs):
L
Liangliang He 已提交
160
    runtime_list = []
L
liuqi 已提交
161
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
162
        model_runtime = \
L
liuqi 已提交
163 164
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
L
Liangliang He 已提交
165 166
        runtime_list.append(model_runtime.lower())

L
liuqi 已提交
167
    if RuntimeType.dsp in runtime_list:
Y
yejianwu 已提交
168 169 170 171
        return True
    return False


Y
yejianwu 已提交
172 173 174
def get_opencl_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
175
        model_runtime = \
Y
yejianwu 已提交
176 177 178 179 180 181 182 183 184
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.gpu in runtime_list or RuntimeType.cpu_gpu in runtime_list:
        return True
    return False


185 186 187 188 189 190 191 192 193 194 195
def get_quantize_mode(configs):
    for model_name in configs[YAMLKeyword.models]:
        quantize =\
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.quantize, 0)
        if quantize == 1:
            return True

    return False


196 197
def md5sum(str):
    md5 = hashlib.md5()
198
    md5.update(str.encode('utf-8'))
199
    return md5.hexdigest()
200

Y
yejianwu 已提交
201

202 203 204 205 206 207
def sha256_checksum(fname):
    hash_func = hashlib.sha256()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_func.update(chunk)
    return hash_func.hexdigest()
Y
yejianwu 已提交
208

W
wuchenghui 已提交
209

210 211
def format_model_config(flags):
    with open(flags.config) as f:
212
        configs = yaml.load(f)
W
wuchenghui 已提交
213

214 215
    library_name = configs.get(YAMLKeyword.library_name, "")
    mace_check(len(library_name) > 0,
L
liuqi 已提交
216
               ModuleName.YAML_CONFIG, "library name should not be empty")
217

218 219 220 221
    if flags.target_abis:
        target_abis = flags.target_abis.split(',')
    else:
        target_abis = configs.get(YAMLKeyword.target_abis, [])
222 223
    mace_check((isinstance(target_abis, list) and len(target_abis) > 0),
               ModuleName.YAML_CONFIG, "target_abis list is needed")
224
    configs[YAMLKeyword.target_abis] = target_abis
225 226 227 228 229 230
    for abi in target_abis:
        mace_check(abi in ABITypeStrs,
                   ModuleName.YAML_CONFIG,
                   "target_abis must be in " + str(ABITypeStrs))

    target_socs = configs.get(YAMLKeyword.target_socs, "")
231 232
    if flags.target_socs:
        configs[YAMLKeyword.target_socs] = \
L
liuqi 已提交
233
            [soc.lower() for soc in flags.target_socs.split(',')]
234
    elif not target_socs:
235 236 237 238
        configs[YAMLKeyword.target_socs] = []
    elif not isinstance(target_socs, list):
        configs[YAMLKeyword.target_socs] = [target_socs]

239 240 241
    configs[YAMLKeyword.target_socs] = \
        [soc.lower() for soc in configs[YAMLKeyword.target_socs]]

L
liuqi 已提交
242 243
    if ABIType.armeabi_v7a in target_abis \
            or ABIType.arm64_v8a in target_abis:
244
        available_socs = sh_commands.adb_get_all_socs()
245 246 247 248
        target_socs = configs[YAMLKeyword.target_socs]
        if ALL_SOC_TAG in target_socs:
            mace_check(available_socs,
                       ModuleName.YAML_CONFIG,
L
liuqi 已提交
249 250 251
                       "Android abi is listed in config file and "
                       "build for all SOCs plugged in computer, "
                       "But no android phone found, "
252 253
                       "you at least plug in one phone")
        else:
254 255 256 257 258 259
            for soc in target_socs:
                mace_check(soc in available_socs,
                           ModuleName.YAML_CONFIG,
                           "Build specified SOC library, "
                           "you must plug in a phone using the SOC")

260 261
    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
262
    else:
263 264 265 266 267 268 269 270
        model_graph_format = configs.get(YAMLKeyword.model_graph_format, "")
    mace_check(model_graph_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_graph_format and '
               "model_graph_format must be in " + str(ModelFormatStrs))
    configs[YAMLKeyword.model_graph_format] = model_graph_format
    if flags.model_data_format:
        model_data_format = flags.model_data_format
271
    else:
272 273 274 275 276 277 278 279 280 281 282 283
        model_data_format = configs.get(YAMLKeyword.model_data_format, "")
    configs[YAMLKeyword.model_data_format] = model_data_format
    mace_check(model_data_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_data_format and '
               "model_data_format must be in " + str(ModelFormatStrs))

    mace_check(not (model_graph_format == ModelFormat.file
                    and model_data_format == ModelFormat.code),
               ModuleName.YAML_CONFIG,
               "If model_graph format is 'file',"
               " the model_data_format must be 'file' too")
Y
yejianwu 已提交
284

285 286 287 288
    model_names = configs.get(YAMLKeyword.models, [])
    mace_check(len(model_names) > 0, ModuleName.YAML_CONFIG,
               "no model found in config file")

L
liuqi 已提交
289
    model_name_reg = re.compile(r'^[a-zA-Z0-9_]+$')
290 291 292 293 294 295 296 297
    for model_name in model_names:
        # check model_name legality
        mace_check(model_name not in CPP_KEYWORDS,
                   ModuleName.YAML_CONFIG,
                   "model name should not be c++ keyword.")
        mace_check((model_name[0] == '_' or model_name[0].isalpha())
                   and bool(model_name_reg.match(model_name)),
                   ModuleName.YAML_CONFIG,
L
liuqi 已提交
298
                   "model name should Meet the c++ naming convention"
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
                   " which start with '_' or alpha"
                   " and only contain alpha, number and '_'")

        model_config = configs[YAMLKeyword.models][model_name]
        platform = model_config.get(YAMLKeyword.platform, "")
        mace_check(platform in PlatformTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'platform' must be in " + str(PlatformTypeStrs))

        for key in [YAMLKeyword.model_file_path,
                    YAMLKeyword.model_sha256_checksum]:
            value = model_config.get(key, "")
            mace_check(value != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" % key)

        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        if weight_file_path:
L
liuqi 已提交
316
            weight_checksum = \
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
                model_config.get(YAMLKeyword.weight_sha256_checksum, "")
            mace_check(weight_checksum != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" %
                       YAMLKeyword.weight_sha256_checksum)
        else:
            model_config[YAMLKeyword.weight_sha256_checksum] = ""

        runtime = model_config.get(YAMLKeyword.runtime, "")
        mace_check(runtime in RuntimeTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'runtime' must be in " + str(RuntimeTypeStrs))
        if ABIType.host in target_abis:
            mace_check(runtime == RuntimeType.cpu,
                       ModuleName.YAML_CONFIG,
                       "host only support cpu runtime now.")

        data_type = model_config.get(YAMLKeyword.data_type, "")
L
liuqi 已提交
334
        if runtime == RuntimeType.dsp:
L
liuqi 已提交
335 336 337 338 339 340 341 342
            if len(data_type) > 0:
                mace_check(data_type in DSPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(DSPDataTypeStrs)
                           + " for dsp runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    DSPDataType.uint8.value
L
liuqi 已提交
343 344 345 346 347 348 349 350 351 352 353 354 355
        else:
            if len(data_type) > 0:
                mace_check(data_type in FPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(FPDataTypeStrs)
                           + " for cpu runtime")
            else:
                if runtime == RuntimeType.cpu:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp32_fp32.value
                else:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp16_fp32.value
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

        subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
        mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
                   "at least one subgraph is needed")

        for subgraph in subgraphs:
            for key in [YAMLKeyword.input_tensors,
                        YAMLKeyword.input_shapes,
                        YAMLKeyword.output_tensors,
                        YAMLKeyword.output_shapes]:
                value = subgraph.get(key, "")
                mace_check(value != "", ModuleName.YAML_CONFIG,
                           "'%s' is necessary in subgraph" % key)
                if not isinstance(value, list):
                    subgraph[key] = [value]
371
                subgraph[key] = [str(v) for v in subgraph[key]]
372

B
Bin Li 已提交
373 374 375 376 377 378 379 380 381 382
            for key in [YAMLKeyword.check_tensors,
                        YAMLKeyword.check_shapes]:
                value = subgraph.get(key, "")
                if value != "":
                    if not isinstance(value, list):
                        subgraph[key] = [value]
                    subgraph[key] = [str(v) for v in subgraph[key]]
                else:
                    subgraph[key] = []

Y
yejianwu 已提交
383 384 385 386
            input_data_types = subgraph.get(YAMLKeyword.input_data_types, "")
            if input_data_types:
                if not isinstance(input_data_types, list):
                    subgraph[YAMLKeyword.input_data_types] = [input_data_types]
387
                for input_data_type in subgraph[YAMLKeyword.input_data_types]:
Y
yejianwu 已提交
388 389 390 391 392 393 394
                    mace_check(input_data_type in InputDataTypeStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_types' must be in "
                               + str(InputDataTypeStrs))
            else:
                subgraph[YAMLKeyword.input_data_types] = []

395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437
            input_data_formats = subgraph.get(YAMLKeyword.input_data_formats,
                                              [])
            if input_data_formats:
                if not isinstance(input_data_formats, list):
                    subgraph[YAMLKeyword.input_data_formats] =\
                        [input_data_formats]
                else:
                    mace_check(len(input_data_formats)
                               == len(subgraph[YAMLKeyword.input_tensors]),
                               ModuleName.YAML_CONFIG,
                               "input_data_formats should match"
                               " the size of input")
                for input_data_format in\
                        subgraph[YAMLKeyword.input_data_formats]:
                    mace_check(input_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs) + ", but got "
                               + input_data_formats)
            else:
                subgraph[YAMLKeyword.input_data_formats] = [DataFormat.NHWC]

            output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
                                               [])
            if output_data_formats:
                if not isinstance(output_data_formats, list):
                    subgraph[YAMLKeyword.output_data_formats] = \
                        [output_data_formats]
                else:
                    mace_check(len(output_data_formats)
                               == len(subgraph[YAMLKeyword.output_tensors]),
                               ModuleName.YAML_CONFIG,
                               "output_data_formats should match"
                               " the size of output")
                for output_data_format in\
                        subgraph[YAMLKeyword.output_data_formats]:
                    mace_check(output_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs))
            else:
                subgraph[YAMLKeyword.output_data_formats] = [DataFormat.NHWC]

438 439 440 441
            validation_threshold = subgraph.get(
                YAMLKeyword.validation_threshold, {})
            if not isinstance(validation_threshold, dict):
                raise argparse.ArgumentTypeError(
L
liuqi 已提交
442
                    'similarity threshold must be a dict.')
443 444

            threshold_dict = {
445 446 447 448 449
                DeviceType.CPU: ValidationThreshold.cpu_threshold,
                DeviceType.GPU: ValidationThreshold.gpu_threshold,
                DeviceType.HEXAGON: ValidationThreshold.hexagon_threshold,
                DeviceType.CPU + "_QUANTIZE":
                    ValidationThreshold.cpu_quantize_threshold,
L
liuqi 已提交
450
            }
451 452 453 454 455
            for k, v in six.iteritems(validation_threshold):
                if k.upper() == 'DSP':
                    k = DeviceType.HEXAGON
                if k.upper() not in (DeviceType.CPU,
                                     DeviceType.GPU,
李寅 已提交
456 457
                                     DeviceType.HEXAGON,
                                     DeviceType.CPU + "_QUANTIZE"):
458
                    raise argparse.ArgumentTypeError(
L
liuqi 已提交
459
                        'Unsupported validation threshold runtime: %s' % k)
460 461 462 463
                threshold_dict[k.upper()] = v

            subgraph[YAMLKeyword.validation_threshold] = threshold_dict

L
liuqi 已提交
464 465 466 467 468 469 470 471
            validation_inputs_data = subgraph.get(
                YAMLKeyword.validation_inputs_data, [])
            if not isinstance(validation_inputs_data, list):
                subgraph[YAMLKeyword.validation_inputs_data] = [
                    validation_inputs_data]
            else:
                subgraph[YAMLKeyword.validation_inputs_data] = \
                    validation_inputs_data
472 473 474 475 476 477
            input_ranges = subgraph.get(
                YAMLKeyword.input_ranges, [])
            if not isinstance(input_ranges, list):
                subgraph[YAMLKeyword.input_ranges] = [input_ranges]
            else:
                subgraph[YAMLKeyword.input_ranges] = input_ranges
L
liuqi 已提交
478
            subgraph[YAMLKeyword.input_ranges] = \
479
                [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
W
wuchenghui 已提交
480

481 482 483
        for key in [YAMLKeyword.limit_opencl_kernel_time,
                    YAMLKeyword.nnlib_graph_mode,
                    YAMLKeyword.obfuscate,
李寅 已提交
484
                    YAMLKeyword.winograd,
485 486
                    YAMLKeyword.quantize,
                    YAMLKeyword.change_concat_ranges]:
487 488 489
            value = model_config.get(key, "")
            if value == "":
                model_config[key] = 0
L
Liangliang He 已提交
490

491 492 493 494 495 496
        mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
                   ModuleName.YAML_CONFIG,
                   "'winograd' parameters must be in "
                   + str(WinogradParameters) +
                   ". 0 for disable winograd convolution")

L
liuqi 已提交
497 498
        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        model_config[YAMLKeyword.weight_file_path] = weight_file_path
Y
yejianwu 已提交
499

500
    return configs
Y
yejianwu 已提交
501

W
wuchenghui 已提交
502

503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
def clear_build_dirs(library_name):
    # make build dir
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    # clear temp build dir
    tmp_build_dir = os.path.join(BUILD_OUTPUT_DIR, library_name,
                                 BUILD_TMP_DIR_NAME)
    if os.path.exists(tmp_build_dir):
        sh.rm('-rf', tmp_build_dir)
    os.makedirs(tmp_build_dir)
    # clear lib dir
    lib_output_dir = os.path.join(
        BUILD_OUTPUT_DIR, library_name, OUTPUT_LIBRARY_DIR_NAME)
    if os.path.exists(lib_output_dir):
        sh.rm('-rf', lib_output_dir)


################################
# convert
################################
def print_configuration(configs):
524 525 526 527 528 529 530 531 532
    title = "Common Configuration"
    header = ["key", "value"]
    data = list()
    data.append([YAMLKeyword.library_name,
                 configs[YAMLKeyword.library_name]])
    data.append([YAMLKeyword.target_abis,
                 configs[YAMLKeyword.target_abis]])
    data.append([YAMLKeyword.target_socs,
                 configs[YAMLKeyword.target_socs]])
533 534 535 536
    data.append([YAMLKeyword.model_graph_format,
                 configs[YAMLKeyword.model_graph_format]])
    data.append([YAMLKeyword.model_data_format,
                 configs[YAMLKeyword.model_data_format]])
537
    MaceLogger.summary(StringFormatter.table(header, data, title))
L
Liangliang He 已提交
538

Y
yejianwu 已提交
539

540 541 542 543
def download_file(url, dst, num_retries=3):
    from six.moves import urllib

    try:
544
        urllib.request.urlretrieve(url, dst)
545
        MaceLogger.info('\nDownloaded successfully.')
L
liuqi 已提交
546 547
    except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
            urllib.error.URLError) as e:
548
        MaceLogger.warning('Download error:' + str(e))
549 550 551 552 553 554 555
        if num_retries > 0:
            return download_file(url, dst, num_retries - 1)
        else:
            return False
    return True


B
Bin Li 已提交
556 557 558 559 560 561 562
def get_model_files(model_file_path,
                    model_sha256_checksum,
                    model_output_dir,
                    weight_file_path="",
                    weight_sha256_checksum=""):
    model_file = model_file_path
    weight_file = weight_file_path
563 564 565

    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
B
Bin Li 已提交
566 567 568 569
        model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
        if not os.path.exists(model_file) or \
                sha256_checksum(model_file) != model_sha256_checksum:
            MaceLogger.info("Downloading model, please wait ...")
570 571 572
            if not download_file(model_file_path, model_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
573 574 575 576

    if sha256_checksum(model_file) != model_sha256_checksum:
        MaceLogger.error(ModuleName.MODEL_CONVERTER,
                         "model file sha256checksum not match")
L
Liangliang He 已提交
577 578 579

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
B
Bin Li 已提交
580 581 582 583 584
        weight_file = \
            model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
        if not os.path.exists(weight_file) or \
                sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.info("Downloading model weight, please wait ...")
585 586 587
            if not download_file(weight_file_path, weight_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
588 589 590 591 592

    if weight_file:
        if sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "weight file sha256checksum not match")
Y
yejianwu 已提交
593 594

    return model_file, weight_file
L
Liangliang He 已提交
595

L
liuqi 已提交
596

597
def convert_model(configs, cl_mem_type):
598 599 600 601
    # Remove previous output dirs
    library_name = configs[YAMLKeyword.library_name]
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
L
liuqi 已提交
602 603 604
    elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
        sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
    os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
B
Bin Li 已提交
605 606
    if not os.path.exists(BUILD_DOWNLOADS_DIR):
        os.makedirs(BUILD_DOWNLOADS_DIR)
607 608 609

    model_output_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
L
liuqi 已提交
610 611
    model_header_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_HEADER_DIR_PATH)
612
    # clear output dir
613 614 615
    if os.path.exists(model_output_dir):
        sh.rm("-rf", model_output_dir)
    os.makedirs(model_output_dir)
L
liuqi 已提交
616 617
    if os.path.exists(model_header_dir):
        sh.rm("-rf", model_header_dir)
618 619 620 621 622 623 624 625 626 627 628 629 630 631

    embed_model_data = \
        configs[YAMLKeyword.model_data_format] == ModelFormat.code

    if os.path.exists(MODEL_CODEGEN_DIR):
        sh.rm("-rf", MODEL_CODEGEN_DIR)
    if os.path.exists(ENGINE_CODEGEN_DIR):
        sh.rm("-rf", ENGINE_CODEGEN_DIR)

    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        os.makedirs(model_header_dir)
        sh_commands.gen_mace_engine_factory_source(
            configs[YAMLKeyword.models].keys(),
            embed_model_data)
L
liuqi 已提交
632 633 634
        sh.cp("-f", glob.glob("mace/codegen/engine/*.h"),
              model_header_dir)

635 636 637 638 639
    for model_name in configs[YAMLKeyword.models]:
        MaceLogger.header(
            StringFormatter.block("Convert %s model" % model_name))
        model_config = configs[YAMLKeyword.models][model_name]
        runtime = model_config[YAMLKeyword.runtime]
640 641 642 643
        if cl_mem_type:
            model_config[YAMLKeyword.cl_mem_type] = cl_mem_type
        else:
            model_config[YAMLKeyword.cl_mem_type] = "image"
644

B
Bin Li 已提交
645
        model_file_path, weight_file_path = get_model_files(
646
            model_config[YAMLKeyword.model_file_path],
B
Bin Li 已提交
647 648 649 650
            model_config[YAMLKeyword.model_sha256_checksum],
            BUILD_DOWNLOADS_DIR,
            model_config[YAMLKeyword.weight_file_path],
            model_config[YAMLKeyword.weight_sha256_checksum])
651 652 653 654 655

        data_type = model_config[YAMLKeyword.data_type]
        # TODO(liuqi): support multiple subgraphs
        subgraphs = model_config[YAMLKeyword.subgraphs]

L
liuqi 已提交
656
        model_codegen_dir = "%s/%s" % (MODEL_CODEGEN_DIR, model_name)
657 658 659 660 661 662 663 664
        sh_commands.gen_model_code(
            model_codegen_dir,
            model_config[YAMLKeyword.platform],
            model_file_path,
            weight_file_path,
            model_config[YAMLKeyword.model_sha256_checksum],
            model_config[YAMLKeyword.weight_sha256_checksum],
            ",".join(subgraphs[0][YAMLKeyword.input_tensors]),
665
            ",".join(subgraphs[0][YAMLKeyword.input_data_formats]),
666
            ",".join(subgraphs[0][YAMLKeyword.output_tensors]),
667
            ",".join(subgraphs[0][YAMLKeyword.output_data_formats]),
B
Bin Li 已提交
668
            ",".join(subgraphs[0][YAMLKeyword.check_tensors]),
669 670 671
            runtime,
            model_name,
            ":".join(subgraphs[0][YAMLKeyword.input_shapes]),
李寅 已提交
672
            ":".join(subgraphs[0][YAMLKeyword.input_ranges]),
B
Bin Li 已提交
673 674
            ":".join(subgraphs[0][YAMLKeyword.output_shapes]),
            ":".join(subgraphs[0][YAMLKeyword.check_shapes]),
675 676 677
            model_config[YAMLKeyword.nnlib_graph_mode],
            embed_model_data,
            model_config[YAMLKeyword.winograd],
李寅 已提交
678 679
            model_config[YAMLKeyword.quantize],
            model_config.get(YAMLKeyword.quantize_range_file, ""),
680
            model_config[YAMLKeyword.change_concat_ranges],
681
            model_config[YAMLKeyword.obfuscate],
682
            configs[YAMLKeyword.model_graph_format],
李寅 已提交
683
            data_type,
684
            model_config[YAMLKeyword.cl_mem_type],
李寅 已提交
685
            ",".join(model_config.get(YAMLKeyword.graph_optimize_options, [])))
686

687
        if configs[YAMLKeyword.model_graph_format] == ModelFormat.file:
L
liuqi 已提交
688 689 690
            sh.mv("-f",
                  '%s/%s.pb' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
691 692 693
            sh.mv("-f",
                  '%s/%s.data' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
694 695
        else:
            if not embed_model_data:
L
liuqi 已提交
696
                sh.mv("-f",
L
liuqi 已提交
697
                      '%s/%s.data' % (model_codegen_dir, model_name),
L
liuqi 已提交
698
                      model_output_dir)
L
liuqi 已提交
699 700
            sh.cp("-f", glob.glob("mace/codegen/models/*/*.h"),
                  model_header_dir)
701

L
liuqi 已提交
702
        MaceLogger.summary(
703 704 705
            StringFormatter.block("Model %s converted" % model_name))


706 707
def build_model_lib(configs, address_sanitizer):
    MaceLogger.header(StringFormatter.block("Building model library"))
708

709 710 711 712
    # create model library dir
    library_name = configs[YAMLKeyword.library_name]
    for target_abi in configs[YAMLKeyword.target_abis]:
        hexagon_mode = get_hexagon_mode(configs)
L
liuqi 已提交
713 714 715 716 717
        model_lib_output_path = get_model_lib_output_path(library_name,
                                                          target_abi)
        library_out_dir = os.path.dirname(model_lib_output_path)
        if not os.path.exists(library_out_dir):
            os.makedirs(library_out_dir)
L
liuqi 已提交
718
        toolchain = infer_toolchain(target_abi)
719
        sh_commands.bazel_build(
720
            MODEL_LIB_TARGET,
721
            abi=target_abi,
L
liuqi 已提交
722
            toolchain=toolchain,
723
            hexagon_mode=hexagon_mode,
Y
yejianwu 已提交
724
            enable_opencl=get_opencl_mode(configs),
725
            enable_quantize=get_quantize_mode(configs),
726 727
            address_sanitizer=address_sanitizer,
            symbol_hidden=True
728 729
        )

730
        sh.cp("-f", MODEL_LIB_PATH, model_lib_output_path)
731 732 733 734 735 736 737


def print_library_summary(configs):
    library_name = configs[YAMLKeyword.library_name]
    title = "Library"
    header = ["key", "value"]
    data = list()
738 739 740 741 742 743 744 745 746
    data.append(["MACE Model Path",
                 "%s/%s/%s"
                 % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)])
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        data.append(["MACE Model Header Path",
                     "%s/%s/%s"
                     % (BUILD_OUTPUT_DIR, library_name,
                        MODEL_HEADER_DIR_PATH)])

747 748 749
    MaceLogger.summary(StringFormatter.table(header, data, title))


750
def convert_func(flags):
751
    configs = format_model_config(flags)
752

753
    print_configuration(configs)
754

755
    convert_model(configs, flags.cl_mem_type)
756

757 758
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        build_model_lib(configs, flags.address_sanitizer)
759 760 761 762 763 764 765 766 767 768 769 770

    print_library_summary(configs)


################################
# run
################################
def report_run_statistics(stdout,
                          abi,
                          serialno,
                          model_name,
                          device_type,
771 772
                          output_dir,
                          tuned):
773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793
    metrics = [0] * 3
    for line in stdout.split('\n'):
        line = line.strip()
        parts = line.split()
        if len(parts) == 4 and parts[0].startswith("time"):
            metrics[0] = str(float(parts[1]))
            metrics[1] = str(float(parts[2]))
            metrics[2] = str(float(parts[3]))
            break

    device_name = ""
    target_soc = ""
    if abi != "host":
        props = sh_commands.adb_getprop_by_serialno(serialno)
        device_name = props.get("ro.product.model", "")
        target_soc = props.get("ro.board.platform", "")

    report_filename = output_dir + "/report.csv"
    if not os.path.exists(report_filename):
        with open(report_filename, 'w') as f:
            f.write("model_name,device_name,soc,abi,runtime,"
794
                    "init(ms),warmup(ms),run_avg(ms),tuned\n")
795 796

    data_str = "{model_name},{device_name},{soc},{abi},{device_type}," \
797
               "{init},{warmup},{run_avg},{tuned}\n" \
798 799 800 801 802 803 804
        .format(model_name=model_name,
                device_name=device_name,
                soc=target_soc,
                abi=abi,
                device_type=device_type,
                init=metrics[0],
                warmup=metrics[1],
805
                run_avg=metrics[2],
806
                tuned=tuned)
807 808 809 810
    with open(report_filename, 'a') as f:
        f.write(data_str)


L
liuqi 已提交
811 812
def build_mace_run(configs, target_abi, toolchain, enable_openmp,
                   address_sanitizer, mace_lib_type):
813 814 815 816 817 818 819 820
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

821
    symbol_hidden = True
822 823
    mace_run_target = MACE_RUN_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
824
        symbol_hidden = False
825 826 827 828 829
        mace_run_target = MACE_RUN_DYNAMIC_TARGET
    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
L
liuqi 已提交
830
                   "You should convert model first.")
831 832 833 834 835
        build_arg = "--per_file_copt=mace/tools/validation/mace_run.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(
        mace_run_target,
        abi=target_abi,
L
liuqi 已提交
836
        toolchain=toolchain,
837 838
        hexagon_mode=hexagon_mode,
        enable_openmp=enable_openmp,
Y
yejianwu 已提交
839
        enable_opencl=get_opencl_mode(configs),
840
        enable_quantize=get_quantize_mode(configs),
841
        address_sanitizer=address_sanitizer,
842
        symbol_hidden=symbol_hidden,
843 844 845 846 847 848
        extra_args=build_arg
    )
    sh_commands.update_mace_run_binary(build_tmp_binary_dir,
                                       mace_lib_type == MACELibType.dynamic)


L
liuqi 已提交
849 850
def build_example(configs, target_abi, toolchain,
                  enable_openmp, mace_lib_type):
851 852 853 854 855 856 857 858
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

859
    symbol_hidden = True
L
liuqi 已提交
860

861 862
    libmace_target = LIBMACE_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
863
        symbol_hidden = False
864 865 866 867
        libmace_target = LIBMACE_SO_TARGET

    sh_commands.bazel_build(libmace_target,
                            abi=target_abi,
L
liuqi 已提交
868
                            toolchain=toolchain,
869
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
870
                            enable_opencl=get_opencl_mode(configs),
871
                            enable_quantize=get_quantize_mode(configs),
872
                            hexagon_mode=hexagon_mode,
L
liuqi 已提交
873
                            address_sanitizer=flags.address_sanitizer,
874
                            symbol_hidden=symbol_hidden)
875 876 877 878 879 880

    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)
    sh.mkdir("-p", LIB_CODEGEN_DIR)

    build_arg = ""
881 882 883 884 885 886 887 888 889
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
                   "You should convert model first.")
        model_lib_path = get_model_lib_output_path(library_name,
                                                   target_abi)
        sh.cp("-f", model_lib_path, LIB_CODEGEN_DIR)
        build_arg = "--per_file_copt=mace/examples/cli/example.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

890 891 892 893 894
    if mace_lib_type == MACELibType.dynamic:
        example_target = EXAMPLE_DYNAMIC_TARGET
        sh.cp("-f", LIBMACE_DYNAMIC_PATH, LIB_CODEGEN_DIR)
    else:
        example_target = EXAMPLE_STATIC_TARGET
895
        sh.cp("-f", LIBMACE_STATIC_PATH, LIB_CODEGEN_DIR)
896 897 898

    sh_commands.bazel_build(example_target,
                            abi=target_abi,
L
liuqi 已提交
899
                            toolchain=toolchain,
900
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
901
                            enable_opencl=get_opencl_mode(configs),
902
                            enable_quantize=get_quantize_mode(configs),
903
                            hexagon_mode=hexagon_mode,
L
liuqi 已提交
904
                            address_sanitizer=flags.address_sanitizer,
905 906 907 908 909 910 911 912
                            extra_args=build_arg)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(example_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)
    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)


L
liuqi 已提交
913 914 915 916 917 918 919 920 921 922
def print_package_summary(package_path):
    title = "Library"
    header = ["key", "value"]
    data = list()
    data.append(["MACE Model package Path",
                 package_path])

    MaceLogger.summary(StringFormatter.table(header, data, title))


923
def run_mace(flags):
924
    configs = format_model_config(flags)
925 926

    clear_build_dirs(configs[YAMLKeyword.library_name])
927 928

    target_socs = configs[YAMLKeyword.target_socs]
929 930
    device_list = DeviceManager.list_devices(flags.device_yml)
    if target_socs and ALL_SOC_TAG not in target_socs:
L
liuqi 已提交
931 932
        device_list = [dev for dev in device_list
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
933
    for target_abi in configs[YAMLKeyword.target_abis]:
934
        # build target
L
liuqi 已提交
935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959
        for dev in device_list:
            if target_abi in dev[YAMLKeyword.target_abis]:
                # get toolchain
                toolchain = infer_toolchain(target_abi)
                if flags.example:
                    build_example(configs,
                                  target_abi,
                                  toolchain,
                                  not flags.disable_openmp,
                                  flags.mace_lib_type)
                else:
                    build_mace_run(configs,
                                   target_abi,
                                   toolchain,
                                   not flags.disable_openmp,
                                   flags.address_sanitizer,
                                   flags.mace_lib_type)
                # run
                device = DeviceWrapper(dev)
                with device.lock():
                    device.run_specify_abi(flags, configs, target_abi)
            elif dev[YAMLKeyword.device_name] != SystemType.host:
                six.print_('The device with soc %s do not support abi %s' %
                           (dev[YAMLKeyword.target_socs], target_abi),
                           file=sys.stderr)
960

L
liuqi 已提交
961 962 963 964 965
    # package the output files
    package_path = sh_commands.packaging_lib(BUILD_OUTPUT_DIR,
                                             configs[YAMLKeyword.library_name])
    print_package_summary(package_path)

966 967 968 969

################################
#  benchmark model
################################
L
liuqi 已提交
970 971 972 973 974
def build_benchmark_model(configs,
                          target_abi,
                          toolchain,
                          enable_openmp,
                          mace_lib_type):
975 976 977 978 979
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    link_dynamic = mace_lib_type == MACELibType.dynamic
    if link_dynamic:
Y
yejianwu 已提交
980
        symbol_hidden = False
981 982
        benchmark_target = BM_MODEL_DYNAMIC_TARGET
    else:
Y
yejianwu 已提交
983
        symbol_hidden = True
984 985 986 987 988 989
        benchmark_target = BM_MODEL_STATIC_TARGET

    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.BENCHMARK,
L
liuqi 已提交
990
                   "You should convert model first.")
991 992 993 994
        build_arg = "--per_file_copt=mace/benchmark/benchmark_model.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(benchmark_target,
                            abi=target_abi,
L
liuqi 已提交
995
                            toolchain=toolchain,
996
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
997
                            enable_opencl=get_opencl_mode(configs),
998
                            enable_quantize=get_quantize_mode(configs),
999
                            hexagon_mode=hexagon_mode,
Y
yejianwu 已提交
1000
                            symbol_hidden=symbol_hidden,
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
                            extra_args=build_arg)
    # clear tmp binary dir
    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(benchmark_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)


1012
def benchmark_model(flags):
1013
    configs = format_model_config(flags)
1014 1015

    clear_build_dirs(configs[YAMLKeyword.library_name])
1016 1017

    target_socs = configs[YAMLKeyword.target_socs]
1018 1019
    device_list = DeviceManager.list_devices(flags.device_yml)
    if target_socs and ALL_SOC_TAG not in target_socs:
L
liuqi 已提交
1020
        device_list = [dev for dev in device_list
1021
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
1022 1023

    for target_abi in configs[YAMLKeyword.target_abis]:
1024
        # build benchmark_model binary
L
liuqi 已提交
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039
        for dev in device_list:
            if target_abi in dev[YAMLKeyword.target_abis]:
                toolchain = infer_toolchain(target_abi)
                build_benchmark_model(configs,
                                      target_abi,
                                      toolchain,
                                      not flags.disable_openmp,
                                      flags.mace_lib_type)
                device = DeviceWrapper(dev)
                with device.lock():
                    device.bm_specific_target(flags, configs, target_abi)
            else:
                six.print_('There is no abi %s with soc %s' %
                           (target_abi, dev[YAMLKeyword.target_socs]),
                           file=sys.stderr)
L
liuqi 已提交
1040

1041

L
liuqi 已提交
1042
################################
Y
yejianwu 已提交
1043
# parsing arguments
L
liuqi 已提交
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
1056
        return CaffeEnvType.DOCKER
L
liuqi 已提交
1057
    elif v.lower() == 'local':
1058
        return CaffeEnvType.LOCAL
L
liuqi 已提交
1059 1060 1061 1062
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


1063 1064 1065 1066 1067 1068 1069 1070 1071
def str_to_mace_lib_type(v):
    if v.lower() == 'dynamic':
        return MACELibType.dynamic
    elif v.lower() == 'static':
        return MACELibType.static
    else:
        raise argparse.ArgumentTypeError('[dynamic| static] expected.')


1072
def parse_args():
L
Liangliang He 已提交
1073
    """Parses command line arguments."""
1074 1075 1076
    all_type_parent_parser = argparse.ArgumentParser(add_help=False)
    all_type_parent_parser.add_argument(
        '--config',
L
Liangliang He 已提交
1077
        type=str,
1078
        default="",
L
liuqi 已提交
1079
        required=True,
1080
        help="the path of model yaml configuration file.")
1081
    all_type_parent_parser.add_argument(
1082
        "--model_graph_format",
1083 1084
        type=str,
        default="",
1085 1086 1087 1088 1089 1090
        help="[file, code], MACE Model graph format.")
    all_type_parent_parser.add_argument(
        "--model_data_format",
        type=str,
        default="",
        help="['file', 'code'], MACE Model data format.")
1091 1092 1093 1094 1095
    all_type_parent_parser.add_argument(
        "--target_abis",
        type=str,
        default="",
        help="Target ABIs, comma seperated list.")
1096 1097 1098 1099 1100
    all_type_parent_parser.add_argument(
        "--target_socs",
        type=str,
        default="",
        help="Target SOCs, comma seperated list.")
1101 1102
    convert_run_parent_parser = argparse.ArgumentParser(add_help=False)
    convert_run_parent_parser.add_argument(
1103 1104
        '--address_sanitizer',
        action="store_true",
L
liuqi 已提交
1105
        help="Whether to use address sanitizer to check memory error")
1106
    run_bm_parent_parser = argparse.ArgumentParser(add_help=False)
1107 1108 1109 1110 1111 1112 1113 1114 1115
    run_bm_parent_parser.add_argument(
        "--mace_lib_type",
        type=str_to_mace_lib_type,
        default=DefaultValues.mace_lib_type,
        help="[static | dynamic], Which type MACE library to use.")
    run_bm_parent_parser.add_argument(
        "--disable_openmp",
        action="store_true",
        help="Disable openmp for multiple thread.")
1116
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1117 1118
        "--omp_num_threads",
        type=int,
1119
        default=DefaultValues.omp_num_threads,
W
wuchenghui 已提交
1120
        help="num of openmp threads")
1121
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1122 1123
        "--cpu_affinity_policy",
        type=int,
1124
        default=DefaultValues.cpu_affinity_policy,
W
wuchenghui 已提交
1125
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
1126
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1127 1128
        "--gpu_perf_hint",
        type=int,
1129
        default=DefaultValues.gpu_perf_hint,
W
wuchenghui 已提交
1130
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
1131
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1132 1133
        "--gpu_priority_hint",
        type=int,
1134
        default=DefaultValues.gpu_priority_hint,
W
wuchenghui 已提交
1135
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liuqi 已提交
1136 1137 1138 1139 1140 1141
    run_bm_parent_parser.add_argument(
        "--device_yml",
        type=str,
        default='',
        help='embedded linux device config yml file'
    )
1142 1143
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
1144 1145 1146 1147
    convert = subparsers.add_parser(
        'convert',
        parents=[all_type_parent_parser, convert_run_parent_parser],
        help='convert to mace model (file or code)')
1148 1149 1150 1151 1152
    convert.add_argument(
        "--cl_mem_type",
        type=str,
        default=None,
        help="Which type of OpenCL memory type to use [image | buffer].")
1153
    convert.set_defaults(func=convert_func)
1154 1155 1156
    run = subparsers.add_parser(
        'run',
        parents=[all_type_parent_parser, run_bm_parent_parser,
1157
                 convert_run_parent_parser],
1158 1159
        help='run model in command line')
    run.set_defaults(func=run_mace)
1160 1161 1162 1163
    run.add_argument(
        "--disable_tuning",
        action="store_true",
        help="Disable tuning for specific thread.")
1164 1165
    run.add_argument(
        "--round",
L
Liangliang He 已提交
1166
        type=int,
1167 1168 1169 1170 1171
        default=1,
        help="The model running round.")
    run.add_argument(
        "--validate",
        action="store_true",
1172 1173
        help="whether to verify the results are consistent with "
             "the frameworks.")
1174
    run.add_argument(
L
liuqi 已提交
1175 1176 1177
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
1178 1179
        help="[docker | local] you can specific caffe environment for"
             " validation. local environment or caffe docker image.")
1180 1181 1182 1183
    run.add_argument(
        "--vlog_level",
        type=int,
        default=0,
1184
        help="[1~5]. Verbose log level for debug.")
1185
    run.add_argument(
L
Liangliang He 已提交
1186
        "--gpu_out_of_range_check",
1187 1188 1189 1190 1191 1192
        action="store_true",
        help="Enable out of memory check for gpu.")
    run.add_argument(
        "--restart_round",
        type=int,
        default=1,
1193
        help="restart round between run.")
1194 1195 1196 1197 1198 1199
    run.add_argument(
        "--report",
        action="store_true",
        help="print run statistics report.")
    run.add_argument(
        "--report_dir",
1200 1201
        type=str,
        default="",
1202 1203
        help="print run statistics report.")
    run.add_argument(
李寅 已提交
1204 1205 1206 1207
        "--runtime_failure_ratio",
        type=float,
        default=0.0,
        help="[mock runtime failure ratio].")
L
liuqi 已提交
1208 1209 1210 1211
    run.add_argument(
        "--example",
        action="store_true",
        help="whether to run example.")
李寅 已提交
1212 1213 1214 1215 1216 1217 1218 1219 1220
    run.add_argument(
        "--quantize_stat",
        action="store_true",
        help="whether to stat quantization range.")
    run.add_argument(
        "--input_dir",
        type=str,
        default="",
        help="quantize stat input dir.")
1221 1222 1223 1224 1225
    run.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="quantize stat output dir.")
1226 1227
    benchmark = subparsers.add_parser(
        'benchmark',
1228
        parents=[all_type_parent_parser, run_bm_parent_parser],
1229 1230
        help='benchmark model for detail information')
    benchmark.set_defaults(func=benchmark_model)
L
Liangliang He 已提交
1231 1232
    return parser.parse_known_args()

1233

Y
yejianwu 已提交
1234
if __name__ == "__main__":
1235 1236
    flags, unparsed = parse_args()
    flags.func(flags)