converter.py 46.1 KB
Newer Older
Y
yejianwu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright 2018 Xiaomi, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

import argparse
L
liuqi 已提交
16
import glob
17
import hashlib
18
import os
L
liuqi 已提交
19
import re
L
Liangliang He 已提交
20
import sh
21
import sys
22
import urllib
Y
yejianwu 已提交
23
import yaml
L
liuqi 已提交
24

25
from enum import Enum
26
import six
27

28
import sh_commands
L
Liangliang He 已提交
29

L
liuqi 已提交
30 31
from common import *
from device import DeviceWrapper, DeviceManager
32

Y
yejianwu 已提交
33 34 35 36
################################
# set environment
################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
A
Allen 已提交
37

38 39 40 41 42
################################
# common definitions
################################

ABITypeStrs = [
L
liuqi 已提交
43 44
    'armeabi-v7a',
    'arm64-v8a',
L
liuqi 已提交
45 46
    'arm64',
    'armhf',
L
liuqi 已提交
47
    'host',
48
]
L
liuqi 已提交
49

50 51 52 53 54
ModelFormatStrs = [
    "file",
    "code",
]

55 56 57
PlatformTypeStrs = [
    "tensorflow",
    "caffe",
L
liutuo 已提交
58
    "onnx",
59 60 61 62 63 64 65 66 67 68 69
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
                    type=str)

RuntimeTypeStrs = [
    "cpu",
    "gpu",
    "dsp",
    "cpu+gpu"
]

Y
yejianwu 已提交
70 71 72 73 74 75 76 77 78
InputDataTypeStrs = [
    "int32",
    "float32",
]

InputDataType = Enum('InputDataType',
                     [(ele, ele) for ele in InputDataTypeStrs],
                     type=str)

L
liuqi 已提交
79
FPDataTypeStrs = [
80 81 82 83
    "fp16_fp32",
    "fp32_fp32",
]

L
liuqi 已提交
84 85
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
                  type=str)
86

L
liuqi 已提交
87 88 89 90 91 92 93
DSPDataTypeStrs = [
    "uint8",
]

DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
                   type=str)

94 95
WinogradParameters = [0, 2, 4]

96 97 98 99 100 101 102 103 104 105
DataFormatStrs = [
    "NONE",
    "NHWC",
]


class DataFormat(object):
    NONE = "NONE"
    NHWC = "NHWC"

106 107

class DefaultValues(object):
108
    mace_lib_type = MACELibType.static
109 110 111 112 113 114
    omp_num_threads = -1,
    cpu_affinity_policy = 1,
    gpu_perf_hint = 3,
    gpu_priority_hint = 3,


115 116 117 118 119 120 121
class ValidationThreshold(object):
    cpu_threshold = 0.999,
    gpu_threshold = 0.995,
    hexagon_threshold = 0.930,
    cpu_quantize_threshold = 0.980,


122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
CPP_KEYWORDS = [
    'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
    'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor',
    'bool', 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t',
    'class', 'compl', 'concept', 'const', 'constexpr', 'const_cast',
    'continue', 'co_await', 'co_return', 'co_yield', 'decltype', 'default',
    'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit',
    'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if',
    'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace',
    'new', 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq',
    'private', 'protected', 'public', 'register', 'reinterpret_cast',
    'requires', 'return', 'short', 'signed', 'sizeof', 'static',
    'static_assert', 'static_cast', 'struct', 'switch', 'synchronized',
    'template', 'this', 'thread_local', 'throw', 'true', 'try', 'typedef',
    'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void',
    'volatile', 'wchar_t', 'while', 'xor', 'xor_eq', 'override', 'final',
    'transaction_safe', 'transaction_safe_dynamic', 'if', 'elif', 'else',
    'endif', 'defined', 'ifdef', 'ifndef', 'define', 'undef', 'include',
    'line', 'error', 'pragma',
]
Y
yejianwu 已提交
142

143

144 145 146
################################
# common functions
################################
147
def parse_device_type(runtime):
Y
yejianwu 已提交
148
    device_type = ""
149

150
    if runtime == RuntimeType.dsp:
151
        device_type = DeviceType.HEXAGON
152
    elif runtime == RuntimeType.gpu:
153
        device_type = DeviceType.GPU
154
    elif runtime == RuntimeType.cpu:
155
        device_type = DeviceType.CPU
156

157
    return device_type
158

Y
yejianwu 已提交
159 160

def get_hexagon_mode(configs):
L
Liangliang He 已提交
161
    runtime_list = []
L
liuqi 已提交
162
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
163
        model_runtime = \
L
liuqi 已提交
164 165
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
L
Liangliang He 已提交
166 167
        runtime_list.append(model_runtime.lower())

L
liuqi 已提交
168
    if RuntimeType.dsp in runtime_list:
Y
yejianwu 已提交
169 170 171 172
        return True
    return False


Y
yejianwu 已提交
173 174 175
def get_opencl_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
176
        model_runtime = \
Y
yejianwu 已提交
177 178 179 180 181 182 183 184 185
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.gpu in runtime_list or RuntimeType.cpu_gpu in runtime_list:
        return True
    return False


186 187 188 189 190 191 192 193 194 195 196
def get_quantize_mode(configs):
    for model_name in configs[YAMLKeyword.models]:
        quantize =\
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.quantize, 0)
        if quantize == 1:
            return True

    return False


197 198
def md5sum(str):
    md5 = hashlib.md5()
199
    md5.update(str.encode('utf-8'))
200
    return md5.hexdigest()
201

Y
yejianwu 已提交
202

203 204 205 206 207 208
def sha256_checksum(fname):
    hash_func = hashlib.sha256()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_func.update(chunk)
    return hash_func.hexdigest()
Y
yejianwu 已提交
209

W
wuchenghui 已提交
210

211 212
def format_model_config(flags):
    with open(flags.config) as f:
213
        configs = yaml.load(f)
W
wuchenghui 已提交
214

215 216
    library_name = configs.get(YAMLKeyword.library_name, "")
    mace_check(len(library_name) > 0,
L
liuqi 已提交
217
               ModuleName.YAML_CONFIG, "library name should not be empty")
218

219 220 221 222
    if flags.target_abis:
        target_abis = flags.target_abis.split(',')
    else:
        target_abis = configs.get(YAMLKeyword.target_abis, [])
223 224
    mace_check((isinstance(target_abis, list) and len(target_abis) > 0),
               ModuleName.YAML_CONFIG, "target_abis list is needed")
225
    configs[YAMLKeyword.target_abis] = target_abis
226 227 228 229 230 231
    for abi in target_abis:
        mace_check(abi in ABITypeStrs,
                   ModuleName.YAML_CONFIG,
                   "target_abis must be in " + str(ABITypeStrs))

    target_socs = configs.get(YAMLKeyword.target_socs, "")
232 233
    if flags.target_socs:
        configs[YAMLKeyword.target_socs] = \
L
liuqi 已提交
234
            [soc.lower() for soc in flags.target_socs.split(',')]
235
    elif not target_socs:
236 237 238 239
        configs[YAMLKeyword.target_socs] = []
    elif not isinstance(target_socs, list):
        configs[YAMLKeyword.target_socs] = [target_socs]

240 241 242
    configs[YAMLKeyword.target_socs] = \
        [soc.lower() for soc in configs[YAMLKeyword.target_socs]]

L
liuqi 已提交
243 244
    if ABIType.armeabi_v7a in target_abis \
            or ABIType.arm64_v8a in target_abis:
245
        available_socs = sh_commands.adb_get_all_socs()
246 247 248 249
        target_socs = configs[YAMLKeyword.target_socs]
        if ALL_SOC_TAG in target_socs:
            mace_check(available_socs,
                       ModuleName.YAML_CONFIG,
L
liuqi 已提交
250 251 252
                       "Android abi is listed in config file and "
                       "build for all SOCs plugged in computer, "
                       "But no android phone found, "
253 254
                       "you at least plug in one phone")
        else:
255 256 257 258 259 260
            for soc in target_socs:
                mace_check(soc in available_socs,
                           ModuleName.YAML_CONFIG,
                           "Build specified SOC library, "
                           "you must plug in a phone using the SOC")

261 262
    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
263
    else:
264 265 266 267 268 269 270 271
        model_graph_format = configs.get(YAMLKeyword.model_graph_format, "")
    mace_check(model_graph_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_graph_format and '
               "model_graph_format must be in " + str(ModelFormatStrs))
    configs[YAMLKeyword.model_graph_format] = model_graph_format
    if flags.model_data_format:
        model_data_format = flags.model_data_format
272
    else:
273 274 275 276 277 278 279 280 281 282 283 284
        model_data_format = configs.get(YAMLKeyword.model_data_format, "")
    configs[YAMLKeyword.model_data_format] = model_data_format
    mace_check(model_data_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_data_format and '
               "model_data_format must be in " + str(ModelFormatStrs))

    mace_check(not (model_graph_format == ModelFormat.file
                    and model_data_format == ModelFormat.code),
               ModuleName.YAML_CONFIG,
               "If model_graph format is 'file',"
               " the model_data_format must be 'file' too")
Y
yejianwu 已提交
285

286 287 288 289
    model_names = configs.get(YAMLKeyword.models, [])
    mace_check(len(model_names) > 0, ModuleName.YAML_CONFIG,
               "no model found in config file")

L
liuqi 已提交
290
    model_name_reg = re.compile(r'^[a-zA-Z0-9_]+$')
291 292 293 294 295 296 297 298
    for model_name in model_names:
        # check model_name legality
        mace_check(model_name not in CPP_KEYWORDS,
                   ModuleName.YAML_CONFIG,
                   "model name should not be c++ keyword.")
        mace_check((model_name[0] == '_' or model_name[0].isalpha())
                   and bool(model_name_reg.match(model_name)),
                   ModuleName.YAML_CONFIG,
L
liuqi 已提交
299
                   "model name should Meet the c++ naming convention"
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
                   " which start with '_' or alpha"
                   " and only contain alpha, number and '_'")

        model_config = configs[YAMLKeyword.models][model_name]
        platform = model_config.get(YAMLKeyword.platform, "")
        mace_check(platform in PlatformTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'platform' must be in " + str(PlatformTypeStrs))

        for key in [YAMLKeyword.model_file_path,
                    YAMLKeyword.model_sha256_checksum]:
            value = model_config.get(key, "")
            mace_check(value != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" % key)

        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        if weight_file_path:
L
liuqi 已提交
317
            weight_checksum = \
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
                model_config.get(YAMLKeyword.weight_sha256_checksum, "")
            mace_check(weight_checksum != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" %
                       YAMLKeyword.weight_sha256_checksum)
        else:
            model_config[YAMLKeyword.weight_sha256_checksum] = ""

        runtime = model_config.get(YAMLKeyword.runtime, "")
        mace_check(runtime in RuntimeTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'runtime' must be in " + str(RuntimeTypeStrs))
        if ABIType.host in target_abis:
            mace_check(runtime == RuntimeType.cpu,
                       ModuleName.YAML_CONFIG,
                       "host only support cpu runtime now.")

        data_type = model_config.get(YAMLKeyword.data_type, "")
L
liuqi 已提交
335
        if runtime == RuntimeType.dsp:
L
liuqi 已提交
336 337 338 339 340 341 342 343
            if len(data_type) > 0:
                mace_check(data_type in DSPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(DSPDataTypeStrs)
                           + " for dsp runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    DSPDataType.uint8.value
L
liuqi 已提交
344 345 346 347 348 349 350 351 352 353 354 355 356
        else:
            if len(data_type) > 0:
                mace_check(data_type in FPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(FPDataTypeStrs)
                           + " for cpu runtime")
            else:
                if runtime == RuntimeType.cpu:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp32_fp32.value
                else:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp16_fp32.value
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371

        subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
        mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
                   "at least one subgraph is needed")

        for subgraph in subgraphs:
            for key in [YAMLKeyword.input_tensors,
                        YAMLKeyword.input_shapes,
                        YAMLKeyword.output_tensors,
                        YAMLKeyword.output_shapes]:
                value = subgraph.get(key, "")
                mace_check(value != "", ModuleName.YAML_CONFIG,
                           "'%s' is necessary in subgraph" % key)
                if not isinstance(value, list):
                    subgraph[key] = [value]
372
                subgraph[key] = [str(v) for v in subgraph[key]]
373

B
Bin Li 已提交
374 375 376 377 378 379 380 381 382 383
            for key in [YAMLKeyword.check_tensors,
                        YAMLKeyword.check_shapes]:
                value = subgraph.get(key, "")
                if value != "":
                    if not isinstance(value, list):
                        subgraph[key] = [value]
                    subgraph[key] = [str(v) for v in subgraph[key]]
                else:
                    subgraph[key] = []

Y
yejianwu 已提交
384 385 386 387
            input_data_types = subgraph.get(YAMLKeyword.input_data_types, "")
            if input_data_types:
                if not isinstance(input_data_types, list):
                    subgraph[YAMLKeyword.input_data_types] = [input_data_types]
388
                for input_data_type in subgraph[YAMLKeyword.input_data_types]:
Y
yejianwu 已提交
389 390 391 392 393 394 395
                    mace_check(input_data_type in InputDataTypeStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_types' must be in "
                               + str(InputDataTypeStrs))
            else:
                subgraph[YAMLKeyword.input_data_types] = []

396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
            input_data_formats = subgraph.get(YAMLKeyword.input_data_formats,
                                              [])
            if input_data_formats:
                if not isinstance(input_data_formats, list):
                    subgraph[YAMLKeyword.input_data_formats] =\
                        [input_data_formats]
                else:
                    mace_check(len(input_data_formats)
                               == len(subgraph[YAMLKeyword.input_tensors]),
                               ModuleName.YAML_CONFIG,
                               "input_data_formats should match"
                               " the size of input")
                for input_data_format in\
                        subgraph[YAMLKeyword.input_data_formats]:
                    mace_check(input_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs) + ", but got "
Y
yejianwu 已提交
414
                               + input_data_format)
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
            else:
                subgraph[YAMLKeyword.input_data_formats] = [DataFormat.NHWC]

            output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
                                               [])
            if output_data_formats:
                if not isinstance(output_data_formats, list):
                    subgraph[YAMLKeyword.output_data_formats] = \
                        [output_data_formats]
                else:
                    mace_check(len(output_data_formats)
                               == len(subgraph[YAMLKeyword.output_tensors]),
                               ModuleName.YAML_CONFIG,
                               "output_data_formats should match"
                               " the size of output")
                for output_data_format in\
                        subgraph[YAMLKeyword.output_data_formats]:
                    mace_check(output_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
Y
yejianwu 已提交
434
                               "'output_data_formats' must be in "
435 436 437 438
                               + str(DataFormatStrs))
            else:
                subgraph[YAMLKeyword.output_data_formats] = [DataFormat.NHWC]

439 440 441 442
            validation_threshold = subgraph.get(
                YAMLKeyword.validation_threshold, {})
            if not isinstance(validation_threshold, dict):
                raise argparse.ArgumentTypeError(
L
liuqi 已提交
443
                    'similarity threshold must be a dict.')
444 445

            threshold_dict = {
446 447
                DeviceType.CPU: ValidationThreshold.cpu_threshold,
                DeviceType.GPU: ValidationThreshold.gpu_threshold,
B
Bin Li 已提交
448 449
                DeviceType.HEXAGON + "_QUANTIZE":
                    ValidationThreshold.hexagon_threshold,
450 451
                DeviceType.CPU + "_QUANTIZE":
                    ValidationThreshold.cpu_quantize_threshold,
L
liuqi 已提交
452
            }
453 454 455 456 457
            for k, v in six.iteritems(validation_threshold):
                if k.upper() == 'DSP':
                    k = DeviceType.HEXAGON
                if k.upper() not in (DeviceType.CPU,
                                     DeviceType.GPU,
李寅 已提交
458 459
                                     DeviceType.HEXAGON,
                                     DeviceType.CPU + "_QUANTIZE"):
460
                    raise argparse.ArgumentTypeError(
L
liuqi 已提交
461
                        'Unsupported validation threshold runtime: %s' % k)
462 463 464 465
                threshold_dict[k.upper()] = v

            subgraph[YAMLKeyword.validation_threshold] = threshold_dict

L
liuqi 已提交
466 467 468 469 470 471 472 473
            validation_inputs_data = subgraph.get(
                YAMLKeyword.validation_inputs_data, [])
            if not isinstance(validation_inputs_data, list):
                subgraph[YAMLKeyword.validation_inputs_data] = [
                    validation_inputs_data]
            else:
                subgraph[YAMLKeyword.validation_inputs_data] = \
                    validation_inputs_data
L
liutuo 已提交
474 475 476 477

            onnx_backend = subgraph.get(
                YAMLKeyword.backend, "tensorflow")
            subgraph[YAMLKeyword.backend] = onnx_backend
478 479 480 481 482 483
            input_ranges = subgraph.get(
                YAMLKeyword.input_ranges, [])
            if not isinstance(input_ranges, list):
                subgraph[YAMLKeyword.input_ranges] = [input_ranges]
            else:
                subgraph[YAMLKeyword.input_ranges] = input_ranges
L
liuqi 已提交
484
            subgraph[YAMLKeyword.input_ranges] = \
485
                [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
W
wuchenghui 已提交
486

487 488 489
        for key in [YAMLKeyword.limit_opencl_kernel_time,
                    YAMLKeyword.nnlib_graph_mode,
                    YAMLKeyword.obfuscate,
李寅 已提交
490
                    YAMLKeyword.winograd,
491 492
                    YAMLKeyword.quantize,
                    YAMLKeyword.change_concat_ranges]:
493 494 495
            value = model_config.get(key, "")
            if value == "":
                model_config[key] = 0
L
Liangliang He 已提交
496

497 498 499 500 501 502
        mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
                   ModuleName.YAML_CONFIG,
                   "'winograd' parameters must be in "
                   + str(WinogradParameters) +
                   ". 0 for disable winograd convolution")

L
liuqi 已提交
503 504
        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        model_config[YAMLKeyword.weight_file_path] = weight_file_path
Y
yejianwu 已提交
505

506
    return configs
Y
yejianwu 已提交
507

W
wuchenghui 已提交
508

509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
def clear_build_dirs(library_name):
    # make build dir
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    # clear temp build dir
    tmp_build_dir = os.path.join(BUILD_OUTPUT_DIR, library_name,
                                 BUILD_TMP_DIR_NAME)
    if os.path.exists(tmp_build_dir):
        sh.rm('-rf', tmp_build_dir)
    os.makedirs(tmp_build_dir)
    # clear lib dir
    lib_output_dir = os.path.join(
        BUILD_OUTPUT_DIR, library_name, OUTPUT_LIBRARY_DIR_NAME)
    if os.path.exists(lib_output_dir):
        sh.rm('-rf', lib_output_dir)


################################
# convert
################################
def print_configuration(configs):
530 531 532 533 534 535 536 537 538
    title = "Common Configuration"
    header = ["key", "value"]
    data = list()
    data.append([YAMLKeyword.library_name,
                 configs[YAMLKeyword.library_name]])
    data.append([YAMLKeyword.target_abis,
                 configs[YAMLKeyword.target_abis]])
    data.append([YAMLKeyword.target_socs,
                 configs[YAMLKeyword.target_socs]])
539 540 541 542
    data.append([YAMLKeyword.model_graph_format,
                 configs[YAMLKeyword.model_graph_format]])
    data.append([YAMLKeyword.model_data_format,
                 configs[YAMLKeyword.model_data_format]])
543
    MaceLogger.summary(StringFormatter.table(header, data, title))
L
Liangliang He 已提交
544

Y
yejianwu 已提交
545

546 547 548 549
def download_file(url, dst, num_retries=3):
    from six.moves import urllib

    try:
550
        urllib.request.urlretrieve(url, dst)
551
        MaceLogger.info('\nDownloaded successfully.')
L
liuqi 已提交
552 553
    except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
            urllib.error.URLError) as e:
554
        MaceLogger.warning('Download error:' + str(e))
555 556 557 558 559 560 561
        if num_retries > 0:
            return download_file(url, dst, num_retries - 1)
        else:
            return False
    return True


B
Bin Li 已提交
562 563 564 565 566 567 568
def get_model_files(model_file_path,
                    model_sha256_checksum,
                    model_output_dir,
                    weight_file_path="",
                    weight_sha256_checksum=""):
    model_file = model_file_path
    weight_file = weight_file_path
569 570 571

    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
B
Bin Li 已提交
572 573 574 575
        model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
        if not os.path.exists(model_file) or \
                sha256_checksum(model_file) != model_sha256_checksum:
            MaceLogger.info("Downloading model, please wait ...")
576 577 578
            if not download_file(model_file_path, model_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
579 580 581 582

    if sha256_checksum(model_file) != model_sha256_checksum:
        MaceLogger.error(ModuleName.MODEL_CONVERTER,
                         "model file sha256checksum not match")
L
Liangliang He 已提交
583 584 585

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
B
Bin Li 已提交
586 587 588 589 590
        weight_file = \
            model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
        if not os.path.exists(weight_file) or \
                sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.info("Downloading model weight, please wait ...")
591 592 593
            if not download_file(weight_file_path, weight_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
594 595 596 597 598

    if weight_file:
        if sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "weight file sha256checksum not match")
Y
yejianwu 已提交
599 600

    return model_file, weight_file
L
Liangliang He 已提交
601

L
liuqi 已提交
602

603
def convert_model(configs, cl_mem_type):
604 605 606 607
    # Remove previous output dirs
    library_name = configs[YAMLKeyword.library_name]
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
L
liuqi 已提交
608 609 610
    elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
        sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
    os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
B
Bin Li 已提交
611 612
    if not os.path.exists(BUILD_DOWNLOADS_DIR):
        os.makedirs(BUILD_DOWNLOADS_DIR)
613 614 615

    model_output_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
L
liuqi 已提交
616 617
    model_header_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_HEADER_DIR_PATH)
618
    # clear output dir
619 620 621
    if os.path.exists(model_output_dir):
        sh.rm("-rf", model_output_dir)
    os.makedirs(model_output_dir)
L
liuqi 已提交
622 623
    if os.path.exists(model_header_dir):
        sh.rm("-rf", model_header_dir)
624 625 626 627 628 629 630 631 632 633 634 635 636 637

    embed_model_data = \
        configs[YAMLKeyword.model_data_format] == ModelFormat.code

    if os.path.exists(MODEL_CODEGEN_DIR):
        sh.rm("-rf", MODEL_CODEGEN_DIR)
    if os.path.exists(ENGINE_CODEGEN_DIR):
        sh.rm("-rf", ENGINE_CODEGEN_DIR)

    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        os.makedirs(model_header_dir)
        sh_commands.gen_mace_engine_factory_source(
            configs[YAMLKeyword.models].keys(),
            embed_model_data)
L
liuqi 已提交
638 639 640
        sh.cp("-f", glob.glob("mace/codegen/engine/*.h"),
              model_header_dir)

641 642 643 644 645
    for model_name in configs[YAMLKeyword.models]:
        MaceLogger.header(
            StringFormatter.block("Convert %s model" % model_name))
        model_config = configs[YAMLKeyword.models][model_name]
        runtime = model_config[YAMLKeyword.runtime]
646 647 648 649
        if cl_mem_type:
            model_config[YAMLKeyword.cl_mem_type] = cl_mem_type
        else:
            model_config[YAMLKeyword.cl_mem_type] = "image"
650

B
Bin Li 已提交
651
        model_file_path, weight_file_path = get_model_files(
652
            model_config[YAMLKeyword.model_file_path],
B
Bin Li 已提交
653 654 655 656
            model_config[YAMLKeyword.model_sha256_checksum],
            BUILD_DOWNLOADS_DIR,
            model_config[YAMLKeyword.weight_file_path],
            model_config[YAMLKeyword.weight_sha256_checksum])
657 658 659 660 661

        data_type = model_config[YAMLKeyword.data_type]
        # TODO(liuqi): support multiple subgraphs
        subgraphs = model_config[YAMLKeyword.subgraphs]

L
liuqi 已提交
662
        model_codegen_dir = "%s/%s" % (MODEL_CODEGEN_DIR, model_name)
663 664 665 666 667 668 669 670
        sh_commands.gen_model_code(
            model_codegen_dir,
            model_config[YAMLKeyword.platform],
            model_file_path,
            weight_file_path,
            model_config[YAMLKeyword.model_sha256_checksum],
            model_config[YAMLKeyword.weight_sha256_checksum],
            ",".join(subgraphs[0][YAMLKeyword.input_tensors]),
671
            ",".join(subgraphs[0][YAMLKeyword.input_data_formats]),
672
            ",".join(subgraphs[0][YAMLKeyword.output_tensors]),
673
            ",".join(subgraphs[0][YAMLKeyword.output_data_formats]),
B
Bin Li 已提交
674
            ",".join(subgraphs[0][YAMLKeyword.check_tensors]),
675 676 677
            runtime,
            model_name,
            ":".join(subgraphs[0][YAMLKeyword.input_shapes]),
李寅 已提交
678
            ":".join(subgraphs[0][YAMLKeyword.input_ranges]),
B
Bin Li 已提交
679 680
            ":".join(subgraphs[0][YAMLKeyword.output_shapes]),
            ":".join(subgraphs[0][YAMLKeyword.check_shapes]),
681 682 683
            model_config[YAMLKeyword.nnlib_graph_mode],
            embed_model_data,
            model_config[YAMLKeyword.winograd],
李寅 已提交
684 685
            model_config[YAMLKeyword.quantize],
            model_config.get(YAMLKeyword.quantize_range_file, ""),
686
            model_config[YAMLKeyword.change_concat_ranges],
687
            model_config[YAMLKeyword.obfuscate],
688
            configs[YAMLKeyword.model_graph_format],
李寅 已提交
689
            data_type,
690
            model_config[YAMLKeyword.cl_mem_type],
李寅 已提交
691
            ",".join(model_config.get(YAMLKeyword.graph_optimize_options, [])))
692

693
        if configs[YAMLKeyword.model_graph_format] == ModelFormat.file:
L
liuqi 已提交
694 695 696
            sh.mv("-f",
                  '%s/%s.pb' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
697 698 699
            sh.mv("-f",
                  '%s/%s.data' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
700 701
        else:
            if not embed_model_data:
L
liuqi 已提交
702
                sh.mv("-f",
L
liuqi 已提交
703
                      '%s/%s.data' % (model_codegen_dir, model_name),
L
liuqi 已提交
704
                      model_output_dir)
L
liuqi 已提交
705 706
            sh.cp("-f", glob.glob("mace/codegen/models/*/*.h"),
                  model_header_dir)
707

L
liuqi 已提交
708
        MaceLogger.summary(
709 710 711
            StringFormatter.block("Model %s converted" % model_name))


712 713
def build_model_lib(configs, address_sanitizer):
    MaceLogger.header(StringFormatter.block("Building model library"))
714

715 716 717 718
    # create model library dir
    library_name = configs[YAMLKeyword.library_name]
    for target_abi in configs[YAMLKeyword.target_abis]:
        hexagon_mode = get_hexagon_mode(configs)
L
liuqi 已提交
719 720 721 722 723
        model_lib_output_path = get_model_lib_output_path(library_name,
                                                          target_abi)
        library_out_dir = os.path.dirname(model_lib_output_path)
        if not os.path.exists(library_out_dir):
            os.makedirs(library_out_dir)
L
liuqi 已提交
724
        toolchain = infer_toolchain(target_abi)
725
        sh_commands.bazel_build(
726
            MODEL_LIB_TARGET,
727
            abi=target_abi,
L
liuqi 已提交
728
            toolchain=toolchain,
729
            hexagon_mode=hexagon_mode,
Y
yejianwu 已提交
730
            enable_opencl=get_opencl_mode(configs),
731
            enable_quantize=get_quantize_mode(configs),
732 733
            address_sanitizer=address_sanitizer,
            symbol_hidden=True
734 735
        )

736
        sh.cp("-f", MODEL_LIB_PATH, model_lib_output_path)
737 738 739 740 741 742 743


def print_library_summary(configs):
    library_name = configs[YAMLKeyword.library_name]
    title = "Library"
    header = ["key", "value"]
    data = list()
744 745 746 747 748 749 750 751 752
    data.append(["MACE Model Path",
                 "%s/%s/%s"
                 % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)])
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        data.append(["MACE Model Header Path",
                     "%s/%s/%s"
                     % (BUILD_OUTPUT_DIR, library_name,
                        MODEL_HEADER_DIR_PATH)])

753 754 755
    MaceLogger.summary(StringFormatter.table(header, data, title))


756
def convert_func(flags):
757
    configs = format_model_config(flags)
758

759
    print_configuration(configs)
760

761
    convert_model(configs, flags.cl_mem_type)
762

763 764
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        build_model_lib(configs, flags.address_sanitizer)
765 766 767 768 769 770 771 772 773 774 775 776

    print_library_summary(configs)


################################
# run
################################
def report_run_statistics(stdout,
                          abi,
                          serialno,
                          model_name,
                          device_type,
777 778
                          output_dir,
                          tuned):
779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799
    metrics = [0] * 3
    for line in stdout.split('\n'):
        line = line.strip()
        parts = line.split()
        if len(parts) == 4 and parts[0].startswith("time"):
            metrics[0] = str(float(parts[1]))
            metrics[1] = str(float(parts[2]))
            metrics[2] = str(float(parts[3]))
            break

    device_name = ""
    target_soc = ""
    if abi != "host":
        props = sh_commands.adb_getprop_by_serialno(serialno)
        device_name = props.get("ro.product.model", "")
        target_soc = props.get("ro.board.platform", "")

    report_filename = output_dir + "/report.csv"
    if not os.path.exists(report_filename):
        with open(report_filename, 'w') as f:
            f.write("model_name,device_name,soc,abi,runtime,"
800
                    "init(ms),warmup(ms),run_avg(ms),tuned\n")
801 802

    data_str = "{model_name},{device_name},{soc},{abi},{device_type}," \
803
               "{init},{warmup},{run_avg},{tuned}\n" \
804 805 806 807 808 809 810
        .format(model_name=model_name,
                device_name=device_name,
                soc=target_soc,
                abi=abi,
                device_type=device_type,
                init=metrics[0],
                warmup=metrics[1],
811
                run_avg=metrics[2],
812
                tuned=tuned)
813 814 815 816
    with open(report_filename, 'a') as f:
        f.write(data_str)


L
liuqi 已提交
817 818
def build_mace_run(configs, target_abi, toolchain, enable_openmp,
                   address_sanitizer, mace_lib_type):
819 820 821 822 823 824 825 826
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

827
    symbol_hidden = True
828 829
    mace_run_target = MACE_RUN_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
830
        symbol_hidden = False
831 832 833 834 835
        mace_run_target = MACE_RUN_DYNAMIC_TARGET
    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
L
liuqi 已提交
836
                   "You should convert model first.")
837 838 839 840 841
        build_arg = "--per_file_copt=mace/tools/validation/mace_run.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(
        mace_run_target,
        abi=target_abi,
L
liuqi 已提交
842
        toolchain=toolchain,
843 844
        hexagon_mode=hexagon_mode,
        enable_openmp=enable_openmp,
Y
yejianwu 已提交
845
        enable_opencl=get_opencl_mode(configs),
846
        enable_quantize=get_quantize_mode(configs),
847
        address_sanitizer=address_sanitizer,
848
        symbol_hidden=symbol_hidden,
849 850 851 852 853 854
        extra_args=build_arg
    )
    sh_commands.update_mace_run_binary(build_tmp_binary_dir,
                                       mace_lib_type == MACELibType.dynamic)


L
liuqi 已提交
855 856
def build_example(configs, target_abi, toolchain,
                  enable_openmp, mace_lib_type):
857 858 859 860 861 862 863 864
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

865
    symbol_hidden = True
L
liuqi 已提交
866

867 868
    libmace_target = LIBMACE_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
869
        symbol_hidden = False
870 871 872 873
        libmace_target = LIBMACE_SO_TARGET

    sh_commands.bazel_build(libmace_target,
                            abi=target_abi,
L
liuqi 已提交
874
                            toolchain=toolchain,
875
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
876
                            enable_opencl=get_opencl_mode(configs),
877
                            enable_quantize=get_quantize_mode(configs),
878
                            hexagon_mode=hexagon_mode,
L
liuqi 已提交
879
                            address_sanitizer=flags.address_sanitizer,
880
                            symbol_hidden=symbol_hidden)
881 882 883 884 885 886

    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)
    sh.mkdir("-p", LIB_CODEGEN_DIR)

    build_arg = ""
887 888 889 890 891 892 893 894 895
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
                   "You should convert model first.")
        model_lib_path = get_model_lib_output_path(library_name,
                                                   target_abi)
        sh.cp("-f", model_lib_path, LIB_CODEGEN_DIR)
        build_arg = "--per_file_copt=mace/examples/cli/example.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

896 897 898 899 900
    if mace_lib_type == MACELibType.dynamic:
        example_target = EXAMPLE_DYNAMIC_TARGET
        sh.cp("-f", LIBMACE_DYNAMIC_PATH, LIB_CODEGEN_DIR)
    else:
        example_target = EXAMPLE_STATIC_TARGET
901
        sh.cp("-f", LIBMACE_STATIC_PATH, LIB_CODEGEN_DIR)
902 903 904

    sh_commands.bazel_build(example_target,
                            abi=target_abi,
L
liuqi 已提交
905
                            toolchain=toolchain,
906
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
907
                            enable_opencl=get_opencl_mode(configs),
908
                            enable_quantize=get_quantize_mode(configs),
909
                            hexagon_mode=hexagon_mode,
L
liuqi 已提交
910
                            address_sanitizer=flags.address_sanitizer,
911 912 913 914 915 916 917 918
                            extra_args=build_arg)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(example_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)
    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)


L
liuqi 已提交
919 920 921 922 923 924 925 926 927 928
def print_package_summary(package_path):
    title = "Library"
    header = ["key", "value"]
    data = list()
    data.append(["MACE Model package Path",
                 package_path])

    MaceLogger.summary(StringFormatter.table(header, data, title))


929
def run_mace(flags):
930
    configs = format_model_config(flags)
931 932

    clear_build_dirs(configs[YAMLKeyword.library_name])
933 934

    target_socs = configs[YAMLKeyword.target_socs]
935 936
    device_list = DeviceManager.list_devices(flags.device_yml)
    if target_socs and ALL_SOC_TAG not in target_socs:
L
liuqi 已提交
937 938
        device_list = [dev for dev in device_list
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
939
    for target_abi in configs[YAMLKeyword.target_abis]:
940
        # build target
L
liuqi 已提交
941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965
        for dev in device_list:
            if target_abi in dev[YAMLKeyword.target_abis]:
                # get toolchain
                toolchain = infer_toolchain(target_abi)
                if flags.example:
                    build_example(configs,
                                  target_abi,
                                  toolchain,
                                  not flags.disable_openmp,
                                  flags.mace_lib_type)
                else:
                    build_mace_run(configs,
                                   target_abi,
                                   toolchain,
                                   not flags.disable_openmp,
                                   flags.address_sanitizer,
                                   flags.mace_lib_type)
                # run
                device = DeviceWrapper(dev)
                with device.lock():
                    device.run_specify_abi(flags, configs, target_abi)
            elif dev[YAMLKeyword.device_name] != SystemType.host:
                six.print_('The device with soc %s do not support abi %s' %
                           (dev[YAMLKeyword.target_socs], target_abi),
                           file=sys.stderr)
966

L
liuqi 已提交
967 968 969 970 971
    # package the output files
    package_path = sh_commands.packaging_lib(BUILD_OUTPUT_DIR,
                                             configs[YAMLKeyword.library_name])
    print_package_summary(package_path)

972 973 974 975

################################
#  benchmark model
################################
L
liuqi 已提交
976 977 978 979 980
def build_benchmark_model(configs,
                          target_abi,
                          toolchain,
                          enable_openmp,
                          mace_lib_type):
981 982 983 984 985
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    link_dynamic = mace_lib_type == MACELibType.dynamic
    if link_dynamic:
Y
yejianwu 已提交
986
        symbol_hidden = False
987 988
        benchmark_target = BM_MODEL_DYNAMIC_TARGET
    else:
Y
yejianwu 已提交
989
        symbol_hidden = True
990 991 992 993 994 995
        benchmark_target = BM_MODEL_STATIC_TARGET

    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.BENCHMARK,
L
liuqi 已提交
996
                   "You should convert model first.")
997 998 999 1000
        build_arg = "--per_file_copt=mace/benchmark/benchmark_model.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(benchmark_target,
                            abi=target_abi,
L
liuqi 已提交
1001
                            toolchain=toolchain,
1002
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
1003
                            enable_opencl=get_opencl_mode(configs),
1004
                            enable_quantize=get_quantize_mode(configs),
1005
                            hexagon_mode=hexagon_mode,
Y
yejianwu 已提交
1006
                            symbol_hidden=symbol_hidden,
1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
                            extra_args=build_arg)
    # clear tmp binary dir
    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(benchmark_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)


1018
def benchmark_model(flags):
1019
    configs = format_model_config(flags)
1020 1021

    clear_build_dirs(configs[YAMLKeyword.library_name])
1022 1023

    target_socs = configs[YAMLKeyword.target_socs]
1024 1025
    device_list = DeviceManager.list_devices(flags.device_yml)
    if target_socs and ALL_SOC_TAG not in target_socs:
L
liuqi 已提交
1026
        device_list = [dev for dev in device_list
1027
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
1028 1029

    for target_abi in configs[YAMLKeyword.target_abis]:
1030
        # build benchmark_model binary
L
liuqi 已提交
1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
        for dev in device_list:
            if target_abi in dev[YAMLKeyword.target_abis]:
                toolchain = infer_toolchain(target_abi)
                build_benchmark_model(configs,
                                      target_abi,
                                      toolchain,
                                      not flags.disable_openmp,
                                      flags.mace_lib_type)
                device = DeviceWrapper(dev)
                with device.lock():
                    device.bm_specific_target(flags, configs, target_abi)
            else:
                six.print_('There is no abi %s with soc %s' %
                           (target_abi, dev[YAMLKeyword.target_socs]),
                           file=sys.stderr)
L
liuqi 已提交
1046

1047

L
liuqi 已提交
1048
################################
Y
yejianwu 已提交
1049
# parsing arguments
L
liuqi 已提交
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
1062
        return CaffeEnvType.DOCKER
L
liuqi 已提交
1063
    elif v.lower() == 'local':
1064
        return CaffeEnvType.LOCAL
L
liuqi 已提交
1065 1066 1067 1068
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


1069 1070 1071 1072 1073 1074 1075 1076 1077
def str_to_mace_lib_type(v):
    if v.lower() == 'dynamic':
        return MACELibType.dynamic
    elif v.lower() == 'static':
        return MACELibType.static
    else:
        raise argparse.ArgumentTypeError('[dynamic| static] expected.')


1078
def parse_args():
L
Liangliang He 已提交
1079
    """Parses command line arguments."""
1080 1081 1082
    all_type_parent_parser = argparse.ArgumentParser(add_help=False)
    all_type_parent_parser.add_argument(
        '--config',
L
Liangliang He 已提交
1083
        type=str,
1084
        default="",
L
liuqi 已提交
1085
        required=True,
1086
        help="the path of model yaml configuration file.")
1087
    all_type_parent_parser.add_argument(
1088
        "--model_graph_format",
1089 1090
        type=str,
        default="",
1091 1092 1093 1094 1095 1096
        help="[file, code], MACE Model graph format.")
    all_type_parent_parser.add_argument(
        "--model_data_format",
        type=str,
        default="",
        help="['file', 'code'], MACE Model data format.")
1097 1098 1099 1100 1101
    all_type_parent_parser.add_argument(
        "--target_abis",
        type=str,
        default="",
        help="Target ABIs, comma seperated list.")
1102 1103 1104 1105 1106
    all_type_parent_parser.add_argument(
        "--target_socs",
        type=str,
        default="",
        help="Target SOCs, comma seperated list.")
1107 1108
    convert_run_parent_parser = argparse.ArgumentParser(add_help=False)
    convert_run_parent_parser.add_argument(
1109 1110
        '--address_sanitizer',
        action="store_true",
L
liuqi 已提交
1111
        help="Whether to use address sanitizer to check memory error")
1112
    run_bm_parent_parser = argparse.ArgumentParser(add_help=False)
1113 1114 1115 1116 1117 1118 1119 1120 1121
    run_bm_parent_parser.add_argument(
        "--mace_lib_type",
        type=str_to_mace_lib_type,
        default=DefaultValues.mace_lib_type,
        help="[static | dynamic], Which type MACE library to use.")
    run_bm_parent_parser.add_argument(
        "--disable_openmp",
        action="store_true",
        help="Disable openmp for multiple thread.")
1122
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1123 1124
        "--omp_num_threads",
        type=int,
1125
        default=DefaultValues.omp_num_threads,
W
wuchenghui 已提交
1126
        help="num of openmp threads")
1127
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1128 1129
        "--cpu_affinity_policy",
        type=int,
1130
        default=DefaultValues.cpu_affinity_policy,
W
wuchenghui 已提交
1131
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
1132
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1133 1134
        "--gpu_perf_hint",
        type=int,
1135
        default=DefaultValues.gpu_perf_hint,
W
wuchenghui 已提交
1136
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
1137
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1138 1139
        "--gpu_priority_hint",
        type=int,
1140
        default=DefaultValues.gpu_priority_hint,
W
wuchenghui 已提交
1141
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liuqi 已提交
1142 1143 1144 1145 1146 1147
    run_bm_parent_parser.add_argument(
        "--device_yml",
        type=str,
        default='',
        help='embedded linux device config yml file'
    )
1148 1149
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
1150 1151 1152 1153
    convert = subparsers.add_parser(
        'convert',
        parents=[all_type_parent_parser, convert_run_parent_parser],
        help='convert to mace model (file or code)')
1154 1155 1156 1157 1158
    convert.add_argument(
        "--cl_mem_type",
        type=str,
        default=None,
        help="Which type of OpenCL memory type to use [image | buffer].")
1159
    convert.set_defaults(func=convert_func)
1160 1161 1162
    run = subparsers.add_parser(
        'run',
        parents=[all_type_parent_parser, run_bm_parent_parser,
1163
                 convert_run_parent_parser],
1164 1165
        help='run model in command line')
    run.set_defaults(func=run_mace)
1166 1167 1168 1169
    run.add_argument(
        "--disable_tuning",
        action="store_true",
        help="Disable tuning for specific thread.")
1170 1171
    run.add_argument(
        "--round",
L
Liangliang He 已提交
1172
        type=int,
1173 1174 1175 1176 1177
        default=1,
        help="The model running round.")
    run.add_argument(
        "--validate",
        action="store_true",
1178 1179
        help="whether to verify the results are consistent with "
             "the frameworks.")
1180
    run.add_argument(
L
liuqi 已提交
1181 1182 1183
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
1184 1185
        help="[docker | local] you can specific caffe environment for"
             " validation. local environment or caffe docker image.")
1186 1187 1188 1189
    run.add_argument(
        "--vlog_level",
        type=int,
        default=0,
1190
        help="[1~5]. Verbose log level for debug.")
1191
    run.add_argument(
L
Liangliang He 已提交
1192
        "--gpu_out_of_range_check",
1193 1194 1195 1196 1197 1198
        action="store_true",
        help="Enable out of memory check for gpu.")
    run.add_argument(
        "--restart_round",
        type=int,
        default=1,
1199
        help="restart round between run.")
1200 1201 1202 1203 1204 1205
    run.add_argument(
        "--report",
        action="store_true",
        help="print run statistics report.")
    run.add_argument(
        "--report_dir",
1206 1207
        type=str,
        default="",
1208 1209
        help="print run statistics report.")
    run.add_argument(
李寅 已提交
1210 1211 1212 1213
        "--runtime_failure_ratio",
        type=float,
        default=0.0,
        help="[mock runtime failure ratio].")
L
liuqi 已提交
1214 1215 1216 1217
    run.add_argument(
        "--example",
        action="store_true",
        help="whether to run example.")
李寅 已提交
1218 1219 1220 1221 1222 1223 1224 1225 1226
    run.add_argument(
        "--quantize_stat",
        action="store_true",
        help="whether to stat quantization range.")
    run.add_argument(
        "--input_dir",
        type=str,
        default="",
        help="quantize stat input dir.")
1227 1228 1229 1230 1231
    run.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="quantize stat output dir.")
1232 1233
    benchmark = subparsers.add_parser(
        'benchmark',
1234
        parents=[all_type_parent_parser, run_bm_parent_parser],
1235 1236
        help='benchmark model for detail information')
    benchmark.set_defaults(func=benchmark_model)
L
Liangliang He 已提交
1237 1238
    return parser.parse_known_args()

1239

Y
yejianwu 已提交
1240
if __name__ == "__main__":
1241 1242
    flags, unparsed = parse_args()
    flags.func(flags)