converter.py 46.0 KB
Newer Older
Y
yejianwu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright 2018 Xiaomi, Inc.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

import argparse
L
liuqi 已提交
16
import glob
17
import hashlib
18
import os
L
liuqi 已提交
19
import re
L
Liangliang He 已提交
20
import sh
21
import sys
22
import urllib
Y
yejianwu 已提交
23
import yaml
L
liuqi 已提交
24

25
from enum import Enum
26
import six
27

28
import sh_commands
L
Liangliang He 已提交
29

L
liuqi 已提交
30 31
from common import *
from device import DeviceWrapper, DeviceManager
32

Y
yejianwu 已提交
33 34 35 36
################################
# set environment
################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
A
Allen 已提交
37

38 39 40 41 42
################################
# common definitions
################################

ABITypeStrs = [
L
liuqi 已提交
43 44
    'armeabi-v7a',
    'arm64-v8a',
L
liuqi 已提交
45 46
    'arm64',
    'armhf',
L
liuqi 已提交
47
    'host',
48
]
L
liuqi 已提交
49

50 51 52 53 54
ModelFormatStrs = [
    "file",
    "code",
]

55 56 57
PlatformTypeStrs = [
    "tensorflow",
    "caffe",
L
liutuo 已提交
58
    "onnx",
59 60 61 62 63 64 65 66 67 68 69
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
                    type=str)

RuntimeTypeStrs = [
    "cpu",
    "gpu",
    "dsp",
    "cpu+gpu"
]

Y
yejianwu 已提交
70 71 72 73 74 75 76 77 78
InputDataTypeStrs = [
    "int32",
    "float32",
]

InputDataType = Enum('InputDataType',
                     [(ele, ele) for ele in InputDataTypeStrs],
                     type=str)

L
liuqi 已提交
79
FPDataTypeStrs = [
80 81 82 83
    "fp16_fp32",
    "fp32_fp32",
]

L
liuqi 已提交
84 85
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
                  type=str)
86

L
liuqi 已提交
87 88 89 90 91 92 93
DSPDataTypeStrs = [
    "uint8",
]

DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
                   type=str)

94 95
WinogradParameters = [0, 2, 4]

96 97 98 99 100 101 102 103 104 105
DataFormatStrs = [
    "NONE",
    "NHWC",
]


class DataFormat(object):
    NONE = "NONE"
    NHWC = "NHWC"

106 107

class DefaultValues(object):
108
    mace_lib_type = MACELibType.static
109 110 111 112 113 114
    omp_num_threads = -1,
    cpu_affinity_policy = 1,
    gpu_perf_hint = 3,
    gpu_priority_hint = 3,


115 116 117 118 119 120 121
class ValidationThreshold(object):
    cpu_threshold = 0.999,
    gpu_threshold = 0.995,
    hexagon_threshold = 0.930,
    cpu_quantize_threshold = 0.980,


122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
CPP_KEYWORDS = [
    'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
    'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor',
    'bool', 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t',
    'class', 'compl', 'concept', 'const', 'constexpr', 'const_cast',
    'continue', 'co_await', 'co_return', 'co_yield', 'decltype', 'default',
    'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit',
    'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if',
    'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace',
    'new', 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq',
    'private', 'protected', 'public', 'register', 'reinterpret_cast',
    'requires', 'return', 'short', 'signed', 'sizeof', 'static',
    'static_assert', 'static_cast', 'struct', 'switch', 'synchronized',
    'template', 'this', 'thread_local', 'throw', 'true', 'try', 'typedef',
    'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void',
    'volatile', 'wchar_t', 'while', 'xor', 'xor_eq', 'override', 'final',
    'transaction_safe', 'transaction_safe_dynamic', 'if', 'elif', 'else',
    'endif', 'defined', 'ifdef', 'ifndef', 'define', 'undef', 'include',
    'line', 'error', 'pragma',
]
Y
yejianwu 已提交
142

143

144 145 146
################################
# common functions
################################
147
def parse_device_type(runtime):
Y
yejianwu 已提交
148
    device_type = ""
149

150
    if runtime == RuntimeType.dsp:
151
        device_type = DeviceType.HEXAGON
152
    elif runtime == RuntimeType.gpu:
153
        device_type = DeviceType.GPU
154
    elif runtime == RuntimeType.cpu:
155
        device_type = DeviceType.CPU
156

157
    return device_type
158

Y
yejianwu 已提交
159 160

def get_hexagon_mode(configs):
L
Liangliang He 已提交
161
    runtime_list = []
L
liuqi 已提交
162
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
163
        model_runtime = \
L
liuqi 已提交
164 165
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
L
Liangliang He 已提交
166 167
        runtime_list.append(model_runtime.lower())

L
liuqi 已提交
168
    if RuntimeType.dsp in runtime_list:
Y
yejianwu 已提交
169 170 171 172
        return True
    return False


Y
yejianwu 已提交
173 174 175
def get_opencl_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
176
        model_runtime = \
Y
yejianwu 已提交
177 178 179 180 181 182 183 184 185
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.gpu in runtime_list or RuntimeType.cpu_gpu in runtime_list:
        return True
    return False


186 187 188 189 190 191 192 193 194 195 196
def get_quantize_mode(configs):
    for model_name in configs[YAMLKeyword.models]:
        quantize =\
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.quantize, 0)
        if quantize == 1:
            return True

    return False


197 198
def md5sum(str):
    md5 = hashlib.md5()
199
    md5.update(str.encode('utf-8'))
200
    return md5.hexdigest()
201

Y
yejianwu 已提交
202

203 204 205 206 207 208
def sha256_checksum(fname):
    hash_func = hashlib.sha256()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_func.update(chunk)
    return hash_func.hexdigest()
Y
yejianwu 已提交
209

W
wuchenghui 已提交
210

211 212
def format_model_config(flags):
    with open(flags.config) as f:
213
        configs = yaml.load(f)
W
wuchenghui 已提交
214

215 216
    library_name = configs.get(YAMLKeyword.library_name, "")
    mace_check(len(library_name) > 0,
L
liuqi 已提交
217
               ModuleName.YAML_CONFIG, "library name should not be empty")
218

219 220 221 222
    if flags.target_abis:
        target_abis = flags.target_abis.split(',')
    else:
        target_abis = configs.get(YAMLKeyword.target_abis, [])
223 224
    mace_check((isinstance(target_abis, list) and len(target_abis) > 0),
               ModuleName.YAML_CONFIG, "target_abis list is needed")
225
    configs[YAMLKeyword.target_abis] = target_abis
226 227 228 229 230 231
    for abi in target_abis:
        mace_check(abi in ABITypeStrs,
                   ModuleName.YAML_CONFIG,
                   "target_abis must be in " + str(ABITypeStrs))

    target_socs = configs.get(YAMLKeyword.target_socs, "")
232 233
    if flags.target_socs:
        configs[YAMLKeyword.target_socs] = \
L
liuqi 已提交
234
            [soc.lower() for soc in flags.target_socs.split(',')]
235
    elif not target_socs:
236 237 238 239
        configs[YAMLKeyword.target_socs] = []
    elif not isinstance(target_socs, list):
        configs[YAMLKeyword.target_socs] = [target_socs]

240 241 242
    configs[YAMLKeyword.target_socs] = \
        [soc.lower() for soc in configs[YAMLKeyword.target_socs]]

L
liuqi 已提交
243 244
    if ABIType.armeabi_v7a in target_abis \
            or ABIType.arm64_v8a in target_abis:
245
        available_socs = sh_commands.adb_get_all_socs()
246 247 248 249
        target_socs = configs[YAMLKeyword.target_socs]
        if ALL_SOC_TAG in target_socs:
            mace_check(available_socs,
                       ModuleName.YAML_CONFIG,
L
liuqi 已提交
250 251 252
                       "Android abi is listed in config file and "
                       "build for all SOCs plugged in computer, "
                       "But no android phone found, "
253 254
                       "you at least plug in one phone")
        else:
255 256 257 258 259 260
            for soc in target_socs:
                mace_check(soc in available_socs,
                           ModuleName.YAML_CONFIG,
                           "Build specified SOC library, "
                           "you must plug in a phone using the SOC")

261 262
    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
263
    else:
264 265 266 267 268 269 270 271
        model_graph_format = configs.get(YAMLKeyword.model_graph_format, "")
    mace_check(model_graph_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_graph_format and '
               "model_graph_format must be in " + str(ModelFormatStrs))
    configs[YAMLKeyword.model_graph_format] = model_graph_format
    if flags.model_data_format:
        model_data_format = flags.model_data_format
272
    else:
273 274 275 276 277 278 279 280 281 282 283 284
        model_data_format = configs.get(YAMLKeyword.model_data_format, "")
    configs[YAMLKeyword.model_data_format] = model_data_format
    mace_check(model_data_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_data_format and '
               "model_data_format must be in " + str(ModelFormatStrs))

    mace_check(not (model_graph_format == ModelFormat.file
                    and model_data_format == ModelFormat.code),
               ModuleName.YAML_CONFIG,
               "If model_graph format is 'file',"
               " the model_data_format must be 'file' too")
Y
yejianwu 已提交
285

286 287 288 289
    model_names = configs.get(YAMLKeyword.models, [])
    mace_check(len(model_names) > 0, ModuleName.YAML_CONFIG,
               "no model found in config file")

L
liuqi 已提交
290
    model_name_reg = re.compile(r'^[a-zA-Z0-9_]+$')
291 292 293 294 295 296 297 298
    for model_name in model_names:
        # check model_name legality
        mace_check(model_name not in CPP_KEYWORDS,
                   ModuleName.YAML_CONFIG,
                   "model name should not be c++ keyword.")
        mace_check((model_name[0] == '_' or model_name[0].isalpha())
                   and bool(model_name_reg.match(model_name)),
                   ModuleName.YAML_CONFIG,
L
liuqi 已提交
299
                   "model name should Meet the c++ naming convention"
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
                   " which start with '_' or alpha"
                   " and only contain alpha, number and '_'")

        model_config = configs[YAMLKeyword.models][model_name]
        platform = model_config.get(YAMLKeyword.platform, "")
        mace_check(platform in PlatformTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'platform' must be in " + str(PlatformTypeStrs))

        for key in [YAMLKeyword.model_file_path,
                    YAMLKeyword.model_sha256_checksum]:
            value = model_config.get(key, "")
            mace_check(value != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" % key)

        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        if weight_file_path:
L
liuqi 已提交
317
            weight_checksum = \
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
                model_config.get(YAMLKeyword.weight_sha256_checksum, "")
            mace_check(weight_checksum != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" %
                       YAMLKeyword.weight_sha256_checksum)
        else:
            model_config[YAMLKeyword.weight_sha256_checksum] = ""

        runtime = model_config.get(YAMLKeyword.runtime, "")
        mace_check(runtime in RuntimeTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'runtime' must be in " + str(RuntimeTypeStrs))
        if ABIType.host in target_abis:
            mace_check(runtime == RuntimeType.cpu,
                       ModuleName.YAML_CONFIG,
                       "host only support cpu runtime now.")

        data_type = model_config.get(YAMLKeyword.data_type, "")
L
liuqi 已提交
335
        if runtime == RuntimeType.dsp:
L
liuqi 已提交
336 337 338 339 340 341 342 343
            if len(data_type) > 0:
                mace_check(data_type in DSPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(DSPDataTypeStrs)
                           + " for dsp runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    DSPDataType.uint8.value
L
liuqi 已提交
344 345 346 347 348 349 350 351 352 353 354 355 356
        else:
            if len(data_type) > 0:
                mace_check(data_type in FPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(FPDataTypeStrs)
                           + " for cpu runtime")
            else:
                if runtime == RuntimeType.cpu:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp32_fp32.value
                else:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp16_fp32.value
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371

        subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
        mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
                   "at least one subgraph is needed")

        for subgraph in subgraphs:
            for key in [YAMLKeyword.input_tensors,
                        YAMLKeyword.input_shapes,
                        YAMLKeyword.output_tensors,
                        YAMLKeyword.output_shapes]:
                value = subgraph.get(key, "")
                mace_check(value != "", ModuleName.YAML_CONFIG,
                           "'%s' is necessary in subgraph" % key)
                if not isinstance(value, list):
                    subgraph[key] = [value]
372
                subgraph[key] = [str(v) for v in subgraph[key]]
373

B
Bin Li 已提交
374 375 376 377 378 379 380 381 382 383
            for key in [YAMLKeyword.check_tensors,
                        YAMLKeyword.check_shapes]:
                value = subgraph.get(key, "")
                if value != "":
                    if not isinstance(value, list):
                        subgraph[key] = [value]
                    subgraph[key] = [str(v) for v in subgraph[key]]
                else:
                    subgraph[key] = []

Y
yejianwu 已提交
384 385 386 387
            input_data_types = subgraph.get(YAMLKeyword.input_data_types, "")
            if input_data_types:
                if not isinstance(input_data_types, list):
                    subgraph[YAMLKeyword.input_data_types] = [input_data_types]
388
                for input_data_type in subgraph[YAMLKeyword.input_data_types]:
Y
yejianwu 已提交
389 390 391 392 393 394 395
                    mace_check(input_data_type in InputDataTypeStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_types' must be in "
                               + str(InputDataTypeStrs))
            else:
                subgraph[YAMLKeyword.input_data_types] = []

396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
            input_data_formats = subgraph.get(YAMLKeyword.input_data_formats,
                                              [])
            if input_data_formats:
                if not isinstance(input_data_formats, list):
                    subgraph[YAMLKeyword.input_data_formats] =\
                        [input_data_formats]
                else:
                    mace_check(len(input_data_formats)
                               == len(subgraph[YAMLKeyword.input_tensors]),
                               ModuleName.YAML_CONFIG,
                               "input_data_formats should match"
                               " the size of input")
                for input_data_format in\
                        subgraph[YAMLKeyword.input_data_formats]:
                    mace_check(input_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs) + ", but got "
                               + input_data_formats)
            else:
                subgraph[YAMLKeyword.input_data_formats] = [DataFormat.NHWC]

            output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
                                               [])
            if output_data_formats:
                if not isinstance(output_data_formats, list):
                    subgraph[YAMLKeyword.output_data_formats] = \
                        [output_data_formats]
                else:
                    mace_check(len(output_data_formats)
                               == len(subgraph[YAMLKeyword.output_tensors]),
                               ModuleName.YAML_CONFIG,
                               "output_data_formats should match"
                               " the size of output")
                for output_data_format in\
                        subgraph[YAMLKeyword.output_data_formats]:
                    mace_check(output_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs))
            else:
                subgraph[YAMLKeyword.output_data_formats] = [DataFormat.NHWC]

439 440 441 442
            validation_threshold = subgraph.get(
                YAMLKeyword.validation_threshold, {})
            if not isinstance(validation_threshold, dict):
                raise argparse.ArgumentTypeError(
L
liuqi 已提交
443
                    'similarity threshold must be a dict.')
444 445

            threshold_dict = {
446 447 448 449 450
                DeviceType.CPU: ValidationThreshold.cpu_threshold,
                DeviceType.GPU: ValidationThreshold.gpu_threshold,
                DeviceType.HEXAGON: ValidationThreshold.hexagon_threshold,
                DeviceType.CPU + "_QUANTIZE":
                    ValidationThreshold.cpu_quantize_threshold,
L
liuqi 已提交
451
            }
452 453 454 455 456
            for k, v in six.iteritems(validation_threshold):
                if k.upper() == 'DSP':
                    k = DeviceType.HEXAGON
                if k.upper() not in (DeviceType.CPU,
                                     DeviceType.GPU,
李寅 已提交
457 458
                                     DeviceType.HEXAGON,
                                     DeviceType.CPU + "_QUANTIZE"):
459
                    raise argparse.ArgumentTypeError(
L
liuqi 已提交
460
                        'Unsupported validation threshold runtime: %s' % k)
461 462 463 464
                threshold_dict[k.upper()] = v

            subgraph[YAMLKeyword.validation_threshold] = threshold_dict

L
liuqi 已提交
465 466 467 468 469 470 471 472
            validation_inputs_data = subgraph.get(
                YAMLKeyword.validation_inputs_data, [])
            if not isinstance(validation_inputs_data, list):
                subgraph[YAMLKeyword.validation_inputs_data] = [
                    validation_inputs_data]
            else:
                subgraph[YAMLKeyword.validation_inputs_data] = \
                    validation_inputs_data
L
liutuo 已提交
473 474 475 476

            onnx_backend = subgraph.get(
                YAMLKeyword.backend, "tensorflow")
            subgraph[YAMLKeyword.backend] = onnx_backend
477 478 479 480 481 482
            input_ranges = subgraph.get(
                YAMLKeyword.input_ranges, [])
            if not isinstance(input_ranges, list):
                subgraph[YAMLKeyword.input_ranges] = [input_ranges]
            else:
                subgraph[YAMLKeyword.input_ranges] = input_ranges
L
liuqi 已提交
483
            subgraph[YAMLKeyword.input_ranges] = \
484
                [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
W
wuchenghui 已提交
485

486 487 488
        for key in [YAMLKeyword.limit_opencl_kernel_time,
                    YAMLKeyword.nnlib_graph_mode,
                    YAMLKeyword.obfuscate,
李寅 已提交
489
                    YAMLKeyword.winograd,
490 491
                    YAMLKeyword.quantize,
                    YAMLKeyword.change_concat_ranges]:
492 493 494
            value = model_config.get(key, "")
            if value == "":
                model_config[key] = 0
L
Liangliang He 已提交
495

496 497 498 499 500 501
        mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
                   ModuleName.YAML_CONFIG,
                   "'winograd' parameters must be in "
                   + str(WinogradParameters) +
                   ". 0 for disable winograd convolution")

L
liuqi 已提交
502 503
        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        model_config[YAMLKeyword.weight_file_path] = weight_file_path
Y
yejianwu 已提交
504

505
    return configs
Y
yejianwu 已提交
506

W
wuchenghui 已提交
507

508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
def clear_build_dirs(library_name):
    # make build dir
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    # clear temp build dir
    tmp_build_dir = os.path.join(BUILD_OUTPUT_DIR, library_name,
                                 BUILD_TMP_DIR_NAME)
    if os.path.exists(tmp_build_dir):
        sh.rm('-rf', tmp_build_dir)
    os.makedirs(tmp_build_dir)
    # clear lib dir
    lib_output_dir = os.path.join(
        BUILD_OUTPUT_DIR, library_name, OUTPUT_LIBRARY_DIR_NAME)
    if os.path.exists(lib_output_dir):
        sh.rm('-rf', lib_output_dir)


################################
# convert
################################
def print_configuration(configs):
529 530 531 532 533 534 535 536 537
    title = "Common Configuration"
    header = ["key", "value"]
    data = list()
    data.append([YAMLKeyword.library_name,
                 configs[YAMLKeyword.library_name]])
    data.append([YAMLKeyword.target_abis,
                 configs[YAMLKeyword.target_abis]])
    data.append([YAMLKeyword.target_socs,
                 configs[YAMLKeyword.target_socs]])
538 539 540 541
    data.append([YAMLKeyword.model_graph_format,
                 configs[YAMLKeyword.model_graph_format]])
    data.append([YAMLKeyword.model_data_format,
                 configs[YAMLKeyword.model_data_format]])
542
    MaceLogger.summary(StringFormatter.table(header, data, title))
L
Liangliang He 已提交
543

Y
yejianwu 已提交
544

545 546 547 548
def download_file(url, dst, num_retries=3):
    from six.moves import urllib

    try:
549
        urllib.request.urlretrieve(url, dst)
550
        MaceLogger.info('\nDownloaded successfully.')
L
liuqi 已提交
551 552
    except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
            urllib.error.URLError) as e:
553
        MaceLogger.warning('Download error:' + str(e))
554 555 556 557 558 559 560
        if num_retries > 0:
            return download_file(url, dst, num_retries - 1)
        else:
            return False
    return True


B
Bin Li 已提交
561 562 563 564 565 566 567
def get_model_files(model_file_path,
                    model_sha256_checksum,
                    model_output_dir,
                    weight_file_path="",
                    weight_sha256_checksum=""):
    model_file = model_file_path
    weight_file = weight_file_path
568 569 570

    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
B
Bin Li 已提交
571 572 573 574
        model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
        if not os.path.exists(model_file) or \
                sha256_checksum(model_file) != model_sha256_checksum:
            MaceLogger.info("Downloading model, please wait ...")
575 576 577
            if not download_file(model_file_path, model_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
578 579 580 581

    if sha256_checksum(model_file) != model_sha256_checksum:
        MaceLogger.error(ModuleName.MODEL_CONVERTER,
                         "model file sha256checksum not match")
L
Liangliang He 已提交
582 583 584

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
B
Bin Li 已提交
585 586 587 588 589
        weight_file = \
            model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
        if not os.path.exists(weight_file) or \
                sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.info("Downloading model weight, please wait ...")
590 591 592
            if not download_file(weight_file_path, weight_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
593 594 595 596 597

    if weight_file:
        if sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "weight file sha256checksum not match")
Y
yejianwu 已提交
598 599

    return model_file, weight_file
L
Liangliang He 已提交
600

L
liuqi 已提交
601

602
def convert_model(configs, cl_mem_type):
603 604 605 606
    # Remove previous output dirs
    library_name = configs[YAMLKeyword.library_name]
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
L
liuqi 已提交
607 608 609
    elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
        sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
    os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
B
Bin Li 已提交
610 611
    if not os.path.exists(BUILD_DOWNLOADS_DIR):
        os.makedirs(BUILD_DOWNLOADS_DIR)
612 613 614

    model_output_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
L
liuqi 已提交
615 616
    model_header_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_HEADER_DIR_PATH)
617
    # clear output dir
618 619 620
    if os.path.exists(model_output_dir):
        sh.rm("-rf", model_output_dir)
    os.makedirs(model_output_dir)
L
liuqi 已提交
621 622
    if os.path.exists(model_header_dir):
        sh.rm("-rf", model_header_dir)
623 624 625 626 627 628 629 630 631 632 633 634 635 636

    embed_model_data = \
        configs[YAMLKeyword.model_data_format] == ModelFormat.code

    if os.path.exists(MODEL_CODEGEN_DIR):
        sh.rm("-rf", MODEL_CODEGEN_DIR)
    if os.path.exists(ENGINE_CODEGEN_DIR):
        sh.rm("-rf", ENGINE_CODEGEN_DIR)

    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        os.makedirs(model_header_dir)
        sh_commands.gen_mace_engine_factory_source(
            configs[YAMLKeyword.models].keys(),
            embed_model_data)
L
liuqi 已提交
637 638 639
        sh.cp("-f", glob.glob("mace/codegen/engine/*.h"),
              model_header_dir)

640 641 642 643 644
    for model_name in configs[YAMLKeyword.models]:
        MaceLogger.header(
            StringFormatter.block("Convert %s model" % model_name))
        model_config = configs[YAMLKeyword.models][model_name]
        runtime = model_config[YAMLKeyword.runtime]
645 646 647 648
        if cl_mem_type:
            model_config[YAMLKeyword.cl_mem_type] = cl_mem_type
        else:
            model_config[YAMLKeyword.cl_mem_type] = "image"
649

B
Bin Li 已提交
650
        model_file_path, weight_file_path = get_model_files(
651
            model_config[YAMLKeyword.model_file_path],
B
Bin Li 已提交
652 653 654 655
            model_config[YAMLKeyword.model_sha256_checksum],
            BUILD_DOWNLOADS_DIR,
            model_config[YAMLKeyword.weight_file_path],
            model_config[YAMLKeyword.weight_sha256_checksum])
656 657 658 659 660

        data_type = model_config[YAMLKeyword.data_type]
        # TODO(liuqi): support multiple subgraphs
        subgraphs = model_config[YAMLKeyword.subgraphs]

L
liuqi 已提交
661
        model_codegen_dir = "%s/%s" % (MODEL_CODEGEN_DIR, model_name)
662 663 664 665 666 667 668 669
        sh_commands.gen_model_code(
            model_codegen_dir,
            model_config[YAMLKeyword.platform],
            model_file_path,
            weight_file_path,
            model_config[YAMLKeyword.model_sha256_checksum],
            model_config[YAMLKeyword.weight_sha256_checksum],
            ",".join(subgraphs[0][YAMLKeyword.input_tensors]),
670
            ",".join(subgraphs[0][YAMLKeyword.input_data_formats]),
671
            ",".join(subgraphs[0][YAMLKeyword.output_tensors]),
672
            ",".join(subgraphs[0][YAMLKeyword.output_data_formats]),
B
Bin Li 已提交
673
            ",".join(subgraphs[0][YAMLKeyword.check_tensors]),
674 675 676
            runtime,
            model_name,
            ":".join(subgraphs[0][YAMLKeyword.input_shapes]),
李寅 已提交
677
            ":".join(subgraphs[0][YAMLKeyword.input_ranges]),
B
Bin Li 已提交
678 679
            ":".join(subgraphs[0][YAMLKeyword.output_shapes]),
            ":".join(subgraphs[0][YAMLKeyword.check_shapes]),
680 681 682
            model_config[YAMLKeyword.nnlib_graph_mode],
            embed_model_data,
            model_config[YAMLKeyword.winograd],
李寅 已提交
683 684
            model_config[YAMLKeyword.quantize],
            model_config.get(YAMLKeyword.quantize_range_file, ""),
685
            model_config[YAMLKeyword.change_concat_ranges],
686
            model_config[YAMLKeyword.obfuscate],
687
            configs[YAMLKeyword.model_graph_format],
李寅 已提交
688
            data_type,
689
            model_config[YAMLKeyword.cl_mem_type],
李寅 已提交
690
            ",".join(model_config.get(YAMLKeyword.graph_optimize_options, [])))
691

692
        if configs[YAMLKeyword.model_graph_format] == ModelFormat.file:
L
liuqi 已提交
693 694 695
            sh.mv("-f",
                  '%s/%s.pb' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
696 697 698
            sh.mv("-f",
                  '%s/%s.data' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
699 700
        else:
            if not embed_model_data:
L
liuqi 已提交
701
                sh.mv("-f",
L
liuqi 已提交
702
                      '%s/%s.data' % (model_codegen_dir, model_name),
L
liuqi 已提交
703
                      model_output_dir)
L
liuqi 已提交
704 705
            sh.cp("-f", glob.glob("mace/codegen/models/*/*.h"),
                  model_header_dir)
706

L
liuqi 已提交
707
        MaceLogger.summary(
708 709 710
            StringFormatter.block("Model %s converted" % model_name))


711 712
def build_model_lib(configs, address_sanitizer):
    MaceLogger.header(StringFormatter.block("Building model library"))
713

714 715 716 717
    # create model library dir
    library_name = configs[YAMLKeyword.library_name]
    for target_abi in configs[YAMLKeyword.target_abis]:
        hexagon_mode = get_hexagon_mode(configs)
L
liuqi 已提交
718 719 720 721 722
        model_lib_output_path = get_model_lib_output_path(library_name,
                                                          target_abi)
        library_out_dir = os.path.dirname(model_lib_output_path)
        if not os.path.exists(library_out_dir):
            os.makedirs(library_out_dir)
L
liuqi 已提交
723
        toolchain = infer_toolchain(target_abi)
724
        sh_commands.bazel_build(
725
            MODEL_LIB_TARGET,
726
            abi=target_abi,
L
liuqi 已提交
727
            toolchain=toolchain,
728
            hexagon_mode=hexagon_mode,
Y
yejianwu 已提交
729
            enable_opencl=get_opencl_mode(configs),
730
            enable_quantize=get_quantize_mode(configs),
731 732
            address_sanitizer=address_sanitizer,
            symbol_hidden=True
733 734
        )

735
        sh.cp("-f", MODEL_LIB_PATH, model_lib_output_path)
736 737 738 739 740 741 742


def print_library_summary(configs):
    library_name = configs[YAMLKeyword.library_name]
    title = "Library"
    header = ["key", "value"]
    data = list()
743 744 745 746 747 748 749 750 751
    data.append(["MACE Model Path",
                 "%s/%s/%s"
                 % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)])
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        data.append(["MACE Model Header Path",
                     "%s/%s/%s"
                     % (BUILD_OUTPUT_DIR, library_name,
                        MODEL_HEADER_DIR_PATH)])

752 753 754
    MaceLogger.summary(StringFormatter.table(header, data, title))


755
def convert_func(flags):
756
    configs = format_model_config(flags)
757

758
    print_configuration(configs)
759

760
    convert_model(configs, flags.cl_mem_type)
761

762 763
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        build_model_lib(configs, flags.address_sanitizer)
764 765 766 767 768 769 770 771 772 773 774 775

    print_library_summary(configs)


################################
# run
################################
def report_run_statistics(stdout,
                          abi,
                          serialno,
                          model_name,
                          device_type,
776 777
                          output_dir,
                          tuned):
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798
    metrics = [0] * 3
    for line in stdout.split('\n'):
        line = line.strip()
        parts = line.split()
        if len(parts) == 4 and parts[0].startswith("time"):
            metrics[0] = str(float(parts[1]))
            metrics[1] = str(float(parts[2]))
            metrics[2] = str(float(parts[3]))
            break

    device_name = ""
    target_soc = ""
    if abi != "host":
        props = sh_commands.adb_getprop_by_serialno(serialno)
        device_name = props.get("ro.product.model", "")
        target_soc = props.get("ro.board.platform", "")

    report_filename = output_dir + "/report.csv"
    if not os.path.exists(report_filename):
        with open(report_filename, 'w') as f:
            f.write("model_name,device_name,soc,abi,runtime,"
799
                    "init(ms),warmup(ms),run_avg(ms),tuned\n")
800 801

    data_str = "{model_name},{device_name},{soc},{abi},{device_type}," \
802
               "{init},{warmup},{run_avg},{tuned}\n" \
803 804 805 806 807 808 809
        .format(model_name=model_name,
                device_name=device_name,
                soc=target_soc,
                abi=abi,
                device_type=device_type,
                init=metrics[0],
                warmup=metrics[1],
810
                run_avg=metrics[2],
811
                tuned=tuned)
812 813 814 815
    with open(report_filename, 'a') as f:
        f.write(data_str)


L
liuqi 已提交
816 817
def build_mace_run(configs, target_abi, toolchain, enable_openmp,
                   address_sanitizer, mace_lib_type):
818 819 820 821 822 823 824 825
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

826
    symbol_hidden = True
827 828
    mace_run_target = MACE_RUN_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
829
        symbol_hidden = False
830 831 832 833 834
        mace_run_target = MACE_RUN_DYNAMIC_TARGET
    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
L
liuqi 已提交
835
                   "You should convert model first.")
836 837 838 839 840
        build_arg = "--per_file_copt=mace/tools/validation/mace_run.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(
        mace_run_target,
        abi=target_abi,
L
liuqi 已提交
841
        toolchain=toolchain,
842 843
        hexagon_mode=hexagon_mode,
        enable_openmp=enable_openmp,
Y
yejianwu 已提交
844
        enable_opencl=get_opencl_mode(configs),
845
        enable_quantize=get_quantize_mode(configs),
846
        address_sanitizer=address_sanitizer,
847
        symbol_hidden=symbol_hidden,
848 849 850 851 852 853
        extra_args=build_arg
    )
    sh_commands.update_mace_run_binary(build_tmp_binary_dir,
                                       mace_lib_type == MACELibType.dynamic)


L
liuqi 已提交
854 855
def build_example(configs, target_abi, toolchain,
                  enable_openmp, mace_lib_type):
856 857 858 859 860 861 862 863
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

864
    symbol_hidden = True
L
liuqi 已提交
865

866 867
    libmace_target = LIBMACE_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
868
        symbol_hidden = False
869 870 871 872
        libmace_target = LIBMACE_SO_TARGET

    sh_commands.bazel_build(libmace_target,
                            abi=target_abi,
L
liuqi 已提交
873
                            toolchain=toolchain,
874
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
875
                            enable_opencl=get_opencl_mode(configs),
876
                            enable_quantize=get_quantize_mode(configs),
877
                            hexagon_mode=hexagon_mode,
L
liuqi 已提交
878
                            address_sanitizer=flags.address_sanitizer,
879
                            symbol_hidden=symbol_hidden)
880 881 882 883 884 885

    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)
    sh.mkdir("-p", LIB_CODEGEN_DIR)

    build_arg = ""
886 887 888 889 890 891 892 893 894
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
                   "You should convert model first.")
        model_lib_path = get_model_lib_output_path(library_name,
                                                   target_abi)
        sh.cp("-f", model_lib_path, LIB_CODEGEN_DIR)
        build_arg = "--per_file_copt=mace/examples/cli/example.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

895 896 897 898 899
    if mace_lib_type == MACELibType.dynamic:
        example_target = EXAMPLE_DYNAMIC_TARGET
        sh.cp("-f", LIBMACE_DYNAMIC_PATH, LIB_CODEGEN_DIR)
    else:
        example_target = EXAMPLE_STATIC_TARGET
900
        sh.cp("-f", LIBMACE_STATIC_PATH, LIB_CODEGEN_DIR)
901 902 903

    sh_commands.bazel_build(example_target,
                            abi=target_abi,
L
liuqi 已提交
904
                            toolchain=toolchain,
905
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
906
                            enable_opencl=get_opencl_mode(configs),
907
                            enable_quantize=get_quantize_mode(configs),
908
                            hexagon_mode=hexagon_mode,
L
liuqi 已提交
909
                            address_sanitizer=flags.address_sanitizer,
910 911 912 913 914 915 916 917
                            extra_args=build_arg)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(example_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)
    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)


L
liuqi 已提交
918 919 920 921 922 923 924 925 926 927
def print_package_summary(package_path):
    title = "Library"
    header = ["key", "value"]
    data = list()
    data.append(["MACE Model package Path",
                 package_path])

    MaceLogger.summary(StringFormatter.table(header, data, title))


928
def run_mace(flags):
929
    configs = format_model_config(flags)
930 931

    clear_build_dirs(configs[YAMLKeyword.library_name])
932 933

    target_socs = configs[YAMLKeyword.target_socs]
934 935
    device_list = DeviceManager.list_devices(flags.device_yml)
    if target_socs and ALL_SOC_TAG not in target_socs:
L
liuqi 已提交
936 937
        device_list = [dev for dev in device_list
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
938
    for target_abi in configs[YAMLKeyword.target_abis]:
939
        # build target
L
liuqi 已提交
940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964
        for dev in device_list:
            if target_abi in dev[YAMLKeyword.target_abis]:
                # get toolchain
                toolchain = infer_toolchain(target_abi)
                if flags.example:
                    build_example(configs,
                                  target_abi,
                                  toolchain,
                                  not flags.disable_openmp,
                                  flags.mace_lib_type)
                else:
                    build_mace_run(configs,
                                   target_abi,
                                   toolchain,
                                   not flags.disable_openmp,
                                   flags.address_sanitizer,
                                   flags.mace_lib_type)
                # run
                device = DeviceWrapper(dev)
                with device.lock():
                    device.run_specify_abi(flags, configs, target_abi)
            elif dev[YAMLKeyword.device_name] != SystemType.host:
                six.print_('The device with soc %s do not support abi %s' %
                           (dev[YAMLKeyword.target_socs], target_abi),
                           file=sys.stderr)
965

L
liuqi 已提交
966 967 968 969 970
    # package the output files
    package_path = sh_commands.packaging_lib(BUILD_OUTPUT_DIR,
                                             configs[YAMLKeyword.library_name])
    print_package_summary(package_path)

971 972 973 974

################################
#  benchmark model
################################
L
liuqi 已提交
975 976 977 978 979
def build_benchmark_model(configs,
                          target_abi,
                          toolchain,
                          enable_openmp,
                          mace_lib_type):
980 981 982 983 984
    library_name = configs[YAMLKeyword.library_name]
    hexagon_mode = get_hexagon_mode(configs)

    link_dynamic = mace_lib_type == MACELibType.dynamic
    if link_dynamic:
Y
yejianwu 已提交
985
        symbol_hidden = False
986 987
        benchmark_target = BM_MODEL_DYNAMIC_TARGET
    else:
Y
yejianwu 已提交
988
        symbol_hidden = True
989 990 991 992 993 994
        benchmark_target = BM_MODEL_STATIC_TARGET

    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.BENCHMARK,
L
liuqi 已提交
995
                   "You should convert model first.")
996 997 998 999
        build_arg = "--per_file_copt=mace/benchmark/benchmark_model.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(benchmark_target,
                            abi=target_abi,
L
liuqi 已提交
1000
                            toolchain=toolchain,
1001
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
1002
                            enable_opencl=get_opencl_mode(configs),
1003
                            enable_quantize=get_quantize_mode(configs),
1004
                            hexagon_mode=hexagon_mode,
Y
yejianwu 已提交
1005
                            symbol_hidden=symbol_hidden,
1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
                            extra_args=build_arg)
    # clear tmp binary dir
    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(benchmark_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)


1017
def benchmark_model(flags):
1018
    configs = format_model_config(flags)
1019 1020

    clear_build_dirs(configs[YAMLKeyword.library_name])
1021 1022

    target_socs = configs[YAMLKeyword.target_socs]
1023 1024
    device_list = DeviceManager.list_devices(flags.device_yml)
    if target_socs and ALL_SOC_TAG not in target_socs:
L
liuqi 已提交
1025
        device_list = [dev for dev in device_list
1026
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
1027 1028

    for target_abi in configs[YAMLKeyword.target_abis]:
1029
        # build benchmark_model binary
L
liuqi 已提交
1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044
        for dev in device_list:
            if target_abi in dev[YAMLKeyword.target_abis]:
                toolchain = infer_toolchain(target_abi)
                build_benchmark_model(configs,
                                      target_abi,
                                      toolchain,
                                      not flags.disable_openmp,
                                      flags.mace_lib_type)
                device = DeviceWrapper(dev)
                with device.lock():
                    device.bm_specific_target(flags, configs, target_abi)
            else:
                six.print_('There is no abi %s with soc %s' %
                           (target_abi, dev[YAMLKeyword.target_socs]),
                           file=sys.stderr)
L
liuqi 已提交
1045

1046

L
liuqi 已提交
1047
################################
Y
yejianwu 已提交
1048
# parsing arguments
L
liuqi 已提交
1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
1061
        return CaffeEnvType.DOCKER
L
liuqi 已提交
1062
    elif v.lower() == 'local':
1063
        return CaffeEnvType.LOCAL
L
liuqi 已提交
1064 1065 1066 1067
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


1068 1069 1070 1071 1072 1073 1074 1075 1076
def str_to_mace_lib_type(v):
    if v.lower() == 'dynamic':
        return MACELibType.dynamic
    elif v.lower() == 'static':
        return MACELibType.static
    else:
        raise argparse.ArgumentTypeError('[dynamic| static] expected.')


1077
def parse_args():
L
Liangliang He 已提交
1078
    """Parses command line arguments."""
1079 1080 1081
    all_type_parent_parser = argparse.ArgumentParser(add_help=False)
    all_type_parent_parser.add_argument(
        '--config',
L
Liangliang He 已提交
1082
        type=str,
1083
        default="",
L
liuqi 已提交
1084
        required=True,
1085
        help="the path of model yaml configuration file.")
1086
    all_type_parent_parser.add_argument(
1087
        "--model_graph_format",
1088 1089
        type=str,
        default="",
1090 1091 1092 1093 1094 1095
        help="[file, code], MACE Model graph format.")
    all_type_parent_parser.add_argument(
        "--model_data_format",
        type=str,
        default="",
        help="['file', 'code'], MACE Model data format.")
1096 1097 1098 1099 1100
    all_type_parent_parser.add_argument(
        "--target_abis",
        type=str,
        default="",
        help="Target ABIs, comma seperated list.")
1101 1102 1103 1104 1105
    all_type_parent_parser.add_argument(
        "--target_socs",
        type=str,
        default="",
        help="Target SOCs, comma seperated list.")
1106 1107
    convert_run_parent_parser = argparse.ArgumentParser(add_help=False)
    convert_run_parent_parser.add_argument(
1108 1109
        '--address_sanitizer',
        action="store_true",
L
liuqi 已提交
1110
        help="Whether to use address sanitizer to check memory error")
1111
    run_bm_parent_parser = argparse.ArgumentParser(add_help=False)
1112 1113 1114 1115 1116 1117 1118 1119 1120
    run_bm_parent_parser.add_argument(
        "--mace_lib_type",
        type=str_to_mace_lib_type,
        default=DefaultValues.mace_lib_type,
        help="[static | dynamic], Which type MACE library to use.")
    run_bm_parent_parser.add_argument(
        "--disable_openmp",
        action="store_true",
        help="Disable openmp for multiple thread.")
1121
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1122 1123
        "--omp_num_threads",
        type=int,
1124
        default=DefaultValues.omp_num_threads,
W
wuchenghui 已提交
1125
        help="num of openmp threads")
1126
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1127 1128
        "--cpu_affinity_policy",
        type=int,
1129
        default=DefaultValues.cpu_affinity_policy,
W
wuchenghui 已提交
1130
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
1131
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1132 1133
        "--gpu_perf_hint",
        type=int,
1134
        default=DefaultValues.gpu_perf_hint,
W
wuchenghui 已提交
1135
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
1136
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1137 1138
        "--gpu_priority_hint",
        type=int,
1139
        default=DefaultValues.gpu_priority_hint,
W
wuchenghui 已提交
1140
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liuqi 已提交
1141 1142 1143 1144 1145 1146
    run_bm_parent_parser.add_argument(
        "--device_yml",
        type=str,
        default='',
        help='embedded linux device config yml file'
    )
1147 1148
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
1149 1150 1151 1152
    convert = subparsers.add_parser(
        'convert',
        parents=[all_type_parent_parser, convert_run_parent_parser],
        help='convert to mace model (file or code)')
1153 1154 1155 1156 1157
    convert.add_argument(
        "--cl_mem_type",
        type=str,
        default=None,
        help="Which type of OpenCL memory type to use [image | buffer].")
1158
    convert.set_defaults(func=convert_func)
1159 1160 1161
    run = subparsers.add_parser(
        'run',
        parents=[all_type_parent_parser, run_bm_parent_parser,
1162
                 convert_run_parent_parser],
1163 1164
        help='run model in command line')
    run.set_defaults(func=run_mace)
1165 1166 1167 1168
    run.add_argument(
        "--disable_tuning",
        action="store_true",
        help="Disable tuning for specific thread.")
1169 1170
    run.add_argument(
        "--round",
L
Liangliang He 已提交
1171
        type=int,
1172 1173 1174 1175 1176
        default=1,
        help="The model running round.")
    run.add_argument(
        "--validate",
        action="store_true",
1177 1178
        help="whether to verify the results are consistent with "
             "the frameworks.")
1179
    run.add_argument(
L
liuqi 已提交
1180 1181 1182
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
1183 1184
        help="[docker | local] you can specific caffe environment for"
             " validation. local environment or caffe docker image.")
1185 1186 1187 1188
    run.add_argument(
        "--vlog_level",
        type=int,
        default=0,
1189
        help="[1~5]. Verbose log level for debug.")
1190
    run.add_argument(
L
Liangliang He 已提交
1191
        "--gpu_out_of_range_check",
1192 1193 1194 1195 1196 1197
        action="store_true",
        help="Enable out of memory check for gpu.")
    run.add_argument(
        "--restart_round",
        type=int,
        default=1,
1198
        help="restart round between run.")
1199 1200 1201 1202 1203 1204
    run.add_argument(
        "--report",
        action="store_true",
        help="print run statistics report.")
    run.add_argument(
        "--report_dir",
1205 1206
        type=str,
        default="",
1207 1208
        help="print run statistics report.")
    run.add_argument(
李寅 已提交
1209 1210 1211 1212
        "--runtime_failure_ratio",
        type=float,
        default=0.0,
        help="[mock runtime failure ratio].")
L
liuqi 已提交
1213 1214 1215 1216
    run.add_argument(
        "--example",
        action="store_true",
        help="whether to run example.")
李寅 已提交
1217 1218 1219 1220 1221 1222 1223 1224 1225
    run.add_argument(
        "--quantize_stat",
        action="store_true",
        help="whether to stat quantization range.")
    run.add_argument(
        "--input_dir",
        type=str,
        default="",
        help="quantize stat input dir.")
1226 1227 1228 1229 1230
    run.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="quantize stat output dir.")
1231 1232
    benchmark = subparsers.add_parser(
        'benchmark',
1233
        parents=[all_type_parent_parser, run_bm_parent_parser],
1234 1235
        help='benchmark model for detail information')
    benchmark.set_defaults(func=benchmark_model)
L
Liangliang He 已提交
1236 1237
    return parser.parse_known_args()

1238

Y
yejianwu 已提交
1239
if __name__ == "__main__":
1240 1241
    flags, unparsed = parse_args()
    flags.func(flags)