converter.py 42.5 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
Y
yejianwu 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import argparse
L
liuqi 已提交
20
import glob
L
Liangliang He 已提交
21
import sh
22
import sys
23
import time
Y
yejianwu 已提交
24
import yaml
25
import sh_commands
26
from enum import Enum
L
Liangliang He 已提交
27

28
sys.path.insert(0, "tools/python")  # noqa
L
liuqi 已提交
29 30
from common import *
from device import DeviceWrapper, DeviceManager
31 32 33
from utils import config_parser
import convert
import encrypt
34

Y
yejianwu 已提交
35 36 37 38
################################
# set environment
################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
A
Allen 已提交
39

40 41 42 43 44
################################
# common definitions
################################

ABITypeStrs = [
L
liuqi 已提交
45 46
    'armeabi-v7a',
    'arm64-v8a',
L
liuqi 已提交
47 48
    'arm64',
    'armhf',
L
liuqi 已提交
49
    'host',
50
]
L
liuqi 已提交
51

52 53 54 55 56
ModelFormatStrs = [
    "file",
    "code",
]

57 58 59
PlatformTypeStrs = [
    "tensorflow",
    "caffe",
L
liutuo 已提交
60
    "onnx",
61 62 63 64 65 66 67 68
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
                    type=str)

RuntimeTypeStrs = [
    "cpu",
    "gpu",
    "dsp",
B
Bin Li 已提交
69
    "hta",
70
    "apu",
71 72 73
    "cpu+gpu"
]

L
liuqi 已提交
74
InOutDataTypeStrs = [
Y
yejianwu 已提交
75 76 77 78
    "int32",
    "float32",
]

L
liuqi 已提交
79 80
InOutDataType = Enum('InputDataType',
                     [(ele, ele) for ele in InOutDataTypeStrs],
Y
yejianwu 已提交
81 82
                     type=str)

L
liuqi 已提交
83
FPDataTypeStrs = [
84 85 86 87
    "fp16_fp32",
    "fp32_fp32",
]

L
liuqi 已提交
88 89
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
                  type=str)
90

L
liuqi 已提交
91 92 93 94 95 96 97
DSPDataTypeStrs = [
    "uint8",
]

DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
                   type=str)

98 99 100 101 102 103 104
APUDataTypeStrs = [
    "uint8",
]

APUDataType = Enum('APUDataType', [(ele, ele) for ele in APUDataTypeStrs],
                   type=str)

105 106
WinogradParameters = [0, 2, 4]

107 108 109
DataFormatStrs = [
    "NONE",
    "NHWC",
110
    "NCHW",
111
    "OIHW",
112 113 114
]


115
class DefaultValues(object):
116
    mace_lib_type = MACELibType.static
117 118 119 120 121 122
    omp_num_threads = -1,
    cpu_affinity_policy = 1,
    gpu_perf_hint = 3,
    gpu_priority_hint = 3,


123 124 125
class ValidationThreshold(object):
    cpu_threshold = 0.999,
    gpu_threshold = 0.995,
B
Bin Li 已提交
126
    quantize_threshold = 0.980,
127 128


129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
CPP_KEYWORDS = [
    'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
    'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor',
    'bool', 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t',
    'class', 'compl', 'concept', 'const', 'constexpr', 'const_cast',
    'continue', 'co_await', 'co_return', 'co_yield', 'decltype', 'default',
    'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit',
    'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if',
    'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace',
    'new', 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq',
    'private', 'protected', 'public', 'register', 'reinterpret_cast',
    'requires', 'return', 'short', 'signed', 'sizeof', 'static',
    'static_assert', 'static_cast', 'struct', 'switch', 'synchronized',
    'template', 'this', 'thread_local', 'throw', 'true', 'try', 'typedef',
    'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void',
    'volatile', 'wchar_t', 'while', 'xor', 'xor_eq', 'override', 'final',
    'transaction_safe', 'transaction_safe_dynamic', 'if', 'elif', 'else',
    'endif', 'defined', 'ifdef', 'ifndef', 'define', 'undef', 'include',
    'line', 'error', 'pragma',
]
Y
yejianwu 已提交
149

150

151 152 153
################################
# common functions
################################
154
def parse_device_type(runtime):
Y
yejianwu 已提交
155
    device_type = ""
156

157
    if runtime == RuntimeType.dsp:
158
        device_type = DeviceType.HEXAGON
B
Bin Li 已提交
159 160
    elif runtime == RuntimeType.hta:
        device_type = DeviceType.HTA
161
    elif runtime == RuntimeType.gpu:
162
        device_type = DeviceType.GPU
163
    elif runtime == RuntimeType.cpu:
164
        device_type = DeviceType.CPU
165 166
    elif runtime == RuntimeType.apu:
        device_type = DeviceType.APU
167

168
    return device_type
169

Y
yejianwu 已提交
170 171

def get_hexagon_mode(configs):
L
Liangliang He 已提交
172
    runtime_list = []
L
liuqi 已提交
173
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
174
        model_runtime = \
L
liuqi 已提交
175 176
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
L
Liangliang He 已提交
177 178
        runtime_list.append(model_runtime.lower())

L
liuqi 已提交
179
    if RuntimeType.dsp in runtime_list:
Y
yejianwu 已提交
180 181 182 183
        return True
    return False


B
Bin Li 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196
def get_hta_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
        model_runtime = \
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.hta in runtime_list:
        return True
    return False


L
lichao18 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209
def get_apu_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
        model_runtime = \
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.apu in runtime_list:
        return True
    return False


Y
yejianwu 已提交
210 211 212
def get_opencl_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
213
        model_runtime = \
Y
yejianwu 已提交
214 215 216 217 218 219 220 221 222
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.gpu in runtime_list or RuntimeType.cpu_gpu in runtime_list:
        return True
    return False


223 224 225 226 227 228 229 230 231 232 233
def get_quantize_mode(configs):
    for model_name in configs[YAMLKeyword.models]:
        quantize =\
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.quantize, 0)
        if quantize == 1:
            return True

    return False


234 235 236 237 238 239 240 241 242
def get_symbol_hidden_mode(debug_mode, mace_lib_type=None):
    if not mace_lib_type:
        return True
    if debug_mode or mace_lib_type == MACELibType.dynamic:
        return False
    else:
        return True


243 244
def md5sum(str):
    md5 = hashlib.md5()
245
    md5.update(str.encode('utf-8'))
246
    return md5.hexdigest()
247

Y
yejianwu 已提交
248

249 250 251 252 253 254
def sha256_checksum(fname):
    hash_func = hashlib.sha256()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_func.update(chunk)
    return hash_func.hexdigest()
Y
yejianwu 已提交
255

W
wuchenghui 已提交
256

257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
def download_file(url, dst, num_retries=3):
    from six.moves import urllib

    try:
        urllib.request.urlretrieve(url, dst)
        MaceLogger.info('\nDownloaded successfully.')
    except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
            urllib.error.URLError) as e:
        MaceLogger.warning('Download error:' + str(e))
        if num_retries > 0:
            return download_file(url, dst, num_retries - 1)
        else:
            return False
    return True


def get_model_files(model_config, model_output_dir):
    if not os.path.exists(model_output_dir):
        os.makedirs(model_output_dir)
    model_file_path = model_config[YAMLKeyword.model_file_path]
    model_sha256_checksum = model_config[YAMLKeyword.model_sha256_checksum]
    weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
    weight_sha256_checksum = model_config.get(YAMLKeyword.weight_sha256_checksum, "")  # noqa
    quantize_range_file_path = model_config.get(YAMLKeyword.quantize_range_file, "")  # noqa
    model_file = model_file_path
    weight_file = weight_file_path
    quantize_range_file = quantize_range_file_path

    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
        model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
        if not os.path.exists(model_file) or \
                sha256_checksum(model_file) != model_sha256_checksum:
            MaceLogger.info("Downloading model, please wait ...")
            if not download_file(model_file_path, model_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
        model_config[YAMLKeyword.model_file_path] = model_file

    if sha256_checksum(model_file) != model_sha256_checksum:
297 298 299 300
        error_info = model_file_path + \
            " model file sha256checksum not match " + \
            model_sha256_checksum
        MaceLogger.error(ModuleName.MODEL_CONVERTER, error_info)
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
        weight_file = \
            model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
        if not os.path.exists(weight_file) or \
                sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.info("Downloading model weight, please wait ...")
            if not download_file(weight_file_path, weight_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
    model_config[YAMLKeyword.weight_file_path] = weight_file

    if weight_file:
        if sha256_checksum(weight_file) != weight_sha256_checksum:
316 317 318 319
            error_info = weight_file_path + \
                " weight file sha256checksum not match " + \
                weight_sha256_checksum
            MaceLogger.error(ModuleName.MODEL_CONVERTER, error_info)
320 321 322 323 324 325 326 327 328 329 330 331

    if quantize_range_file_path.startswith("http://") or \
            quantize_range_file_path.startswith("https://"):
        quantize_range_file = \
            model_output_dir + "/" + md5sum(quantize_range_file_path) \
            + ".range"
        if not download_file(quantize_range_file_path, quantize_range_file):
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "Model range file download failed.")
    model_config[YAMLKeyword.quantize_range_file] = quantize_range_file


332 333
def format_model_config(flags):
    with open(flags.config) as f:
334
        configs = yaml.load(f)
W
wuchenghui 已提交
335

336 337
    library_name = configs.get(YAMLKeyword.library_name, "")
    mace_check(len(library_name) > 0,
L
liuqi 已提交
338
               ModuleName.YAML_CONFIG, "library name should not be empty")
339

340 341 342 343
    if flags.target_abis:
        target_abis = flags.target_abis.split(',')
    else:
        target_abis = configs.get(YAMLKeyword.target_abis, [])
344 345
    mace_check((isinstance(target_abis, list) and len(target_abis) > 0),
               ModuleName.YAML_CONFIG, "target_abis list is needed")
346
    configs[YAMLKeyword.target_abis] = target_abis
347 348 349 350 351 352
    for abi in target_abis:
        mace_check(abi in ABITypeStrs,
                   ModuleName.YAML_CONFIG,
                   "target_abis must be in " + str(ABITypeStrs))

    target_socs = configs.get(YAMLKeyword.target_socs, "")
L
liuqi 已提交
353 354
    if flags.target_socs and flags.target_socs != TargetSOCTag.random \
            and flags.target_socs != TargetSOCTag.all:
355
        configs[YAMLKeyword.target_socs] = \
L
liuqi 已提交
356
            [soc.lower() for soc in flags.target_socs.split(',')]
357
    elif not target_socs:
358 359 360 361
        configs[YAMLKeyword.target_socs] = []
    elif not isinstance(target_socs, list):
        configs[YAMLKeyword.target_socs] = [target_socs]

362 363 364
    configs[YAMLKeyword.target_socs] = \
        [soc.lower() for soc in configs[YAMLKeyword.target_socs]]

L
liuqi 已提交
365 366
    if ABIType.armeabi_v7a in target_abis \
            or ABIType.arm64_v8a in target_abis:
367
        available_socs = sh_commands.adb_get_all_socs()
368
        target_socs = configs[YAMLKeyword.target_socs]
L
liuqi 已提交
369
        if TargetSOCTag.all in target_socs:
370 371
            mace_check(available_socs,
                       ModuleName.YAML_CONFIG,
L
liuqi 已提交
372 373 374
                       "Android abi is listed in config file and "
                       "build for all SOCs plugged in computer, "
                       "But no android phone found, "
375 376
                       "you at least plug in one phone")
        else:
377 378 379 380 381 382
            for soc in target_socs:
                mace_check(soc in available_socs,
                           ModuleName.YAML_CONFIG,
                           "Build specified SOC library, "
                           "you must plug in a phone using the SOC")

383 384
    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
385
    else:
386 387 388 389 390 391 392 393
        model_graph_format = configs.get(YAMLKeyword.model_graph_format, "")
    mace_check(model_graph_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_graph_format and '
               "model_graph_format must be in " + str(ModelFormatStrs))
    configs[YAMLKeyword.model_graph_format] = model_graph_format
    if flags.model_data_format:
        model_data_format = flags.model_data_format
394
    else:
395 396 397 398 399 400 401 402 403 404 405 406
        model_data_format = configs.get(YAMLKeyword.model_data_format, "")
    configs[YAMLKeyword.model_data_format] = model_data_format
    mace_check(model_data_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_data_format and '
               "model_data_format must be in " + str(ModelFormatStrs))

    mace_check(not (model_graph_format == ModelFormat.file
                    and model_data_format == ModelFormat.code),
               ModuleName.YAML_CONFIG,
               "If model_graph format is 'file',"
               " the model_data_format must be 'file' too")
Y
yejianwu 已提交
407

408 409 410 411
    model_names = configs.get(YAMLKeyword.models, [])
    mace_check(len(model_names) > 0, ModuleName.YAML_CONFIG,
               "no model found in config file")

L
liuqi 已提交
412
    model_name_reg = re.compile(r'^[a-zA-Z0-9_]+$')
413 414 415 416 417 418 419 420
    for model_name in model_names:
        # check model_name legality
        mace_check(model_name not in CPP_KEYWORDS,
                   ModuleName.YAML_CONFIG,
                   "model name should not be c++ keyword.")
        mace_check((model_name[0] == '_' or model_name[0].isalpha())
                   and bool(model_name_reg.match(model_name)),
                   ModuleName.YAML_CONFIG,
L
liuqi 已提交
421
                   "model name should Meet the c++ naming convention"
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
                   " which start with '_' or alpha"
                   " and only contain alpha, number and '_'")

        model_config = configs[YAMLKeyword.models][model_name]
        platform = model_config.get(YAMLKeyword.platform, "")
        mace_check(platform in PlatformTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'platform' must be in " + str(PlatformTypeStrs))

        for key in [YAMLKeyword.model_file_path,
                    YAMLKeyword.model_sha256_checksum]:
            value = model_config.get(key, "")
            mace_check(value != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" % key)

        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        if weight_file_path:
L
liuqi 已提交
439
            weight_checksum = \
440 441 442 443 444 445 446
                model_config.get(YAMLKeyword.weight_sha256_checksum, "")
            mace_check(weight_checksum != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" %
                       YAMLKeyword.weight_sha256_checksum)
        else:
            model_config[YAMLKeyword.weight_sha256_checksum] = ""

447 448
        get_model_files(model_config, BUILD_DOWNLOADS_DIR)

449 450 451 452 453 454 455 456 457 458
        runtime = model_config.get(YAMLKeyword.runtime, "")
        mace_check(runtime in RuntimeTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'runtime' must be in " + str(RuntimeTypeStrs))
        if ABIType.host in target_abis:
            mace_check(runtime == RuntimeType.cpu,
                       ModuleName.YAML_CONFIG,
                       "host only support cpu runtime now.")

        data_type = model_config.get(YAMLKeyword.data_type, "")
L
liuqi 已提交
459
        if runtime == RuntimeType.dsp:
L
liuqi 已提交
460 461 462 463 464 465 466 467
            if len(data_type) > 0:
                mace_check(data_type in DSPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(DSPDataTypeStrs)
                           + " for dsp runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    DSPDataType.uint8.value
468 469 470 471 472 473 474 475 476
        elif runtime == RuntimeType.apu:
            if len(data_type) > 0:
                mace_check(data_type in APUDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(APUDataTypeStrs)
                           + " for apu runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    APUDataType.uint8.value
L
liuqi 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489
        else:
            if len(data_type) > 0:
                mace_check(data_type in FPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(FPDataTypeStrs)
                           + " for cpu runtime")
            else:
                if runtime == RuntimeType.cpu:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp32_fp32.value
                else:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp16_fp32.value
490 491 492 493 494 495 496 497 498 499 500 501 502 503 504

        subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
        mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
                   "at least one subgraph is needed")

        for subgraph in subgraphs:
            for key in [YAMLKeyword.input_tensors,
                        YAMLKeyword.input_shapes,
                        YAMLKeyword.output_tensors,
                        YAMLKeyword.output_shapes]:
                value = subgraph.get(key, "")
                mace_check(value != "", ModuleName.YAML_CONFIG,
                           "'%s' is necessary in subgraph" % key)
                if not isinstance(value, list):
                    subgraph[key] = [value]
505
                subgraph[key] = [str(v) for v in subgraph[key]]
506 507 508 509 510 511 512 513 514
            input_size = len(subgraph[YAMLKeyword.input_tensors])
            output_size = len(subgraph[YAMLKeyword.output_tensors])

            mace_check(len(subgraph[YAMLKeyword.input_shapes]) == input_size,
                       ModuleName.YAML_CONFIG,
                       "input shapes' size not equal inputs' size.")
            mace_check(len(subgraph[YAMLKeyword.output_shapes]) == output_size,
                       ModuleName.YAML_CONFIG,
                       "output shapes' size not equal outputs' size.")
515

B
Bin Li 已提交
516 517 518 519 520 521 522 523 524 525
            for key in [YAMLKeyword.check_tensors,
                        YAMLKeyword.check_shapes]:
                value = subgraph.get(key, "")
                if value != "":
                    if not isinstance(value, list):
                        subgraph[key] = [value]
                    subgraph[key] = [str(v) for v in subgraph[key]]
                else:
                    subgraph[key] = []

L
liuqi 已提交
526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
            for key in [YAMLKeyword.input_data_types,
                        YAMLKeyword.output_data_types]:
                if key == YAMLKeyword.input_data_types:
                    count = input_size
                else:
                    count = output_size
                data_types = subgraph.get(key, "")
                if data_types:
                    if not isinstance(data_types, list):
                        subgraph[key] = [data_types] * count
                    for data_type in subgraph[key]:
                        mace_check(data_type in InOutDataTypeStrs,
                                   ModuleName.YAML_CONFIG,
                                   key + " must be in "
                                   + str(InOutDataTypeStrs))
                else:
                    subgraph[key] = [InOutDataType.float32] * count
Y
yejianwu 已提交
543

544 545 546 547 548
            input_data_formats = subgraph.get(YAMLKeyword.input_data_formats,
                                              [])
            if input_data_formats:
                if not isinstance(input_data_formats, list):
                    subgraph[YAMLKeyword.input_data_formats] =\
549
                        [input_data_formats] * input_size
550 551
                else:
                    mace_check(len(input_data_formats)
552
                               == input_size,
553 554
                               ModuleName.YAML_CONFIG,
                               "input_data_formats should match"
555
                               " the size of input.")
556 557 558 559 560 561
                for input_data_format in\
                        subgraph[YAMLKeyword.input_data_formats]:
                    mace_check(input_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs) + ", but got "
Y
yejianwu 已提交
562
                               + input_data_format)
563
            else:
564 565
                subgraph[YAMLKeyword.input_data_formats] = \
                    [DataFormat.NHWC] * input_size
566 567 568 569 570 571

            output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
                                               [])
            if output_data_formats:
                if not isinstance(output_data_formats, list):
                    subgraph[YAMLKeyword.output_data_formats] = \
572
                        [output_data_formats] * output_size
573 574
                else:
                    mace_check(len(output_data_formats)
575
                               == output_size,
576 577 578 579 580 581 582
                               ModuleName.YAML_CONFIG,
                               "output_data_formats should match"
                               " the size of output")
                for output_data_format in\
                        subgraph[YAMLKeyword.output_data_formats]:
                    mace_check(output_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
Y
yejianwu 已提交
583
                               "'output_data_formats' must be in "
584 585
                               + str(DataFormatStrs))
            else:
586 587
                subgraph[YAMLKeyword.output_data_formats] =\
                    [DataFormat.NHWC] * output_size
588

589 590 591 592
            validation_threshold = subgraph.get(
                YAMLKeyword.validation_threshold, {})
            if not isinstance(validation_threshold, dict):
                raise argparse.ArgumentTypeError(
L
liuqi 已提交
593
                    'similarity threshold must be a dict.')
594 595

            threshold_dict = {
596 597
                DeviceType.CPU: ValidationThreshold.cpu_threshold,
                DeviceType.GPU: ValidationThreshold.gpu_threshold,
B
Bin Li 已提交
598 599
                DeviceType.HEXAGON: ValidationThreshold.quantize_threshold,
                DeviceType.HTA: ValidationThreshold.quantize_threshold,
L
lichao18 已提交
600
                DeviceType.APU: ValidationThreshold.quantize_threshold,
B
Bin Li 已提交
601
                DeviceType.QUANTIZE: ValidationThreshold.quantize_threshold,
L
liuqi 已提交
602
            }
603 604 605 606 607
            for k, v in six.iteritems(validation_threshold):
                if k.upper() == 'DSP':
                    k = DeviceType.HEXAGON
                if k.upper() not in (DeviceType.CPU,
                                     DeviceType.GPU,
李寅 已提交
608
                                     DeviceType.HEXAGON,
B
Bin Li 已提交
609
                                     DeviceType.HTA,
B
Bin Li 已提交
610
                                     DeviceType.QUANTIZE):
611
                    raise argparse.ArgumentTypeError(
L
liuqi 已提交
612
                        'Unsupported validation threshold runtime: %s' % k)
613 614 615 616
                threshold_dict[k.upper()] = v

            subgraph[YAMLKeyword.validation_threshold] = threshold_dict

L
liuqi 已提交
617 618 619 620 621 622 623 624
            validation_inputs_data = subgraph.get(
                YAMLKeyword.validation_inputs_data, [])
            if not isinstance(validation_inputs_data, list):
                subgraph[YAMLKeyword.validation_inputs_data] = [
                    validation_inputs_data]
            else:
                subgraph[YAMLKeyword.validation_inputs_data] = \
                    validation_inputs_data
L
liutuo 已提交
625 626 627 628

            onnx_backend = subgraph.get(
                YAMLKeyword.backend, "tensorflow")
            subgraph[YAMLKeyword.backend] = onnx_backend
629 630 631 632 633 634 635 636
            validation_outputs_data = subgraph.get(
                YAMLKeyword.validation_outputs_data, [])
            if not isinstance(validation_outputs_data, list):
                subgraph[YAMLKeyword.validation_outputs_data] = [
                    validation_outputs_data]
            else:
                subgraph[YAMLKeyword.validation_outputs_data] = \
                    validation_outputs_data
637 638 639 640 641 642
            input_ranges = subgraph.get(
                YAMLKeyword.input_ranges, [])
            if not isinstance(input_ranges, list):
                subgraph[YAMLKeyword.input_ranges] = [input_ranges]
            else:
                subgraph[YAMLKeyword.input_ranges] = input_ranges
L
liuqi 已提交
643
            subgraph[YAMLKeyword.input_ranges] = \
644
                [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
W
wuchenghui 已提交
645

646 647 648 649 650 651 652 653 654 655
            accuracy_validation_script = subgraph.get(
                YAMLKeyword.accuracy_validation_script, "")
            if isinstance(accuracy_validation_script, list):
                mace_check(len(accuracy_validation_script) == 1,
                           ModuleName.YAML_CONFIG,
                           "Only support one accuracy validation script")
                accuracy_validation_script = accuracy_validation_script[0]
            subgraph[YAMLKeyword.accuracy_validation_script] = \
                accuracy_validation_script

656 657 658
        for key in [YAMLKeyword.limit_opencl_kernel_time,
                    YAMLKeyword.nnlib_graph_mode,
                    YAMLKeyword.obfuscate,
李寅 已提交
659
                    YAMLKeyword.winograd,
660
                    YAMLKeyword.quantize,
B
Bin Li 已提交
661
                    YAMLKeyword.quantize_large_weights,
662
                    YAMLKeyword.change_concat_ranges]:
663 664 665
            value = model_config.get(key, "")
            if value == "":
                model_config[key] = 0
L
Liangliang He 已提交
666

B
Bin Li 已提交
667 668 669 670 671 672
        mace_check(model_config[YAMLKeyword.quantize] == 0 or
                   model_config[YAMLKeyword.quantize_large_weights] == 0,
                   ModuleName.YAML_CONFIG,
                   "quantize and quantize_large_weights should not be set to 1"
                   " at the same time.")

673 674 675 676 677 678
        mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
                   ModuleName.YAML_CONFIG,
                   "'winograd' parameters must be in "
                   + str(WinogradParameters) +
                   ". 0 for disable winograd convolution")

679
    return configs
Y
yejianwu 已提交
680

W
wuchenghui 已提交
681

682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702
def clear_build_dirs(library_name):
    # make build dir
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    # clear temp build dir
    tmp_build_dir = os.path.join(BUILD_OUTPUT_DIR, library_name,
                                 BUILD_TMP_DIR_NAME)
    if os.path.exists(tmp_build_dir):
        sh.rm('-rf', tmp_build_dir)
    os.makedirs(tmp_build_dir)
    # clear lib dir
    lib_output_dir = os.path.join(
        BUILD_OUTPUT_DIR, library_name, OUTPUT_LIBRARY_DIR_NAME)
    if os.path.exists(lib_output_dir):
        sh.rm('-rf', lib_output_dir)


################################
# convert
################################
def print_configuration(configs):
703 704 705 706 707 708 709 710 711
    title = "Common Configuration"
    header = ["key", "value"]
    data = list()
    data.append([YAMLKeyword.library_name,
                 configs[YAMLKeyword.library_name]])
    data.append([YAMLKeyword.target_abis,
                 configs[YAMLKeyword.target_abis]])
    data.append([YAMLKeyword.target_socs,
                 configs[YAMLKeyword.target_socs]])
712 713 714 715
    data.append([YAMLKeyword.model_graph_format,
                 configs[YAMLKeyword.model_graph_format]])
    data.append([YAMLKeyword.model_data_format,
                 configs[YAMLKeyword.model_data_format]])
716
    MaceLogger.summary(StringFormatter.table(header, data, title))
L
Liangliang He 已提交
717

Y
yejianwu 已提交
718

719
def build_model_lib(configs, address_sanitizer, debug_mode):
720
    MaceLogger.header(StringFormatter.block("Building model library"))
721

722 723 724
    # create model library dir
    library_name = configs[YAMLKeyword.library_name]
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
725 726 727 728 729
        model_lib_output_path = get_model_lib_output_path(library_name,
                                                          target_abi)
        library_out_dir = os.path.dirname(model_lib_output_path)
        if not os.path.exists(library_out_dir):
            os.makedirs(library_out_dir)
L
liuqi 已提交
730
        toolchain = infer_toolchain(target_abi)
731
        sh_commands.bazel_build(
732
            MODEL_LIB_TARGET,
733
            abi=target_abi,
L
liuqi 已提交
734
            toolchain=toolchain,
B
Bin Li 已提交
735 736
            enable_hexagon=get_hexagon_mode(configs),
            enable_hta=get_hta_mode(configs),
L
lichao18 已提交
737
            enable_apu=get_apu_mode(configs),
Y
yejianwu 已提交
738
            enable_opencl=get_opencl_mode(configs),
739
            enable_quantize=get_quantize_mode(configs),
740
            address_sanitizer=address_sanitizer,
741 742
            symbol_hidden=get_symbol_hidden_mode(debug_mode),
            debug_mode=debug_mode
743 744
        )

745
        sh.cp("-f", MODEL_LIB_PATH, model_lib_output_path)
746 747 748 749 750 751 752


def print_library_summary(configs):
    library_name = configs[YAMLKeyword.library_name]
    title = "Library"
    header = ["key", "value"]
    data = list()
753 754 755 756 757 758 759 760 761
    data.append(["MACE Model Path",
                 "%s/%s/%s"
                 % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)])
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        data.append(["MACE Model Header Path",
                     "%s/%s/%s"
                     % (BUILD_OUTPUT_DIR, library_name,
                        MODEL_HEADER_DIR_PATH)])

762 763 764
    MaceLogger.summary(StringFormatter.table(header, data, title))


765
def convert_func(flags):
766 767 768 769 770 771 772 773 774
    configs = config_parser.parse(flags.config)
    library_name = configs[YAMLKeyword.library_name]
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
        sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
    os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
    if not os.path.exists(BUILD_DOWNLOADS_DIR):
        os.makedirs(BUILD_DOWNLOADS_DIR)
775

776 777 778 779 780 781 782 783 784 785
    model_output_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
    model_header_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_HEADER_DIR_PATH)
    # clear output dir
    if os.path.exists(model_output_dir):
        sh.rm("-rf", model_output_dir)
    os.makedirs(model_output_dir)
    if os.path.exists(model_header_dir):
        sh.rm("-rf", model_header_dir)
786

787 788 789 790
    if os.path.exists(MODEL_CODEGEN_DIR):
        sh.rm("-rf", MODEL_CODEGEN_DIR)
    if os.path.exists(ENGINE_CODEGEN_DIR):
        sh.rm("-rf", ENGINE_CODEGEN_DIR)
791

792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844
    if flags.model_data_format:
        model_data_format = flags.model_data_format
    else:
        model_data_format = configs.get(YAMLKeyword.model_data_format,
                                        "file")
    embed_model_data = model_data_format == ModelFormat.code

    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
    else:
        model_graph_format = configs.get(YAMLKeyword.model_graph_format,
                                         "file")
    if model_graph_format == ModelFormat.code:
        os.makedirs(model_header_dir)
        sh_commands.gen_mace_engine_factory_source(
            configs[YAMLKeyword.models].keys(),
            embed_model_data)
        sh.cp("-f", glob.glob("mace/codegen/engine/*.h"),
              model_header_dir)

    convert.convert(configs, MODEL_CODEGEN_DIR)

    for model_name, model_config in configs[YAMLKeyword.models].items():
        model_codegen_dir = "%s/%s" % (MODEL_CODEGEN_DIR, model_name)
        encrypt.encrypt(model_name,
                        "%s/%s.pb" % (model_codegen_dir, model_name),
                        "%s/%s.data" % (model_codegen_dir, model_name),
                        model_config[YAMLKeyword.runtime],
                        model_codegen_dir,
                        bool(model_config[YAMLKeyword.obfuscate]))

        if model_graph_format == ModelFormat.file:
            sh.mv("-f",
                  '%s/file/%s.pb' % (model_codegen_dir, model_name),
                  model_output_dir)
            sh.mv("-f",
                  '%s/file/%s.data' % (model_codegen_dir, model_name),
                  model_output_dir)
            sh.rm("-rf", '%s/code' % model_codegen_dir)
        else:
            if not embed_model_data:
                sh.mv("-f",
                      '%s/file/%s.data' % (model_codegen_dir, model_name),
                      model_output_dir)
                sh.rm('%s/code/tensor_data.cc' % model_codegen_dir)

            sh.cp("-f", glob.glob("mace/codegen/models/*/code/*.h"),
                  model_header_dir)

        MaceLogger.summary(
            StringFormatter.block("Model %s converted" % model_name))

    if model_graph_format == ModelFormat.code:
845
        build_model_lib(configs, flags.address_sanitizer, flags.debug_mode)
846 847 848 849 850 851 852

    print_library_summary(configs)


################################
# run
################################
L
liuqi 已提交
853
def build_mace_run(configs, target_abi, toolchain, enable_openmp,
854
                   address_sanitizer, mace_lib_type, debug_mode):
855 856 857 858 859 860 861 862 863 864 865 866 867 868
    library_name = configs[YAMLKeyword.library_name]

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    mace_run_target = MACE_RUN_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
        mace_run_target = MACE_RUN_DYNAMIC_TARGET
    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
L
liuqi 已提交
869
                   "You should convert model first.")
L
liyin 已提交
870
        build_arg = "--per_file_copt=mace/tools/mace_run.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa
871 872 873 874

    sh_commands.bazel_build(
        mace_run_target,
        abi=target_abi,
L
liuqi 已提交
875
        toolchain=toolchain,
B
Bin Li 已提交
876 877
        enable_hexagon=get_hexagon_mode(configs),
        enable_hta=get_hta_mode(configs),
L
lichao18 已提交
878
        enable_apu=get_apu_mode(configs),
879
        enable_openmp=enable_openmp,
Y
yejianwu 已提交
880
        enable_opencl=get_opencl_mode(configs),
881
        enable_quantize=get_quantize_mode(configs),
882
        address_sanitizer=address_sanitizer,
883 884
        symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),
        debug_mode=debug_mode,
885 886 887 888 889 890
        extra_args=build_arg
    )
    sh_commands.update_mace_run_binary(build_tmp_binary_dir,
                                       mace_lib_type == MACELibType.dynamic)


L
liuqi 已提交
891 892 893 894 895 896 897 898 899 900
def print_package_summary(package_path):
    title = "Library"
    header = ["key", "value"]
    data = list()
    data.append(["MACE Model package Path",
                 package_path])

    MaceLogger.summary(StringFormatter.table(header, data, title))


901
def run_mace(flags):
902
    configs = format_model_config(flags)
903 904

    clear_build_dirs(configs[YAMLKeyword.library_name])
905 906

    target_socs = configs[YAMLKeyword.target_socs]
907
    device_list = DeviceManager.list_devices(flags.device_yml)
L
liuqi 已提交
908
    if target_socs and TargetSOCTag.all not in target_socs:
L
liuqi 已提交
909 910
        device_list = [dev for dev in device_list
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
911
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
912 913 914 915 916
        if flags.target_socs == TargetSOCTag.random:
            target_devices = sh_commands.choose_a_random_device(
                device_list, target_abi)
        else:
            target_devices = device_list
917
        # build target
L
liuqi 已提交
918
        for dev in target_devices:
L
liuqi 已提交
919 920 921
            if target_abi in dev[YAMLKeyword.target_abis]:
                # get toolchain
                toolchain = infer_toolchain(target_abi)
922
                device = DeviceWrapper(dev)
L
liyin 已提交
923 924 925 926 927 928 929
                build_mace_run(configs,
                               target_abi,
                               toolchain,
                               flags.enable_openmp,
                               flags.address_sanitizer,
                               flags.mace_lib_type,
                               flags.debug_mode)
L
liuqi 已提交
930
                # run
931
                start_time = time.time()
L
liuqi 已提交
932 933
                with device.lock():
                    device.run_specify_abi(flags, configs, target_abi)
934 935
                elapse_minutes = (time.time() - start_time) / 60
                print("Elapse time: %f minutes." % elapse_minutes)
L
liuqi 已提交
936 937 938 939
            elif dev[YAMLKeyword.device_name] != SystemType.host:
                six.print_('The device with soc %s do not support abi %s' %
                           (dev[YAMLKeyword.target_socs], target_abi),
                           file=sys.stderr)
940

L
liuqi 已提交
941 942 943 944 945
    # package the output files
    package_path = sh_commands.packaging_lib(BUILD_OUTPUT_DIR,
                                             configs[YAMLKeyword.library_name])
    print_package_summary(package_path)

946

L
liuqi 已提交
947
################################
Y
yejianwu 已提交
948
# parsing arguments
L
liuqi 已提交
949 950 951 952 953 954 955 956 957 958 959 960
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
961
        return CaffeEnvType.DOCKER
L
liuqi 已提交
962
    elif v.lower() == 'local':
963
        return CaffeEnvType.LOCAL
L
liuqi 已提交
964 965 966 967
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


968 969 970 971 972 973 974 975 976
def str_to_mace_lib_type(v):
    if v.lower() == 'dynamic':
        return MACELibType.dynamic
    elif v.lower() == 'static':
        return MACELibType.static
    else:
        raise argparse.ArgumentTypeError('[dynamic| static] expected.')


977
def parse_args():
L
Liangliang He 已提交
978
    """Parses command line arguments."""
979 980 981
    all_type_parent_parser = argparse.ArgumentParser(add_help=False)
    all_type_parent_parser.add_argument(
        '--config',
L
Liangliang He 已提交
982
        type=str,
983
        default="",
L
liuqi 已提交
984
        required=True,
985
        help="the path of model yaml configuration file.")
986
    all_type_parent_parser.add_argument(
987
        "--model_graph_format",
988 989
        type=str,
        default="",
990 991 992 993 994 995
        help="[file, code], MACE Model graph format.")
    all_type_parent_parser.add_argument(
        "--model_data_format",
        type=str,
        default="",
        help="['file', 'code'], MACE Model data format.")
996 997 998 999 1000
    all_type_parent_parser.add_argument(
        "--target_abis",
        type=str,
        default="",
        help="Target ABIs, comma seperated list.")
1001 1002 1003 1004 1005
    all_type_parent_parser.add_argument(
        "--target_socs",
        type=str,
        default="",
        help="Target SOCs, comma seperated list.")
1006 1007 1008 1009
    all_type_parent_parser.add_argument(
        "--debug_mode",
        action="store_true",
        help="Reserve debug symbols.")
1010 1011
    convert_run_parent_parser = argparse.ArgumentParser(add_help=False)
    convert_run_parent_parser.add_argument(
1012 1013
        '--address_sanitizer',
        action="store_true",
L
liuqi 已提交
1014
        help="Whether to use address sanitizer to check memory error")
L
liyin 已提交
1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030

    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
    convert = subparsers.add_parser(
        'convert',
        parents=[all_type_parent_parser, convert_run_parent_parser],
        help='convert to mace model (file or code)')
    convert.set_defaults(func=convert_func)

    run = subparsers.add_parser(
        'run',
        parents=[all_type_parent_parser,
                 convert_run_parent_parser],
        help='run model in command line')
    run.set_defaults(func=run_mace)
    run.add_argument(
1031 1032 1033 1034
        "--mace_lib_type",
        type=str_to_mace_lib_type,
        default=DefaultValues.mace_lib_type,
        help="[static | dynamic], Which type MACE library to use.")
L
liyin 已提交
1035
    run.add_argument(
L
liyin 已提交
1036
        "--enable_openmp",
1037
        action="store_true",
L
liyin 已提交
1038
        help="Enable openmp for multiple thread.")
L
liyin 已提交
1039
    run.add_argument(
W
wuchenghui 已提交
1040 1041
        "--omp_num_threads",
        type=int,
1042
        default=DefaultValues.omp_num_threads,
W
wuchenghui 已提交
1043
        help="num of openmp threads")
L
liyin 已提交
1044
    run.add_argument(
W
wuchenghui 已提交
1045 1046
        "--cpu_affinity_policy",
        type=int,
1047
        default=DefaultValues.cpu_affinity_policy,
W
wuchenghui 已提交
1048
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
L
liyin 已提交
1049
    run.add_argument(
W
wuchenghui 已提交
1050 1051
        "--gpu_perf_hint",
        type=int,
1052
        default=DefaultValues.gpu_perf_hint,
W
wuchenghui 已提交
1053
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liyin 已提交
1054
    run.add_argument(
W
wuchenghui 已提交
1055 1056
        "--gpu_priority_hint",
        type=int,
1057
        default=DefaultValues.gpu_priority_hint,
W
wuchenghui 已提交
1058
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liyin 已提交
1059
    run.add_argument(
L
liuqi 已提交
1060 1061 1062 1063 1064
        "--device_yml",
        type=str,
        default='',
        help='embedded linux device config yml file'
    )
1065 1066 1067 1068
    run.add_argument(
        "--disable_tuning",
        action="store_true",
        help="Disable tuning for specific thread.")
1069 1070
    run.add_argument(
        "--round",
L
Liangliang He 已提交
1071
        type=int,
1072 1073 1074 1075 1076
        default=1,
        help="The model running round.")
    run.add_argument(
        "--validate",
        action="store_true",
1077 1078
        help="whether to verify the results are consistent with "
             "the frameworks.")
B
Bin Li 已提交
1079
    run.add_argument(
B
Bin Li 已提交
1080 1081 1082 1083 1084
        "--layers",
        type=str,
        default="-1",
        help="'start_layer:end_layer' or 'layer', similar to python slice."
             " Use with --validate flag.")
1085
    run.add_argument(
L
liuqi 已提交
1086 1087 1088
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
1089 1090
        help="[docker | local] you can specific caffe environment for"
             " validation. local environment or caffe docker image.")
1091 1092 1093 1094
    run.add_argument(
        "--vlog_level",
        type=int,
        default=0,
1095
        help="[1~5]. Verbose log level for debug.")
1096
    run.add_argument(
L
Liangliang He 已提交
1097
        "--gpu_out_of_range_check",
1098 1099 1100 1101 1102 1103
        action="store_true",
        help="Enable out of memory check for gpu.")
    run.add_argument(
        "--restart_round",
        type=int,
        default=1,
1104
        help="restart round between run.")
1105 1106 1107 1108 1109 1110
    run.add_argument(
        "--report",
        action="store_true",
        help="print run statistics report.")
    run.add_argument(
        "--report_dir",
1111 1112
        type=str,
        default="",
1113 1114
        help="print run statistics report.")
    run.add_argument(
李寅 已提交
1115 1116 1117 1118
        "--runtime_failure_ratio",
        type=float,
        default=0.0,
        help="[mock runtime failure ratio].")
李寅 已提交
1119 1120 1121 1122 1123 1124 1125 1126 1127
    run.add_argument(
        "--quantize_stat",
        action="store_true",
        help="whether to stat quantization range.")
    run.add_argument(
        "--input_dir",
        type=str,
        default="",
        help="quantize stat input dir.")
1128 1129 1130 1131 1132
    run.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="quantize stat output dir.")
1133 1134 1135 1136
    run.add_argument(
        "--cl_binary_to_code",
        action="store_true",
        help="convert OpenCL binaries to cpp.")
L
liyin 已提交
1137 1138 1139 1140
    run.add_argument(
        "--benchmark",
        action="store_true",
        help="enable op benchmark.")
L
Liangliang He 已提交
1141 1142
    return parser.parse_known_args()

1143

Y
yejianwu 已提交
1144
if __name__ == "__main__":
1145 1146
    flags, unparsed = parse_args()
    flags.func(flags)