converter.py 45.3 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
Y
yejianwu 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
import argparse
L
liuqi 已提交
20
import glob
L
Liangliang He 已提交
21
import sh
22
import sys
23
import time
Y
yejianwu 已提交
24
import yaml
25
import sh_commands
26
from enum import Enum
L
Liangliang He 已提交
27

L
liyin 已提交
28
sys.path.insert(0, "tools/python")  # noqa
L
liuqi 已提交
29 30
from common import *
from device import DeviceWrapper, DeviceManager
L
liyin 已提交
31 32 33 34 35
from utils import config_parser
import convert
import encrypt

from dana.dana_util import DanaUtil
36

Y
yejianwu 已提交
37 38 39 40
################################
# set environment
################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
A
Allen 已提交
41

42 43 44 45 46
################################
# common definitions
################################

ABITypeStrs = [
L
liuqi 已提交
47 48
    'armeabi-v7a',
    'arm64-v8a',
L
liuqi 已提交
49 50
    'arm64',
    'armhf',
L
liuqi 已提交
51
    'host',
52
]
L
liuqi 已提交
53

54 55 56 57 58
ModelFormatStrs = [
    "file",
    "code",
]

59 60 61
PlatformTypeStrs = [
    "tensorflow",
    "caffe",
L
liutuo 已提交
62
    "onnx",
63
    "megengine",
64
    "keras",
65
    "pytorch",
66 67 68 69 70 71 72 73
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
                    type=str)

RuntimeTypeStrs = [
    "cpu",
    "gpu",
    "dsp",
B
Bin Li 已提交
74
    "hta",
75
    "apu",
76 77 78
    "cpu+gpu"
]

L
liuqi 已提交
79
InOutDataTypeStrs = [
Y
yejianwu 已提交
80 81 82 83
    "int32",
    "float32",
]

L
liuqi 已提交
84 85
InOutDataType = Enum('InputDataType',
                     [(ele, ele) for ele in InOutDataTypeStrs],
Y
yejianwu 已提交
86 87
                     type=str)

L
liuqi 已提交
88
FPDataTypeStrs = [
89 90
    "fp16_fp32",
    "fp32_fp32",
L
luxuhui 已提交
91
    "bf16_fp32",
L
lichao18 已提交
92
    "fp16_fp16",
93 94
]

L
liuqi 已提交
95 96
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
                  type=str)
97

L
liuqi 已提交
98 99 100 101 102 103 104
DSPDataTypeStrs = [
    "uint8",
]

DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
                   type=str)

105 106 107 108 109 110 111
APUDataTypeStrs = [
    "uint8",
]

APUDataType = Enum('APUDataType', [(ele, ele) for ele in APUDataTypeStrs],
                   type=str)

112 113
WinogradParameters = [0, 2, 4]

114 115 116
DataFormatStrs = [
    "NONE",
    "NHWC",
117
    "NCHW",
118
    "OIHW",
119 120 121
]


122
class DefaultValues(object):
123
    mace_lib_type = MACELibType.static
124
    num_threads = -1,
125 126 127
    cpu_affinity_policy = 1,
    gpu_perf_hint = 3,
    gpu_priority_hint = 3,
128
    apu_cache_policy = 0,
129 130


131 132 133
class ValidationThreshold(object):
    cpu_threshold = 0.999,
    gpu_threshold = 0.995,
B
Bin Li 已提交
134
    quantize_threshold = 0.980,
135 136


137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
CPP_KEYWORDS = [
    'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
    'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor',
    'bool', 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t',
    'class', 'compl', 'concept', 'const', 'constexpr', 'const_cast',
    'continue', 'co_await', 'co_return', 'co_yield', 'decltype', 'default',
    'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit',
    'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if',
    'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace',
    'new', 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq',
    'private', 'protected', 'public', 'register', 'reinterpret_cast',
    'requires', 'return', 'short', 'signed', 'sizeof', 'static',
    'static_assert', 'static_cast', 'struct', 'switch', 'synchronized',
    'template', 'this', 'thread_local', 'throw', 'true', 'try', 'typedef',
    'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void',
    'volatile', 'wchar_t', 'while', 'xor', 'xor_eq', 'override', 'final',
    'transaction_safe', 'transaction_safe_dynamic', 'if', 'elif', 'else',
    'endif', 'defined', 'ifdef', 'ifndef', 'define', 'undef', 'include',
    'line', 'error', 'pragma',
]
Y
yejianwu 已提交
157

158

159 160 161
################################
# common functions
################################
162
def parse_device_type(runtime):
Y
yejianwu 已提交
163
    device_type = ""
164

165
    if runtime == RuntimeType.dsp:
166
        device_type = DeviceType.HEXAGON
B
Bin Li 已提交
167 168
    elif runtime == RuntimeType.hta:
        device_type = DeviceType.HTA
169
    elif runtime == RuntimeType.gpu:
170
        device_type = DeviceType.GPU
171
    elif runtime == RuntimeType.cpu:
172
        device_type = DeviceType.CPU
173 174
    elif runtime == RuntimeType.apu:
        device_type = DeviceType.APU
175

176
    return device_type
177

Y
yejianwu 已提交
178

L
luxuhui 已提交
179 180 181 182 183 184 185 186 187
def bfloat16_enabled(configs):
    for model_name in configs[YAMLKeyword.models]:
        model_config = configs[YAMLKeyword.models][model_name]
        dtype = model_config.get(YAMLKeyword.data_type, FPDataType.fp16_fp32)
        if dtype == FPDataType.bf16_fp32:
            return True
    return False


L
lichao18 已提交
188 189 190 191 192 193 194 195 196
def fp16_enabled(configs):
    for model_name in configs[YAMLKeyword.models]:
        model_config = configs[YAMLKeyword.models][model_name]
        dtype = model_config.get(YAMLKeyword.data_type, FPDataType.fp16_fp32)
        if dtype == FPDataType.fp16_fp16:
            return True
    return False


L
luxuhui 已提交
197
def hexagon_enabled(configs):
L
Liangliang He 已提交
198
    runtime_list = []
L
liuqi 已提交
199
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
200
        model_runtime = \
L
liuqi 已提交
201 202
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
L
Liangliang He 已提交
203 204
        runtime_list.append(model_runtime.lower())

L
liuqi 已提交
205
    if RuntimeType.dsp in runtime_list:
Y
yejianwu 已提交
206 207 208 209
        return True
    return False


L
luxuhui 已提交
210
def hta_enabled(configs):
B
Bin Li 已提交
211 212 213 214 215 216 217 218 219 220 221 222
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
        model_runtime = \
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.hta in runtime_list:
        return True
    return False


L
luxuhui 已提交
223
def apu_enabled(configs):
L
lichao18 已提交
224 225 226 227 228 229 230 231 232 233 234 235
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
        model_runtime = \
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.apu in runtime_list:
        return True
    return False


L
luxuhui 已提交
236
def opencl_enabled(configs):
Y
yejianwu 已提交
237 238
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
239
        model_runtime = \
Y
yejianwu 已提交
240 241 242 243
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

B
Bin Li 已提交
244 245
    if RuntimeType.gpu in runtime_list or RuntimeType.cpu_gpu in runtime_list \
            or RuntimeType.hta in runtime_list:
Y
yejianwu 已提交
246 247 248 249
        return True
    return False


L
luxuhui 已提交
250
def quantize_enabled(configs):
251
    for model_name in configs[YAMLKeyword.models]:
252
        quantize = \
253 254 255 256 257 258 259 260
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.quantize, 0)
        if quantize == 1:
            return True

    return False


261 262 263 264 265 266 267 268 269
def get_symbol_hidden_mode(debug_mode, mace_lib_type=None):
    if not mace_lib_type:
        return True
    if debug_mode or mace_lib_type == MACELibType.dynamic:
        return False
    else:
        return True


270 271
def md5sum(str):
    md5 = hashlib.md5()
272
    md5.update(str.encode('utf-8'))
273
    return md5.hexdigest()
274

Y
yejianwu 已提交
275

276 277 278 279 280 281
def sha256_checksum(fname):
    hash_func = hashlib.sha256()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_func.update(chunk)
    return hash_func.hexdigest()
Y
yejianwu 已提交
282

W
wuchenghui 已提交
283

284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
def download_file(url, dst, num_retries=3):
    from six.moves import urllib

    try:
        urllib.request.urlretrieve(url, dst)
        MaceLogger.info('\nDownloaded successfully.')
    except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
            urllib.error.URLError) as e:
        MaceLogger.warning('Download error:' + str(e))
        if num_retries > 0:
            return download_file(url, dst, num_retries - 1)
        else:
            return False
    return True


def get_model_files(model_config, model_output_dir):
    if not os.path.exists(model_output_dir):
        os.makedirs(model_output_dir)
    model_file_path = model_config[YAMLKeyword.model_file_path]
    model_sha256_checksum = model_config[YAMLKeyword.model_sha256_checksum]
    weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
L
luxuhui 已提交
306 307 308 309
    weight_sha256_checksum = \
        model_config.get(YAMLKeyword.weight_sha256_checksum, "")
    quantize_range_file_path = \
        model_config.get(YAMLKeyword.quantize_range_file, "")
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
    model_file = model_file_path
    weight_file = weight_file_path
    quantize_range_file = quantize_range_file_path

    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
        model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
        if not os.path.exists(model_file) or \
                sha256_checksum(model_file) != model_sha256_checksum:
            MaceLogger.info("Downloading model, please wait ...")
            if not download_file(model_file_path, model_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
        model_config[YAMLKeyword.model_file_path] = model_file

    if sha256_checksum(model_file) != model_sha256_checksum:
326
        error_info = model_file_path + \
327 328
                     " model file sha256checksum not match " + \
                     model_sha256_checksum
329
        MaceLogger.error(ModuleName.MODEL_CONVERTER, error_info)
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
        weight_file = \
            model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
        if not os.path.exists(weight_file) or \
                sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.info("Downloading model weight, please wait ...")
            if not download_file(weight_file_path, weight_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
    model_config[YAMLKeyword.weight_file_path] = weight_file

    if weight_file:
        if sha256_checksum(weight_file) != weight_sha256_checksum:
345
            error_info = weight_file_path + \
346 347
                         " weight file sha256checksum not match " + \
                         weight_sha256_checksum
348
            MaceLogger.error(ModuleName.MODEL_CONVERTER, error_info)
349 350 351 352 353 354 355 356 357 358 359 360

    if quantize_range_file_path.startswith("http://") or \
            quantize_range_file_path.startswith("https://"):
        quantize_range_file = \
            model_output_dir + "/" + md5sum(quantize_range_file_path) \
            + ".range"
        if not download_file(quantize_range_file_path, quantize_range_file):
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "Model range file download failed.")
    model_config[YAMLKeyword.quantize_range_file] = quantize_range_file


361 362
def format_model_config(flags):
    with open(flags.config) as f:
363
        configs = yaml.load(f)
W
wuchenghui 已提交
364

365 366
    library_name = configs.get(YAMLKeyword.library_name, "")
    mace_check(len(library_name) > 0,
L
liuqi 已提交
367
               ModuleName.YAML_CONFIG, "library name should not be empty")
368

369 370 371 372
    if flags.target_abis:
        target_abis = flags.target_abis.split(',')
    else:
        target_abis = configs.get(YAMLKeyword.target_abis, [])
373 374
    mace_check((isinstance(target_abis, list) and len(target_abis) > 0),
               ModuleName.YAML_CONFIG, "target_abis list is needed")
375
    configs[YAMLKeyword.target_abis] = target_abis
376 377 378 379 380 381
    for abi in target_abis:
        mace_check(abi in ABITypeStrs,
                   ModuleName.YAML_CONFIG,
                   "target_abis must be in " + str(ABITypeStrs))

    target_socs = configs.get(YAMLKeyword.target_socs, "")
L
liuqi 已提交
382 383
    if flags.target_socs and flags.target_socs != TargetSOCTag.random \
            and flags.target_socs != TargetSOCTag.all:
384
        configs[YAMLKeyword.target_socs] = \
L
liuqi 已提交
385
            [soc.lower() for soc in flags.target_socs.split(',')]
386
    elif not target_socs:
387 388 389 390
        configs[YAMLKeyword.target_socs] = []
    elif not isinstance(target_socs, list):
        configs[YAMLKeyword.target_socs] = [target_socs]

391 392 393
    configs[YAMLKeyword.target_socs] = \
        [soc.lower() for soc in configs[YAMLKeyword.target_socs]]

L
liuqi 已提交
394 395
    if ABIType.armeabi_v7a in target_abis \
            or ABIType.arm64_v8a in target_abis:
396
        available_socs = sh_commands.adb_get_all_socs()
397
        target_socs = configs[YAMLKeyword.target_socs]
L
liuqi 已提交
398
        if TargetSOCTag.all in target_socs:
399 400
            mace_check(available_socs,
                       ModuleName.YAML_CONFIG,
L
liuqi 已提交
401 402 403
                       "Android abi is listed in config file and "
                       "build for all SOCs plugged in computer, "
                       "But no android phone found, "
404 405
                       "you at least plug in one phone")
        else:
406 407 408 409 410 411
            for soc in target_socs:
                mace_check(soc in available_socs,
                           ModuleName.YAML_CONFIG,
                           "Build specified SOC library, "
                           "you must plug in a phone using the SOC")

412 413
    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
414
    else:
415 416 417 418 419 420 421 422
        model_graph_format = configs.get(YAMLKeyword.model_graph_format, "")
    mace_check(model_graph_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_graph_format and '
               "model_graph_format must be in " + str(ModelFormatStrs))
    configs[YAMLKeyword.model_graph_format] = model_graph_format
    if flags.model_data_format:
        model_data_format = flags.model_data_format
423
    else:
424 425 426 427 428 429 430 431 432 433 434 435
        model_data_format = configs.get(YAMLKeyword.model_data_format, "")
    configs[YAMLKeyword.model_data_format] = model_data_format
    mace_check(model_data_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_data_format and '
               "model_data_format must be in " + str(ModelFormatStrs))

    mace_check(not (model_graph_format == ModelFormat.file
                    and model_data_format == ModelFormat.code),
               ModuleName.YAML_CONFIG,
               "If model_graph format is 'file',"
               " the model_data_format must be 'file' too")
Y
yejianwu 已提交
436

437 438 439 440
    model_names = configs.get(YAMLKeyword.models, [])
    mace_check(len(model_names) > 0, ModuleName.YAML_CONFIG,
               "no model found in config file")

L
liuqi 已提交
441
    model_name_reg = re.compile(r'^[a-zA-Z0-9_]+$')
442 443 444 445 446 447 448 449
    for model_name in model_names:
        # check model_name legality
        mace_check(model_name not in CPP_KEYWORDS,
                   ModuleName.YAML_CONFIG,
                   "model name should not be c++ keyword.")
        mace_check((model_name[0] == '_' or model_name[0].isalpha())
                   and bool(model_name_reg.match(model_name)),
                   ModuleName.YAML_CONFIG,
L
liuqi 已提交
450
                   "model name should Meet the c++ naming convention"
451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
                   " which start with '_' or alpha"
                   " and only contain alpha, number and '_'")

        model_config = configs[YAMLKeyword.models][model_name]
        platform = model_config.get(YAMLKeyword.platform, "")
        mace_check(platform in PlatformTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'platform' must be in " + str(PlatformTypeStrs))

        for key in [YAMLKeyword.model_file_path,
                    YAMLKeyword.model_sha256_checksum]:
            value = model_config.get(key, "")
            mace_check(value != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" % key)

        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        if weight_file_path:
L
liuqi 已提交
468
            weight_checksum = \
469 470 471 472 473 474 475
                model_config.get(YAMLKeyword.weight_sha256_checksum, "")
            mace_check(weight_checksum != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" %
                       YAMLKeyword.weight_sha256_checksum)
        else:
            model_config[YAMLKeyword.weight_sha256_checksum] = ""

476 477
        get_model_files(model_config, BUILD_DOWNLOADS_DIR)

478 479 480 481 482 483 484 485 486 487
        runtime = model_config.get(YAMLKeyword.runtime, "")
        mace_check(runtime in RuntimeTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'runtime' must be in " + str(RuntimeTypeStrs))
        if ABIType.host in target_abis:
            mace_check(runtime == RuntimeType.cpu,
                       ModuleName.YAML_CONFIG,
                       "host only support cpu runtime now.")

        data_type = model_config.get(YAMLKeyword.data_type, "")
L
liuqi 已提交
488
        if runtime == RuntimeType.dsp:
L
liuqi 已提交
489 490 491 492 493 494 495 496
            if len(data_type) > 0:
                mace_check(data_type in DSPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(DSPDataTypeStrs)
                           + " for dsp runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    DSPDataType.uint8.value
497 498 499 500 501 502 503 504 505
        elif runtime == RuntimeType.apu:
            if len(data_type) > 0:
                mace_check(data_type in APUDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(APUDataTypeStrs)
                           + " for apu runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    APUDataType.uint8.value
L
liuqi 已提交
506 507 508 509 510 511 512 513 514 515 516 517 518
        else:
            if len(data_type) > 0:
                mace_check(data_type in FPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(FPDataTypeStrs)
                           + " for cpu runtime")
            else:
                if runtime == RuntimeType.cpu:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp32_fp32.value
                else:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp16_fp32.value
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533

        subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
        mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
                   "at least one subgraph is needed")

        for subgraph in subgraphs:
            for key in [YAMLKeyword.input_tensors,
                        YAMLKeyword.input_shapes,
                        YAMLKeyword.output_tensors,
                        YAMLKeyword.output_shapes]:
                value = subgraph.get(key, "")
                mace_check(value != "", ModuleName.YAML_CONFIG,
                           "'%s' is necessary in subgraph" % key)
                if not isinstance(value, list):
                    subgraph[key] = [value]
534
                subgraph[key] = [str(v) for v in subgraph[key]]
535 536 537 538 539 540 541
# --inputs_shapes will be passed to ELF file `mace_run_static', if input_shapes
# contains spaces, such as: '1, 3, 224, 224', because mace_run.cc use gflags to
# parse command line arguments, --input_shapes 1, 3, 224, 224 will be passed as
# `--input_shapes 1,'. So we strip out spaces here.
                if key in [YAMLKeyword.input_shapes,
                           YAMLKeyword.output_shapes]:
                    subgraph[key] = [e.replace(' ', '') for e in subgraph[key]]
542 543 544 545 546 547 548 549 550
            input_size = len(subgraph[YAMLKeyword.input_tensors])
            output_size = len(subgraph[YAMLKeyword.output_tensors])

            mace_check(len(subgraph[YAMLKeyword.input_shapes]) == input_size,
                       ModuleName.YAML_CONFIG,
                       "input shapes' size not equal inputs' size.")
            mace_check(len(subgraph[YAMLKeyword.output_shapes]) == output_size,
                       ModuleName.YAML_CONFIG,
                       "output shapes' size not equal outputs' size.")
551

B
Bin Li 已提交
552 553 554 555 556 557 558 559 560 561
            for key in [YAMLKeyword.check_tensors,
                        YAMLKeyword.check_shapes]:
                value = subgraph.get(key, "")
                if value != "":
                    if not isinstance(value, list):
                        subgraph[key] = [value]
                    subgraph[key] = [str(v) for v in subgraph[key]]
                else:
                    subgraph[key] = []

L
liuqi 已提交
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
            for key in [YAMLKeyword.input_data_types,
                        YAMLKeyword.output_data_types]:
                if key == YAMLKeyword.input_data_types:
                    count = input_size
                else:
                    count = output_size
                data_types = subgraph.get(key, "")
                if data_types:
                    if not isinstance(data_types, list):
                        subgraph[key] = [data_types] * count
                    for data_type in subgraph[key]:
                        mace_check(data_type in InOutDataTypeStrs,
                                   ModuleName.YAML_CONFIG,
                                   key + " must be in "
                                   + str(InOutDataTypeStrs))
                else:
                    subgraph[key] = [InOutDataType.float32] * count
Y
yejianwu 已提交
579

580 581 582 583
            input_data_formats = subgraph.get(YAMLKeyword.input_data_formats,
                                              [])
            if input_data_formats:
                if not isinstance(input_data_formats, list):
584
                    subgraph[YAMLKeyword.input_data_formats] = \
585
                        [input_data_formats] * input_size
586 587
                else:
                    mace_check(len(input_data_formats)
588
                               == input_size,
589 590
                               ModuleName.YAML_CONFIG,
                               "input_data_formats should match"
591
                               " the size of input.")
592
                for input_data_format in \
593 594 595 596 597
                        subgraph[YAMLKeyword.input_data_formats]:
                    mace_check(input_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs) + ", but got "
Y
yejianwu 已提交
598
                               + input_data_format)
599
            else:
600 601
                subgraph[YAMLKeyword.input_data_formats] = \
                    [DataFormat.NHWC] * input_size
602 603 604 605 606 607

            output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
                                               [])
            if output_data_formats:
                if not isinstance(output_data_formats, list):
                    subgraph[YAMLKeyword.output_data_formats] = \
608
                        [output_data_formats] * output_size
609 610
                else:
                    mace_check(len(output_data_formats)
611
                               == output_size,
612 613 614
                               ModuleName.YAML_CONFIG,
                               "output_data_formats should match"
                               " the size of output")
615
                for output_data_format in \
616 617 618
                        subgraph[YAMLKeyword.output_data_formats]:
                    mace_check(output_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
Y
yejianwu 已提交
619
                               "'output_data_formats' must be in "
620 621
                               + str(DataFormatStrs))
            else:
622
                subgraph[YAMLKeyword.output_data_formats] = \
623
                    [DataFormat.NHWC] * output_size
624

625 626 627 628
            validation_threshold = subgraph.get(
                YAMLKeyword.validation_threshold, {})
            if not isinstance(validation_threshold, dict):
                raise argparse.ArgumentTypeError(
L
liuqi 已提交
629
                    'similarity threshold must be a dict.')
630 631

            threshold_dict = {
632 633
                DeviceType.CPU: ValidationThreshold.cpu_threshold,
                DeviceType.GPU: ValidationThreshold.gpu_threshold,
B
Bin Li 已提交
634 635
                DeviceType.HEXAGON: ValidationThreshold.quantize_threshold,
                DeviceType.HTA: ValidationThreshold.quantize_threshold,
L
lichao18 已提交
636
                DeviceType.APU: ValidationThreshold.quantize_threshold,
B
Bin Li 已提交
637
                DeviceType.QUANTIZE: ValidationThreshold.quantize_threshold,
L
liuqi 已提交
638
            }
639 640 641 642 643
            for k, v in six.iteritems(validation_threshold):
                if k.upper() == 'DSP':
                    k = DeviceType.HEXAGON
                if k.upper() not in (DeviceType.CPU,
                                     DeviceType.GPU,
李寅 已提交
644
                                     DeviceType.HEXAGON,
B
Bin Li 已提交
645
                                     DeviceType.HTA,
B
Bin Li 已提交
646
                                     DeviceType.QUANTIZE):
647
                    raise argparse.ArgumentTypeError(
L
liuqi 已提交
648
                        'Unsupported validation threshold runtime: %s' % k)
649 650 651 652
                threshold_dict[k.upper()] = v

            subgraph[YAMLKeyword.validation_threshold] = threshold_dict

L
liuqi 已提交
653 654 655 656 657 658 659 660
            validation_inputs_data = subgraph.get(
                YAMLKeyword.validation_inputs_data, [])
            if not isinstance(validation_inputs_data, list):
                subgraph[YAMLKeyword.validation_inputs_data] = [
                    validation_inputs_data]
            else:
                subgraph[YAMLKeyword.validation_inputs_data] = \
                    validation_inputs_data
L
liutuo 已提交
661 662 663 664

            onnx_backend = subgraph.get(
                YAMLKeyword.backend, "tensorflow")
            subgraph[YAMLKeyword.backend] = onnx_backend
665 666 667 668 669 670 671 672
            validation_outputs_data = subgraph.get(
                YAMLKeyword.validation_outputs_data, [])
            if not isinstance(validation_outputs_data, list):
                subgraph[YAMLKeyword.validation_outputs_data] = [
                    validation_outputs_data]
            else:
                subgraph[YAMLKeyword.validation_outputs_data] = \
                    validation_outputs_data
673 674 675 676 677 678
            input_ranges = subgraph.get(
                YAMLKeyword.input_ranges, [])
            if not isinstance(input_ranges, list):
                subgraph[YAMLKeyword.input_ranges] = [input_ranges]
            else:
                subgraph[YAMLKeyword.input_ranges] = input_ranges
L
liuqi 已提交
679
            subgraph[YAMLKeyword.input_ranges] = \
680
                [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
W
wuchenghui 已提交
681

682 683 684 685 686 687 688 689 690 691
            accuracy_validation_script = subgraph.get(
                YAMLKeyword.accuracy_validation_script, "")
            if isinstance(accuracy_validation_script, list):
                mace_check(len(accuracy_validation_script) == 1,
                           ModuleName.YAML_CONFIG,
                           "Only support one accuracy validation script")
                accuracy_validation_script = accuracy_validation_script[0]
            subgraph[YAMLKeyword.accuracy_validation_script] = \
                accuracy_validation_script

692
        for key in [YAMLKeyword.limit_opencl_kernel_time,
693
                    YAMLKeyword.opencl_queue_window_size,
694 695
                    YAMLKeyword.nnlib_graph_mode,
                    YAMLKeyword.obfuscate,
李寅 已提交
696
                    YAMLKeyword.winograd,
697
                    YAMLKeyword.quantize,
B
Bin Li 已提交
698
                    YAMLKeyword.quantize_large_weights,
699
                    YAMLKeyword.change_concat_ranges]:
700 701 702
            value = model_config.get(key, "")
            if value == "":
                model_config[key] = 0
L
Liangliang He 已提交
703

B
Bin Li 已提交
704 705 706 707 708 709
        mace_check(model_config[YAMLKeyword.quantize] == 0 or
                   model_config[YAMLKeyword.quantize_large_weights] == 0,
                   ModuleName.YAML_CONFIG,
                   "quantize and quantize_large_weights should not be set to 1"
                   " at the same time.")

710 711 712 713 714 715
        mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
                   ModuleName.YAML_CONFIG,
                   "'winograd' parameters must be in "
                   + str(WinogradParameters) +
                   ". 0 for disable winograd convolution")

716
    return configs
Y
yejianwu 已提交
717

W
wuchenghui 已提交
718

719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739
def clear_build_dirs(library_name):
    # make build dir
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    # clear temp build dir
    tmp_build_dir = os.path.join(BUILD_OUTPUT_DIR, library_name,
                                 BUILD_TMP_DIR_NAME)
    if os.path.exists(tmp_build_dir):
        sh.rm('-rf', tmp_build_dir)
    os.makedirs(tmp_build_dir)
    # clear lib dir
    lib_output_dir = os.path.join(
        BUILD_OUTPUT_DIR, library_name, OUTPUT_LIBRARY_DIR_NAME)
    if os.path.exists(lib_output_dir):
        sh.rm('-rf', lib_output_dir)


################################
# convert
################################
def print_configuration(configs):
740 741 742 743 744 745 746 747 748
    title = "Common Configuration"
    header = ["key", "value"]
    data = list()
    data.append([YAMLKeyword.library_name,
                 configs[YAMLKeyword.library_name]])
    data.append([YAMLKeyword.target_abis,
                 configs[YAMLKeyword.target_abis]])
    data.append([YAMLKeyword.target_socs,
                 configs[YAMLKeyword.target_socs]])
749 750 751 752
    data.append([YAMLKeyword.model_graph_format,
                 configs[YAMLKeyword.model_graph_format]])
    data.append([YAMLKeyword.model_data_format,
                 configs[YAMLKeyword.model_data_format]])
753
    MaceLogger.summary(StringFormatter.table(header, data, title))
L
Liangliang He 已提交
754

Y
yejianwu 已提交
755

756
def build_model_lib(configs, address_sanitizer, debug_mode):
757
    MaceLogger.header(StringFormatter.block("Building model library"))
758

759 760 761
    # create model library dir
    library_name = configs[YAMLKeyword.library_name]
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
762 763 764 765 766
        model_lib_output_path = get_model_lib_output_path(library_name,
                                                          target_abi)
        library_out_dir = os.path.dirname(model_lib_output_path)
        if not os.path.exists(library_out_dir):
            os.makedirs(library_out_dir)
L
liuqi 已提交
767
        toolchain = infer_toolchain(target_abi)
768
        sh_commands.bazel_build(
769
            MODEL_LIB_TARGET,
770
            abi=target_abi,
L
liuqi 已提交
771
            toolchain=toolchain,
L
luxuhui 已提交
772 773 774 775 776 777
            enable_hexagon=hexagon_enabled(configs),
            enable_hta=hta_enabled(configs),
            enable_apu=apu_enabled(configs),
            enable_opencl=opencl_enabled(configs),
            enable_quantize=quantize_enabled(configs),
            enable_bfloat16=bfloat16_enabled(configs),
L
lichao18 已提交
778
            enable_fp16=fp16_enabled(configs),
779
            address_sanitizer=address_sanitizer,
780 781
            symbol_hidden=get_symbol_hidden_mode(debug_mode),
            debug_mode=debug_mode
782 783
        )

784
        sh.cp("-f", MODEL_LIB_PATH, model_lib_output_path)
785 786 787 788 789 790 791


def print_library_summary(configs):
    library_name = configs[YAMLKeyword.library_name]
    title = "Library"
    header = ["key", "value"]
    data = list()
792 793 794 795 796 797 798 799 800
    data.append(["MACE Model Path",
                 "%s/%s/%s"
                 % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)])
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        data.append(["MACE Model Header Path",
                     "%s/%s/%s"
                     % (BUILD_OUTPUT_DIR, library_name,
                        MODEL_HEADER_DIR_PATH)])

801 802 803
    MaceLogger.summary(StringFormatter.table(header, data, title))


804
def convert_func(flags):
805
    configs = config_parser.parse(flags.config)
806
    print(configs)
807 808 809 810 811 812 813 814
    library_name = configs[YAMLKeyword.library_name]
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
        sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
    os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
    if not os.path.exists(BUILD_DOWNLOADS_DIR):
        os.makedirs(BUILD_DOWNLOADS_DIR)
815

816 817 818 819 820 821 822 823 824 825
    model_output_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
    model_header_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_HEADER_DIR_PATH)
    # clear output dir
    if os.path.exists(model_output_dir):
        sh.rm("-rf", model_output_dir)
    os.makedirs(model_output_dir)
    if os.path.exists(model_header_dir):
        sh.rm("-rf", model_header_dir)
826

827 828 829 830
    if os.path.exists(MODEL_CODEGEN_DIR):
        sh.rm("-rf", MODEL_CODEGEN_DIR)
    if os.path.exists(ENGINE_CODEGEN_DIR):
        sh.rm("-rf", ENGINE_CODEGEN_DIR)
831

B
Bin Li 已提交
832 833 834
    if flags.quantize_stat:
        configs[YAMLKeyword.quantize_stat] = flags.quantize_stat

835 836 837 838 839 840 841 842 843 844 845 846
    if flags.model_data_format:
        model_data_format = flags.model_data_format
    else:
        model_data_format = configs.get(YAMLKeyword.model_data_format,
                                        "file")
    embed_model_data = model_data_format == ModelFormat.code

    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
    else:
        model_graph_format = configs.get(YAMLKeyword.model_graph_format,
                                         "file")
L
luxuhui 已提交
847 848 849 850 851 852
    embed_graph_def = model_graph_format == ModelFormat.code
    if flags.enable_micro:
        mace_check((not embed_model_data) and (not embed_graph_def),
                   ModuleName.YAML_CONFIG,
                   "You should specify file mode to convert micro model.")
    if embed_graph_def:
853 854 855 856 857 858 859
        os.makedirs(model_header_dir)
        sh_commands.gen_mace_engine_factory_source(
            configs[YAMLKeyword.models].keys(),
            embed_model_data)
        sh.cp("-f", glob.glob("mace/codegen/engine/*.h"),
              model_header_dir)

L
luxuhui 已提交
860
    convert.convert(configs, MODEL_CODEGEN_DIR, flags.enable_micro)
861 862

    for model_name, model_config in configs[YAMLKeyword.models].items():
L
luxuhui 已提交
863 864 865 866 867 868 869
        if flags.enable_micro:
            data_type = model_config.get(YAMLKeyword.data_type, "")
            mace_check(data_type == FPDataType.fp32_fp32.value or
                       data_type == FPDataType.bf16_fp32.value,
                       ModuleName.YAML_CONFIG,
                       "You should specify fp32_fp32 or bf16_fp32 data type "
                       "to convert micro model.")
870 871
        model_codegen_dir = "%s/%s" % (MODEL_CODEGEN_DIR, model_name)
        encrypt.encrypt(model_name,
872 873 874 875
                        "%s/model/%s.pb" % (model_codegen_dir, model_name),
                        "%s/model/%s.data" % (model_codegen_dir, model_name),
                        config_parser.parse_device_type(
                            model_config[YAMLKeyword.runtime]),
876
                        model_codegen_dir,
877 878 879
                        bool(model_config.get(YAMLKeyword.obfuscate, 1)),
                        model_graph_format == "code",
                        model_data_format == "code")
880 881 882

        if model_graph_format == ModelFormat.file:
            sh.mv("-f",
883
                  '%s/model/%s.pb' % (model_codegen_dir, model_name),
884 885
                  model_output_dir)
            sh.mv("-f",
886
                  '%s/model/%s.data' % (model_codegen_dir, model_name),
887
                  model_output_dir)
L
luxuhui 已提交
888 889 890
            if flags.enable_micro:
                sh.mv("-f", '%s/model/%s_micro.tar.gz' %
                      (model_codegen_dir, model_name), model_output_dir)
891 892 893
        else:
            if not embed_model_data:
                sh.mv("-f",
894
                      '%s/model/%s.data' % (model_codegen_dir, model_name),
895 896 897 898 899 900 901 902 903
                      model_output_dir)

            sh.cp("-f", glob.glob("mace/codegen/models/*/code/*.h"),
                  model_header_dir)

        MaceLogger.summary(
            StringFormatter.block("Model %s converted" % model_name))

    if model_graph_format == ModelFormat.code:
904
        build_model_lib(configs, flags.address_sanitizer, flags.debug_mode)
905 906 907 908 909 910 911

    print_library_summary(configs)


################################
# run
################################
912
def build_mace_run(configs, target_abi, toolchain,
913
                   address_sanitizer, mace_lib_type, debug_mode):
914 915 916 917 918 919 920 921 922 923 924 925 926 927
    library_name = configs[YAMLKeyword.library_name]

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    mace_run_target = MACE_RUN_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
        mace_run_target = MACE_RUN_DYNAMIC_TARGET
    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
L
liuqi 已提交
928
                   "You should convert model first.")
L
liyin 已提交
929
        build_arg = "--per_file_copt=mace/tools/mace_run.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa
930 931 932 933

    sh_commands.bazel_build(
        mace_run_target,
        abi=target_abi,
L
liuqi 已提交
934
        toolchain=toolchain,
L
luxuhui 已提交
935 936 937 938 939 940
        enable_hexagon=hexagon_enabled(configs),
        enable_hta=hta_enabled(configs),
        enable_apu=apu_enabled(configs),
        enable_opencl=opencl_enabled(configs),
        enable_quantize=quantize_enabled(configs),
        enable_bfloat16=bfloat16_enabled(configs),
L
lichao18 已提交
941
        enable_fp16=fp16_enabled(configs),
942
        address_sanitizer=address_sanitizer,
943 944
        symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),
        debug_mode=debug_mode,
945 946 947 948 949 950
        extra_args=build_arg
    )
    sh_commands.update_mace_run_binary(build_tmp_binary_dir,
                                       mace_lib_type == MACELibType.dynamic)


L
liuqi 已提交
951 952 953 954 955 956 957 958 959 960
def print_package_summary(package_path):
    title = "Library"
    header = ["key", "value"]
    data = list()
    data.append(["MACE Model package Path",
                 package_path])

    MaceLogger.summary(StringFormatter.table(header, data, title))


961
def run_mace(flags):
962
    configs = format_model_config(flags)
963 964

    clear_build_dirs(configs[YAMLKeyword.library_name])
965 966

    target_socs = configs[YAMLKeyword.target_socs]
967
    device_list = DeviceManager.list_devices(flags.device_yml)
L
liuqi 已提交
968
    if target_socs and TargetSOCTag.all not in target_socs:
L
liuqi 已提交
969 970
        device_list = [dev for dev in device_list
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
971
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
972 973 974 975 976
        if flags.target_socs == TargetSOCTag.random:
            target_devices = sh_commands.choose_a_random_device(
                device_list, target_abi)
        else:
            target_devices = device_list
977
        # build target
L
liuqi 已提交
978
        for dev in target_devices:
L
liuqi 已提交
979 980 981
            if target_abi in dev[YAMLKeyword.target_abis]:
                # get toolchain
                toolchain = infer_toolchain(target_abi)
982
                device = DeviceWrapper(dev)
L
liyin 已提交
983 984 985 986 987 988
                build_mace_run(configs,
                               target_abi,
                               toolchain,
                               flags.address_sanitizer,
                               flags.mace_lib_type,
                               flags.debug_mode)
L
liuqi 已提交
989
                # run
990
                start_time = time.time()
L
liuqi 已提交
991 992
                with device.lock():
                    device.run_specify_abi(flags, configs, target_abi)
993 994
                elapse_minutes = (time.time() - start_time) / 60
                print("Elapse time: %f minutes." % elapse_minutes)
L
liuqi 已提交
995 996 997 998
            elif dev[YAMLKeyword.device_name] != SystemType.host:
                six.print_('The device with soc %s do not support abi %s' %
                           (dev[YAMLKeyword.target_socs], target_abi),
                           file=sys.stderr)
999

L
liuqi 已提交
1000 1001 1002 1003 1004
    # package the output files
    package_path = sh_commands.packaging_lib(BUILD_OUTPUT_DIR,
                                             configs[YAMLKeyword.library_name])
    print_package_summary(package_path)

1005

L
liuqi 已提交
1006
################################
Y
yejianwu 已提交
1007
# parsing arguments
L
liuqi 已提交
1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
1020
        return CaffeEnvType.DOCKER
L
liuqi 已提交
1021
    elif v.lower() == 'local':
1022
        return CaffeEnvType.LOCAL
L
liuqi 已提交
1023 1024 1025 1026
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


1027 1028 1029 1030 1031 1032 1033 1034 1035
def str_to_mace_lib_type(v):
    if v.lower() == 'dynamic':
        return MACELibType.dynamic
    elif v.lower() == 'static':
        return MACELibType.static
    else:
        raise argparse.ArgumentTypeError('[dynamic| static] expected.')


1036
def parse_args():
L
Liangliang He 已提交
1037
    """Parses command line arguments."""
1038 1039 1040
    all_type_parent_parser = argparse.ArgumentParser(add_help=False)
    all_type_parent_parser.add_argument(
        '--config',
L
Liangliang He 已提交
1041
        type=str,
1042
        default="",
L
liuqi 已提交
1043
        required=True,
1044
        help="the path of model yaml configuration file.")
1045
    all_type_parent_parser.add_argument(
1046
        "--model_graph_format",
1047 1048
        type=str,
        default="",
1049 1050 1051 1052 1053 1054
        help="[file, code], MACE Model graph format.")
    all_type_parent_parser.add_argument(
        "--model_data_format",
        type=str,
        default="",
        help="['file', 'code'], MACE Model data format.")
1055 1056 1057 1058 1059
    all_type_parent_parser.add_argument(
        "--target_abis",
        type=str,
        default="",
        help="Target ABIs, comma seperated list.")
1060 1061 1062 1063 1064
    all_type_parent_parser.add_argument(
        "--target_socs",
        type=str,
        default="",
        help="Target SOCs, comma seperated list.")
1065 1066 1067 1068
    all_type_parent_parser.add_argument(
        "--debug_mode",
        action="store_true",
        help="Reserve debug symbols.")
1069 1070
    convert_run_parent_parser = argparse.ArgumentParser(add_help=False)
    convert_run_parent_parser.add_argument(
1071 1072
        '--address_sanitizer',
        action="store_true",
L
liuqi 已提交
1073
        help="Whether to use address sanitizer to check memory error")
B
Bin Li 已提交
1074 1075 1076 1077
    convert_run_parent_parser.add_argument(
        "--quantize_stat",
        action="store_true",
        help="whether to stat quantization range.")
L
liyin 已提交
1078 1079 1080 1081 1082 1083 1084

    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
    convert = subparsers.add_parser(
        'convert',
        parents=[all_type_parent_parser, convert_run_parent_parser],
        help='convert to mace model (file or code)')
L
luxuhui 已提交
1085 1086 1087 1088
    convert.add_argument(
        "--enable_micro",
        action="store_true",
        help="enable convert micro.")
L
liyin 已提交
1089 1090 1091 1092 1093 1094 1095 1096 1097
    convert.set_defaults(func=convert_func)

    run = subparsers.add_parser(
        'run',
        parents=[all_type_parent_parser,
                 convert_run_parent_parser],
        help='run model in command line')
    run.set_defaults(func=run_mace)
    run.add_argument(
1098 1099 1100 1101
        "--mace_lib_type",
        type=str_to_mace_lib_type,
        default=DefaultValues.mace_lib_type,
        help="[static | dynamic], Which type MACE library to use.")
L
liyin 已提交
1102
    run.add_argument(
1103
        "--num_threads",
W
wuchenghui 已提交
1104
        type=int,
1105 1106
        default=DefaultValues.num_threads,
        help="num of threads")
L
liyin 已提交
1107
    run.add_argument(
W
wuchenghui 已提交
1108 1109
        "--cpu_affinity_policy",
        type=int,
1110
        default=DefaultValues.cpu_affinity_policy,
W
wuchenghui 已提交
1111
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
L
liyin 已提交
1112
    run.add_argument(
W
wuchenghui 已提交
1113 1114
        "--gpu_perf_hint",
        type=int,
1115
        default=DefaultValues.gpu_perf_hint,
W
wuchenghui 已提交
1116
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liyin 已提交
1117
    run.add_argument(
W
wuchenghui 已提交
1118 1119
        "--gpu_priority_hint",
        type=int,
1120
        default=DefaultValues.gpu_priority_hint,
W
wuchenghui 已提交
1121
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liyin 已提交
1122
    run.add_argument(
L
liuqi 已提交
1123 1124 1125 1126 1127
        "--device_yml",
        type=str,
        default='',
        help='embedded linux device config yml file'
    )
1128 1129 1130 1131
    run.add_argument(
        "--disable_tuning",
        action="store_true",
        help="Disable tuning for specific thread.")
1132 1133
    run.add_argument(
        "--round",
L
Liangliang He 已提交
1134
        type=int,
1135 1136 1137 1138 1139
        default=1,
        help="The model running round.")
    run.add_argument(
        "--validate",
        action="store_true",
1140 1141
        help="whether to verify the results are consistent with "
             "the frameworks.")
B
Bin Li 已提交
1142
    run.add_argument(
B
Bin Li 已提交
1143 1144 1145 1146 1147
        "--layers",
        type=str,
        default="-1",
        help="'start_layer:end_layer' or 'layer', similar to python slice."
             " Use with --validate flag.")
1148
    run.add_argument(
L
liuqi 已提交
1149 1150 1151
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
1152 1153
        help="[docker | local] you can specific caffe environment for"
             " validation. local environment or caffe docker image.")
1154 1155 1156 1157
    run.add_argument(
        "--vlog_level",
        type=int,
        default=0,
1158
        help="[1~5]. Verbose log level for debug.")
1159
    run.add_argument(
L
Liangliang He 已提交
1160
        "--gpu_out_of_range_check",
1161 1162 1163 1164 1165 1166
        action="store_true",
        help="Enable out of memory check for gpu.")
    run.add_argument(
        "--restart_round",
        type=int,
        default=1,
1167
        help="restart round between run.")
1168 1169 1170 1171 1172 1173
    run.add_argument(
        "--report",
        action="store_true",
        help="print run statistics report.")
    run.add_argument(
        "--report_dir",
1174 1175
        type=str,
        default="",
1176 1177
        help="print run statistics report.")
    run.add_argument(
李寅 已提交
1178 1179 1180 1181
        "--runtime_failure_ratio",
        type=float,
        default=0.0,
        help="[mock runtime failure ratio].")
李寅 已提交
1182 1183 1184 1185 1186
    run.add_argument(
        "--input_dir",
        type=str,
        default="",
        help="quantize stat input dir.")
1187 1188 1189 1190 1191
    run.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="quantize stat output dir.")
1192 1193 1194 1195
    run.add_argument(
        "--cl_binary_to_code",
        action="store_true",
        help="convert OpenCL binaries to cpp.")
L
liyin 已提交
1196 1197 1198 1199
    run.add_argument(
        "--benchmark",
        action="store_true",
        help="enable op benchmark.")
1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214
    run.add_argument(
        "--apu_cache_policy",
        type=int,
        default=DefaultValues.apu_cache_policy,
        help="0:NONE/1:STORE/2:LOAD")
    run.add_argument(
        "--apu_binary_file",
        type=str,
        default="",
        help="apu cache load dir.")
    run.add_argument(
        "--apu_storage_file",
        type=str,
        default="",
        help="apu cache store dir.")
L
Liangliang He 已提交
1215 1216
    return parser.parse_known_args()

1217

Y
yejianwu 已提交
1218
if __name__ == "__main__":
1219 1220
    flags, unparsed = parse_args()
    flags.func(flags)