converter.py 51.2 KB
Newer Older
L
Liangliang He 已提交
1
# Copyright 2018 The MACE Authors. All Rights Reserved.
Y
yejianwu 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

import argparse
L
liuqi 已提交
16
import glob
L
Liangliang He 已提交
17
import sh
18
import sys
19
import time
Y
yejianwu 已提交
20
import yaml
L
liuqi 已提交
21

22
from enum import Enum
23
import six
24

25
import sh_commands
L
Liangliang He 已提交
26

L
liuqi 已提交
27 28
from common import *
from device import DeviceWrapper, DeviceManager
29

Y
yejianwu 已提交
30 31 32 33
################################
# set environment
################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
A
Allen 已提交
34

35 36 37 38 39
################################
# common definitions
################################

ABITypeStrs = [
L
liuqi 已提交
40 41
    'armeabi-v7a',
    'arm64-v8a',
L
liuqi 已提交
42 43
    'arm64',
    'armhf',
L
liuqi 已提交
44
    'host',
45
]
L
liuqi 已提交
46

47 48 49 50 51
ModelFormatStrs = [
    "file",
    "code",
]

52 53 54
PlatformTypeStrs = [
    "tensorflow",
    "caffe",
L
liutuo 已提交
55
    "onnx",
56 57 58 59 60 61 62 63
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
                    type=str)

RuntimeTypeStrs = [
    "cpu",
    "gpu",
    "dsp",
B
Bin Li 已提交
64
    "hta",
65
    "apu",
66 67 68
    "cpu+gpu"
]

L
liuqi 已提交
69
InOutDataTypeStrs = [
Y
yejianwu 已提交
70 71 72 73
    "int32",
    "float32",
]

L
liuqi 已提交
74 75
InOutDataType = Enum('InputDataType',
                     [(ele, ele) for ele in InOutDataTypeStrs],
Y
yejianwu 已提交
76 77
                     type=str)

L
liuqi 已提交
78
FPDataTypeStrs = [
79 80 81 82
    "fp16_fp32",
    "fp32_fp32",
]

L
liuqi 已提交
83 84
FPDataType = Enum('GPUDataType', [(ele, ele) for ele in FPDataTypeStrs],
                  type=str)
85

L
liuqi 已提交
86 87 88 89 90 91 92
DSPDataTypeStrs = [
    "uint8",
]

DSPDataType = Enum('DSPDataType', [(ele, ele) for ele in DSPDataTypeStrs],
                   type=str)

93 94 95 96 97 98 99
APUDataTypeStrs = [
    "uint8",
]

APUDataType = Enum('APUDataType', [(ele, ele) for ele in APUDataTypeStrs],
                   type=str)

100 101
WinogradParameters = [0, 2, 4]

102 103 104
DataFormatStrs = [
    "NONE",
    "NHWC",
105
    "NCHW",
106
    "OIHW",
107 108 109
]


110
class DefaultValues(object):
111
    mace_lib_type = MACELibType.static
112 113 114 115 116 117
    omp_num_threads = -1,
    cpu_affinity_policy = 1,
    gpu_perf_hint = 3,
    gpu_priority_hint = 3,


118 119 120 121 122 123 124
class ValidationThreshold(object):
    cpu_threshold = 0.999,
    gpu_threshold = 0.995,
    hexagon_threshold = 0.930,
    cpu_quantize_threshold = 0.980,


125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
CPP_KEYWORDS = [
    'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
    'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor',
    'bool', 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t',
    'class', 'compl', 'concept', 'const', 'constexpr', 'const_cast',
    'continue', 'co_await', 'co_return', 'co_yield', 'decltype', 'default',
    'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit',
    'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if',
    'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace',
    'new', 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq',
    'private', 'protected', 'public', 'register', 'reinterpret_cast',
    'requires', 'return', 'short', 'signed', 'sizeof', 'static',
    'static_assert', 'static_cast', 'struct', 'switch', 'synchronized',
    'template', 'this', 'thread_local', 'throw', 'true', 'try', 'typedef',
    'typeid', 'typename', 'union', 'unsigned', 'using', 'virtual', 'void',
    'volatile', 'wchar_t', 'while', 'xor', 'xor_eq', 'override', 'final',
    'transaction_safe', 'transaction_safe_dynamic', 'if', 'elif', 'else',
    'endif', 'defined', 'ifdef', 'ifndef', 'define', 'undef', 'include',
    'line', 'error', 'pragma',
]
Y
yejianwu 已提交
145

146

147 148 149
################################
# common functions
################################
150
def parse_device_type(runtime):
Y
yejianwu 已提交
151
    device_type = ""
152

153
    if runtime == RuntimeType.dsp:
154
        device_type = DeviceType.HEXAGON
B
Bin Li 已提交
155 156
    elif runtime == RuntimeType.hta:
        device_type = DeviceType.HTA
157
    elif runtime == RuntimeType.gpu:
158
        device_type = DeviceType.GPU
159
    elif runtime == RuntimeType.cpu:
160
        device_type = DeviceType.CPU
161 162
    elif runtime == RuntimeType.apu:
        device_type = DeviceType.APU
163

164
    return device_type
165

Y
yejianwu 已提交
166 167

def get_hexagon_mode(configs):
L
Liangliang He 已提交
168
    runtime_list = []
L
liuqi 已提交
169
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
170
        model_runtime = \
L
liuqi 已提交
171 172
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
L
Liangliang He 已提交
173 174
        runtime_list.append(model_runtime.lower())

L
liuqi 已提交
175
    if RuntimeType.dsp in runtime_list:
Y
yejianwu 已提交
176 177 178 179
        return True
    return False


B
Bin Li 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192
def get_hta_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
        model_runtime = \
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.hta in runtime_list:
        return True
    return False


Y
yejianwu 已提交
193 194 195
def get_opencl_mode(configs):
    runtime_list = []
    for model_name in configs[YAMLKeyword.models]:
L
liuqi 已提交
196
        model_runtime = \
Y
yejianwu 已提交
197 198 199 200 201 202 203 204 205
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.runtime, "")
        runtime_list.append(model_runtime.lower())

    if RuntimeType.gpu in runtime_list or RuntimeType.cpu_gpu in runtime_list:
        return True
    return False


206 207 208 209 210 211 212 213 214 215 216
def get_quantize_mode(configs):
    for model_name in configs[YAMLKeyword.models]:
        quantize =\
            configs[YAMLKeyword.models][model_name].get(
                YAMLKeyword.quantize, 0)
        if quantize == 1:
            return True

    return False


217 218 219 220 221 222 223 224 225
def get_symbol_hidden_mode(debug_mode, mace_lib_type=None):
    if not mace_lib_type:
        return True
    if debug_mode or mace_lib_type == MACELibType.dynamic:
        return False
    else:
        return True


226 227
def md5sum(str):
    md5 = hashlib.md5()
228
    md5.update(str.encode('utf-8'))
229
    return md5.hexdigest()
230

Y
yejianwu 已提交
231

232 233 234 235 236 237
def sha256_checksum(fname):
    hash_func = hashlib.sha256()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_func.update(chunk)
    return hash_func.hexdigest()
Y
yejianwu 已提交
238

W
wuchenghui 已提交
239

240 241
def format_model_config(flags):
    with open(flags.config) as f:
242
        configs = yaml.load(f)
W
wuchenghui 已提交
243

244 245
    library_name = configs.get(YAMLKeyword.library_name, "")
    mace_check(len(library_name) > 0,
L
liuqi 已提交
246
               ModuleName.YAML_CONFIG, "library name should not be empty")
247

248 249 250 251
    if flags.target_abis:
        target_abis = flags.target_abis.split(',')
    else:
        target_abis = configs.get(YAMLKeyword.target_abis, [])
252 253
    mace_check((isinstance(target_abis, list) and len(target_abis) > 0),
               ModuleName.YAML_CONFIG, "target_abis list is needed")
254
    configs[YAMLKeyword.target_abis] = target_abis
255 256 257 258 259 260
    for abi in target_abis:
        mace_check(abi in ABITypeStrs,
                   ModuleName.YAML_CONFIG,
                   "target_abis must be in " + str(ABITypeStrs))

    target_socs = configs.get(YAMLKeyword.target_socs, "")
L
liuqi 已提交
261 262
    if flags.target_socs and flags.target_socs != TargetSOCTag.random \
            and flags.target_socs != TargetSOCTag.all:
263
        configs[YAMLKeyword.target_socs] = \
L
liuqi 已提交
264
            [soc.lower() for soc in flags.target_socs.split(',')]
265
    elif not target_socs:
266 267 268 269
        configs[YAMLKeyword.target_socs] = []
    elif not isinstance(target_socs, list):
        configs[YAMLKeyword.target_socs] = [target_socs]

270 271 272
    configs[YAMLKeyword.target_socs] = \
        [soc.lower() for soc in configs[YAMLKeyword.target_socs]]

L
liuqi 已提交
273 274
    if ABIType.armeabi_v7a in target_abis \
            or ABIType.arm64_v8a in target_abis:
275
        available_socs = sh_commands.adb_get_all_socs()
276
        target_socs = configs[YAMLKeyword.target_socs]
L
liuqi 已提交
277
        if TargetSOCTag.all in target_socs:
278 279
            mace_check(available_socs,
                       ModuleName.YAML_CONFIG,
L
liuqi 已提交
280 281 282
                       "Android abi is listed in config file and "
                       "build for all SOCs plugged in computer, "
                       "But no android phone found, "
283 284
                       "you at least plug in one phone")
        else:
285 286 287 288 289 290
            for soc in target_socs:
                mace_check(soc in available_socs,
                           ModuleName.YAML_CONFIG,
                           "Build specified SOC library, "
                           "you must plug in a phone using the SOC")

291 292
    if flags.model_graph_format:
        model_graph_format = flags.model_graph_format
293
    else:
294 295 296 297 298 299 300 301
        model_graph_format = configs.get(YAMLKeyword.model_graph_format, "")
    mace_check(model_graph_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_graph_format and '
               "model_graph_format must be in " + str(ModelFormatStrs))
    configs[YAMLKeyword.model_graph_format] = model_graph_format
    if flags.model_data_format:
        model_data_format = flags.model_data_format
302
    else:
303 304 305 306 307 308 309 310 311 312 313 314
        model_data_format = configs.get(YAMLKeyword.model_data_format, "")
    configs[YAMLKeyword.model_data_format] = model_data_format
    mace_check(model_data_format in ModelFormatStrs,
               ModuleName.YAML_CONFIG,
               'You must set model_data_format and '
               "model_data_format must be in " + str(ModelFormatStrs))

    mace_check(not (model_graph_format == ModelFormat.file
                    and model_data_format == ModelFormat.code),
               ModuleName.YAML_CONFIG,
               "If model_graph format is 'file',"
               " the model_data_format must be 'file' too")
Y
yejianwu 已提交
315

316 317 318 319
    model_names = configs.get(YAMLKeyword.models, [])
    mace_check(len(model_names) > 0, ModuleName.YAML_CONFIG,
               "no model found in config file")

L
liuqi 已提交
320
    model_name_reg = re.compile(r'^[a-zA-Z0-9_]+$')
321 322 323 324 325 326 327 328
    for model_name in model_names:
        # check model_name legality
        mace_check(model_name not in CPP_KEYWORDS,
                   ModuleName.YAML_CONFIG,
                   "model name should not be c++ keyword.")
        mace_check((model_name[0] == '_' or model_name[0].isalpha())
                   and bool(model_name_reg.match(model_name)),
                   ModuleName.YAML_CONFIG,
L
liuqi 已提交
329
                   "model name should Meet the c++ naming convention"
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
                   " which start with '_' or alpha"
                   " and only contain alpha, number and '_'")

        model_config = configs[YAMLKeyword.models][model_name]
        platform = model_config.get(YAMLKeyword.platform, "")
        mace_check(platform in PlatformTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'platform' must be in " + str(PlatformTypeStrs))

        for key in [YAMLKeyword.model_file_path,
                    YAMLKeyword.model_sha256_checksum]:
            value = model_config.get(key, "")
            mace_check(value != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" % key)

        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        if weight_file_path:
L
liuqi 已提交
347
            weight_checksum = \
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
                model_config.get(YAMLKeyword.weight_sha256_checksum, "")
            mace_check(weight_checksum != "", ModuleName.YAML_CONFIG,
                       "'%s' is necessary" %
                       YAMLKeyword.weight_sha256_checksum)
        else:
            model_config[YAMLKeyword.weight_sha256_checksum] = ""

        runtime = model_config.get(YAMLKeyword.runtime, "")
        mace_check(runtime in RuntimeTypeStrs,
                   ModuleName.YAML_CONFIG,
                   "'runtime' must be in " + str(RuntimeTypeStrs))
        if ABIType.host in target_abis:
            mace_check(runtime == RuntimeType.cpu,
                       ModuleName.YAML_CONFIG,
                       "host only support cpu runtime now.")

        data_type = model_config.get(YAMLKeyword.data_type, "")
L
liuqi 已提交
365
        if runtime == RuntimeType.dsp:
L
liuqi 已提交
366 367 368 369 370 371 372 373
            if len(data_type) > 0:
                mace_check(data_type in DSPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(DSPDataTypeStrs)
                           + " for dsp runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    DSPDataType.uint8.value
374 375 376 377 378 379 380 381 382
        elif runtime == RuntimeType.apu:
            if len(data_type) > 0:
                mace_check(data_type in APUDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(APUDataTypeStrs)
                           + " for apu runtime")
            else:
                model_config[YAMLKeyword.data_type] = \
                    APUDataType.uint8.value
L
liuqi 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395
        else:
            if len(data_type) > 0:
                mace_check(data_type in FPDataTypeStrs,
                           ModuleName.YAML_CONFIG,
                           "'data_type' must be in " + str(FPDataTypeStrs)
                           + " for cpu runtime")
            else:
                if runtime == RuntimeType.cpu:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp32_fp32.value
                else:
                    model_config[YAMLKeyword.data_type] = \
                        FPDataType.fp16_fp32.value
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410

        subgraphs = model_config.get(YAMLKeyword.subgraphs, "")
        mace_check(len(subgraphs) > 0, ModuleName.YAML_CONFIG,
                   "at least one subgraph is needed")

        for subgraph in subgraphs:
            for key in [YAMLKeyword.input_tensors,
                        YAMLKeyword.input_shapes,
                        YAMLKeyword.output_tensors,
                        YAMLKeyword.output_shapes]:
                value = subgraph.get(key, "")
                mace_check(value != "", ModuleName.YAML_CONFIG,
                           "'%s' is necessary in subgraph" % key)
                if not isinstance(value, list):
                    subgraph[key] = [value]
411
                subgraph[key] = [str(v) for v in subgraph[key]]
412 413 414 415 416 417 418 419 420
            input_size = len(subgraph[YAMLKeyword.input_tensors])
            output_size = len(subgraph[YAMLKeyword.output_tensors])

            mace_check(len(subgraph[YAMLKeyword.input_shapes]) == input_size,
                       ModuleName.YAML_CONFIG,
                       "input shapes' size not equal inputs' size.")
            mace_check(len(subgraph[YAMLKeyword.output_shapes]) == output_size,
                       ModuleName.YAML_CONFIG,
                       "output shapes' size not equal outputs' size.")
421

B
Bin Li 已提交
422 423 424 425 426 427 428 429 430 431
            for key in [YAMLKeyword.check_tensors,
                        YAMLKeyword.check_shapes]:
                value = subgraph.get(key, "")
                if value != "":
                    if not isinstance(value, list):
                        subgraph[key] = [value]
                    subgraph[key] = [str(v) for v in subgraph[key]]
                else:
                    subgraph[key] = []

L
liuqi 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
            for key in [YAMLKeyword.input_data_types,
                        YAMLKeyword.output_data_types]:
                if key == YAMLKeyword.input_data_types:
                    count = input_size
                else:
                    count = output_size
                data_types = subgraph.get(key, "")
                if data_types:
                    if not isinstance(data_types, list):
                        subgraph[key] = [data_types] * count
                    for data_type in subgraph[key]:
                        mace_check(data_type in InOutDataTypeStrs,
                                   ModuleName.YAML_CONFIG,
                                   key + " must be in "
                                   + str(InOutDataTypeStrs))
                else:
                    subgraph[key] = [InOutDataType.float32] * count
Y
yejianwu 已提交
449

450 451 452 453 454
            input_data_formats = subgraph.get(YAMLKeyword.input_data_formats,
                                              [])
            if input_data_formats:
                if not isinstance(input_data_formats, list):
                    subgraph[YAMLKeyword.input_data_formats] =\
455
                        [input_data_formats] * input_size
456 457
                else:
                    mace_check(len(input_data_formats)
458
                               == input_size,
459 460
                               ModuleName.YAML_CONFIG,
                               "input_data_formats should match"
461
                               " the size of input.")
462 463 464 465 466 467
                for input_data_format in\
                        subgraph[YAMLKeyword.input_data_formats]:
                    mace_check(input_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
                               "'input_data_formats' must be in "
                               + str(DataFormatStrs) + ", but got "
Y
yejianwu 已提交
468
                               + input_data_format)
469
            else:
470 471
                subgraph[YAMLKeyword.input_data_formats] = \
                    [DataFormat.NHWC] * input_size
472 473 474 475 476 477

            output_data_formats = subgraph.get(YAMLKeyword.output_data_formats,
                                               [])
            if output_data_formats:
                if not isinstance(output_data_formats, list):
                    subgraph[YAMLKeyword.output_data_formats] = \
478
                        [output_data_formats] * output_size
479 480
                else:
                    mace_check(len(output_data_formats)
481
                               == output_size,
482 483 484 485 486 487 488
                               ModuleName.YAML_CONFIG,
                               "output_data_formats should match"
                               " the size of output")
                for output_data_format in\
                        subgraph[YAMLKeyword.output_data_formats]:
                    mace_check(output_data_format in DataFormatStrs,
                               ModuleName.YAML_CONFIG,
Y
yejianwu 已提交
489
                               "'output_data_formats' must be in "
490 491
                               + str(DataFormatStrs))
            else:
492 493
                subgraph[YAMLKeyword.output_data_formats] =\
                    [DataFormat.NHWC] * output_size
494

495 496 497 498
            validation_threshold = subgraph.get(
                YAMLKeyword.validation_threshold, {})
            if not isinstance(validation_threshold, dict):
                raise argparse.ArgumentTypeError(
L
liuqi 已提交
499
                    'similarity threshold must be a dict.')
500 501

            threshold_dict = {
502 503
                DeviceType.CPU: ValidationThreshold.cpu_threshold,
                DeviceType.GPU: ValidationThreshold.gpu_threshold,
B
Bin Li 已提交
504 505
                DeviceType.HEXAGON + "_QUANTIZE":
                    ValidationThreshold.hexagon_threshold,
B
Bin Li 已提交
506 507
                DeviceType.HTA + "_QUANTIZE":
                    ValidationThreshold.hexagon_threshold,
508 509
                DeviceType.CPU + "_QUANTIZE":
                    ValidationThreshold.cpu_quantize_threshold,
L
liuqi 已提交
510
            }
511 512 513 514 515
            for k, v in six.iteritems(validation_threshold):
                if k.upper() == 'DSP':
                    k = DeviceType.HEXAGON
                if k.upper() not in (DeviceType.CPU,
                                     DeviceType.GPU,
李寅 已提交
516
                                     DeviceType.HEXAGON,
B
Bin Li 已提交
517
                                     DeviceType.HTA,
李寅 已提交
518
                                     DeviceType.CPU + "_QUANTIZE"):
519
                    raise argparse.ArgumentTypeError(
L
liuqi 已提交
520
                        'Unsupported validation threshold runtime: %s' % k)
521 522 523 524
                threshold_dict[k.upper()] = v

            subgraph[YAMLKeyword.validation_threshold] = threshold_dict

L
liuqi 已提交
525 526 527 528 529 530 531 532
            validation_inputs_data = subgraph.get(
                YAMLKeyword.validation_inputs_data, [])
            if not isinstance(validation_inputs_data, list):
                subgraph[YAMLKeyword.validation_inputs_data] = [
                    validation_inputs_data]
            else:
                subgraph[YAMLKeyword.validation_inputs_data] = \
                    validation_inputs_data
L
liutuo 已提交
533 534 535 536

            onnx_backend = subgraph.get(
                YAMLKeyword.backend, "tensorflow")
            subgraph[YAMLKeyword.backend] = onnx_backend
537 538 539 540 541 542 543 544
            validation_outputs_data = subgraph.get(
                YAMLKeyword.validation_outputs_data, [])
            if not isinstance(validation_outputs_data, list):
                subgraph[YAMLKeyword.validation_outputs_data] = [
                    validation_outputs_data]
            else:
                subgraph[YAMLKeyword.validation_outputs_data] = \
                    validation_outputs_data
545 546 547 548 549 550
            input_ranges = subgraph.get(
                YAMLKeyword.input_ranges, [])
            if not isinstance(input_ranges, list):
                subgraph[YAMLKeyword.input_ranges] = [input_ranges]
            else:
                subgraph[YAMLKeyword.input_ranges] = input_ranges
L
liuqi 已提交
551
            subgraph[YAMLKeyword.input_ranges] = \
552
                [str(v) for v in subgraph[YAMLKeyword.input_ranges]]
W
wuchenghui 已提交
553

554 555 556 557 558 559 560 561 562 563
            accuracy_validation_script = subgraph.get(
                YAMLKeyword.accuracy_validation_script, "")
            if isinstance(accuracy_validation_script, list):
                mace_check(len(accuracy_validation_script) == 1,
                           ModuleName.YAML_CONFIG,
                           "Only support one accuracy validation script")
                accuracy_validation_script = accuracy_validation_script[0]
            subgraph[YAMLKeyword.accuracy_validation_script] = \
                accuracy_validation_script

564 565 566
        for key in [YAMLKeyword.limit_opencl_kernel_time,
                    YAMLKeyword.nnlib_graph_mode,
                    YAMLKeyword.obfuscate,
李寅 已提交
567
                    YAMLKeyword.winograd,
568 569
                    YAMLKeyword.quantize,
                    YAMLKeyword.change_concat_ranges]:
570 571 572
            value = model_config.get(key, "")
            if value == "":
                model_config[key] = 0
L
Liangliang He 已提交
573

574 575 576 577 578 579
        mace_check(model_config[YAMLKeyword.winograd] in WinogradParameters,
                   ModuleName.YAML_CONFIG,
                   "'winograd' parameters must be in "
                   + str(WinogradParameters) +
                   ". 0 for disable winograd convolution")

L
liuqi 已提交
580 581
        weight_file_path = model_config.get(YAMLKeyword.weight_file_path, "")
        model_config[YAMLKeyword.weight_file_path] = weight_file_path
Y
yejianwu 已提交
582

583
    return configs
Y
yejianwu 已提交
584

W
wuchenghui 已提交
585

586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
def clear_build_dirs(library_name):
    # make build dir
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
    # clear temp build dir
    tmp_build_dir = os.path.join(BUILD_OUTPUT_DIR, library_name,
                                 BUILD_TMP_DIR_NAME)
    if os.path.exists(tmp_build_dir):
        sh.rm('-rf', tmp_build_dir)
    os.makedirs(tmp_build_dir)
    # clear lib dir
    lib_output_dir = os.path.join(
        BUILD_OUTPUT_DIR, library_name, OUTPUT_LIBRARY_DIR_NAME)
    if os.path.exists(lib_output_dir):
        sh.rm('-rf', lib_output_dir)


################################
# convert
################################
def print_configuration(configs):
607 608 609 610 611 612 613 614 615
    title = "Common Configuration"
    header = ["key", "value"]
    data = list()
    data.append([YAMLKeyword.library_name,
                 configs[YAMLKeyword.library_name]])
    data.append([YAMLKeyword.target_abis,
                 configs[YAMLKeyword.target_abis]])
    data.append([YAMLKeyword.target_socs,
                 configs[YAMLKeyword.target_socs]])
616 617 618 619
    data.append([YAMLKeyword.model_graph_format,
                 configs[YAMLKeyword.model_graph_format]])
    data.append([YAMLKeyword.model_data_format,
                 configs[YAMLKeyword.model_data_format]])
620
    MaceLogger.summary(StringFormatter.table(header, data, title))
L
Liangliang He 已提交
621

Y
yejianwu 已提交
622

623 624 625 626
def download_file(url, dst, num_retries=3):
    from six.moves import urllib

    try:
627
        urllib.request.urlretrieve(url, dst)
628
        MaceLogger.info('\nDownloaded successfully.')
L
liuqi 已提交
629 630
    except (urllib.error.ContentTooShortError, urllib.error.HTTPError,
            urllib.error.URLError) as e:
631
        MaceLogger.warning('Download error:' + str(e))
632 633 634 635 636 637 638
        if num_retries > 0:
            return download_file(url, dst, num_retries - 1)
        else:
            return False
    return True


B
Bin Li 已提交
639 640 641 642
def get_model_files(model_file_path,
                    model_sha256_checksum,
                    model_output_dir,
                    weight_file_path="",
李寅 已提交
643 644
                    weight_sha256_checksum="",
                    quantize_range_file_path=""):
B
Bin Li 已提交
645 646
    model_file = model_file_path
    weight_file = weight_file_path
李寅 已提交
647
    quantize_range_file = quantize_range_file_path
648 649 650

    if model_file_path.startswith("http://") or \
            model_file_path.startswith("https://"):
B
Bin Li 已提交
651 652 653 654
        model_file = model_output_dir + "/" + md5sum(model_file_path) + ".pb"
        if not os.path.exists(model_file) or \
                sha256_checksum(model_file) != model_sha256_checksum:
            MaceLogger.info("Downloading model, please wait ...")
655 656 657
            if not download_file(model_file_path, model_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
658 659 660 661

    if sha256_checksum(model_file) != model_sha256_checksum:
        MaceLogger.error(ModuleName.MODEL_CONVERTER,
                         "model file sha256checksum not match")
L
Liangliang He 已提交
662 663 664

    if weight_file_path.startswith("http://") or \
            weight_file_path.startswith("https://"):
B
Bin Li 已提交
665 666 667 668 669
        weight_file = \
            model_output_dir + "/" + md5sum(weight_file_path) + ".caffemodel"
        if not os.path.exists(weight_file) or \
                sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.info("Downloading model weight, please wait ...")
670 671 672
            if not download_file(weight_file_path, weight_file):
                MaceLogger.error(ModuleName.MODEL_CONVERTER,
                                 "Model download failed.")
B
Bin Li 已提交
673 674 675 676 677

    if weight_file:
        if sha256_checksum(weight_file) != weight_sha256_checksum:
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "weight file sha256checksum not match")
Y
yejianwu 已提交
678

李寅 已提交
679 680 681 682 683 684 685 686 687
    if quantize_range_file_path.startswith("http://") or \
            quantize_range_file_path.startswith("https://"):
        quantize_range_file = \
            model_output_dir + "/" + md5sum(quantize_range_file_path) \
            + ".range"
        if not download_file(quantize_range_file_path, quantize_range_file):
            MaceLogger.error(ModuleName.MODEL_CONVERTER,
                             "Model range file download failed.")
    return model_file, weight_file, quantize_range_file
L
Liangliang He 已提交
688

L
liuqi 已提交
689

690
def convert_model(configs, cl_mem_type):
691 692 693 694
    # Remove previous output dirs
    library_name = configs[YAMLKeyword.library_name]
    if not os.path.exists(BUILD_OUTPUT_DIR):
        os.makedirs(BUILD_OUTPUT_DIR)
L
liuqi 已提交
695 696 697
    elif os.path.exists(os.path.join(BUILD_OUTPUT_DIR, library_name)):
        sh.rm("-rf", os.path.join(BUILD_OUTPUT_DIR, library_name))
    os.makedirs(os.path.join(BUILD_OUTPUT_DIR, library_name))
B
Bin Li 已提交
698 699
    if not os.path.exists(BUILD_DOWNLOADS_DIR):
        os.makedirs(BUILD_DOWNLOADS_DIR)
700 701 702

    model_output_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)
L
liuqi 已提交
703 704
    model_header_dir = \
        '%s/%s/%s' % (BUILD_OUTPUT_DIR, library_name, MODEL_HEADER_DIR_PATH)
705
    # clear output dir
706 707 708
    if os.path.exists(model_output_dir):
        sh.rm("-rf", model_output_dir)
    os.makedirs(model_output_dir)
L
liuqi 已提交
709 710
    if os.path.exists(model_header_dir):
        sh.rm("-rf", model_header_dir)
711 712 713 714 715 716 717 718 719 720 721 722 723 724

    embed_model_data = \
        configs[YAMLKeyword.model_data_format] == ModelFormat.code

    if os.path.exists(MODEL_CODEGEN_DIR):
        sh.rm("-rf", MODEL_CODEGEN_DIR)
    if os.path.exists(ENGINE_CODEGEN_DIR):
        sh.rm("-rf", ENGINE_CODEGEN_DIR)

    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        os.makedirs(model_header_dir)
        sh_commands.gen_mace_engine_factory_source(
            configs[YAMLKeyword.models].keys(),
            embed_model_data)
L
liuqi 已提交
725 726 727
        sh.cp("-f", glob.glob("mace/codegen/engine/*.h"),
              model_header_dir)

728 729 730 731 732
    for model_name in configs[YAMLKeyword.models]:
        MaceLogger.header(
            StringFormatter.block("Convert %s model" % model_name))
        model_config = configs[YAMLKeyword.models][model_name]
        runtime = model_config[YAMLKeyword.runtime]
733 734 735 736
        if cl_mem_type:
            model_config[YAMLKeyword.cl_mem_type] = cl_mem_type
        else:
            model_config[YAMLKeyword.cl_mem_type] = "image"
737

李寅 已提交
738 739 740 741 742 743 744 745
        model_file_path, weight_file_path, quantize_range_file_path = \
            get_model_files(
                model_config[YAMLKeyword.model_file_path],
                model_config[YAMLKeyword.model_sha256_checksum],
                BUILD_DOWNLOADS_DIR,
                model_config[YAMLKeyword.weight_file_path],
                model_config[YAMLKeyword.weight_sha256_checksum],
                model_config.get(YAMLKeyword.quantize_range_file, ""))
746 747 748 749 750

        data_type = model_config[YAMLKeyword.data_type]
        # TODO(liuqi): support multiple subgraphs
        subgraphs = model_config[YAMLKeyword.subgraphs]

L
liuqi 已提交
751
        model_codegen_dir = "%s/%s" % (MODEL_CODEGEN_DIR, model_name)
752 753 754 755 756 757 758 759
        sh_commands.gen_model_code(
            model_codegen_dir,
            model_config[YAMLKeyword.platform],
            model_file_path,
            weight_file_path,
            model_config[YAMLKeyword.model_sha256_checksum],
            model_config[YAMLKeyword.weight_sha256_checksum],
            ",".join(subgraphs[0][YAMLKeyword.input_tensors]),
L
liuqi 已提交
760
            ",".join(subgraphs[0][YAMLKeyword.input_data_types]),
761
            ",".join(subgraphs[0][YAMLKeyword.input_data_formats]),
762
            ",".join(subgraphs[0][YAMLKeyword.output_tensors]),
L
liuqi 已提交
763
            ",".join(subgraphs[0][YAMLKeyword.output_data_types]),
764
            ",".join(subgraphs[0][YAMLKeyword.output_data_formats]),
B
Bin Li 已提交
765
            ",".join(subgraphs[0][YAMLKeyword.check_tensors]),
766 767 768
            runtime,
            model_name,
            ":".join(subgraphs[0][YAMLKeyword.input_shapes]),
李寅 已提交
769
            ":".join(subgraphs[0][YAMLKeyword.input_ranges]),
B
Bin Li 已提交
770 771
            ":".join(subgraphs[0][YAMLKeyword.output_shapes]),
            ":".join(subgraphs[0][YAMLKeyword.check_shapes]),
772 773 774
            model_config[YAMLKeyword.nnlib_graph_mode],
            embed_model_data,
            model_config[YAMLKeyword.winograd],
李寅 已提交
775
            model_config[YAMLKeyword.quantize],
李寅 已提交
776
            quantize_range_file_path,
777
            model_config[YAMLKeyword.change_concat_ranges],
778
            model_config[YAMLKeyword.obfuscate],
779
            configs[YAMLKeyword.model_graph_format],
李寅 已提交
780
            data_type,
781
            model_config[YAMLKeyword.cl_mem_type],
李寅 已提交
782
            ",".join(model_config.get(YAMLKeyword.graph_optimize_options, [])))
783

784
        if configs[YAMLKeyword.model_graph_format] == ModelFormat.file:
L
liuqi 已提交
785 786 787
            sh.mv("-f",
                  '%s/%s.pb' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
788 789 790
            sh.mv("-f",
                  '%s/%s.data' % (model_codegen_dir, model_name),
                  model_output_dir)
L
liuqi 已提交
791 792
        else:
            if not embed_model_data:
L
liuqi 已提交
793
                sh.mv("-f",
L
liuqi 已提交
794
                      '%s/%s.data' % (model_codegen_dir, model_name),
L
liuqi 已提交
795
                      model_output_dir)
L
liuqi 已提交
796 797
            sh.cp("-f", glob.glob("mace/codegen/models/*/*.h"),
                  model_header_dir)
798

L
liuqi 已提交
799
        MaceLogger.summary(
800 801 802
            StringFormatter.block("Model %s converted" % model_name))


803
def build_model_lib(configs, address_sanitizer, debug_mode):
804
    MaceLogger.header(StringFormatter.block("Building model library"))
805

806 807 808
    # create model library dir
    library_name = configs[YAMLKeyword.library_name]
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
809 810 811 812 813
        model_lib_output_path = get_model_lib_output_path(library_name,
                                                          target_abi)
        library_out_dir = os.path.dirname(model_lib_output_path)
        if not os.path.exists(library_out_dir):
            os.makedirs(library_out_dir)
L
liuqi 已提交
814
        toolchain = infer_toolchain(target_abi)
815
        sh_commands.bazel_build(
816
            MODEL_LIB_TARGET,
817
            abi=target_abi,
L
liuqi 已提交
818
            toolchain=toolchain,
B
Bin Li 已提交
819 820
            enable_hexagon=get_hexagon_mode(configs),
            enable_hta=get_hta_mode(configs),
Y
yejianwu 已提交
821
            enable_opencl=get_opencl_mode(configs),
822
            enable_quantize=get_quantize_mode(configs),
823
            address_sanitizer=address_sanitizer,
824 825
            symbol_hidden=get_symbol_hidden_mode(debug_mode),
            debug_mode=debug_mode
826 827
        )

828
        sh.cp("-f", MODEL_LIB_PATH, model_lib_output_path)
829 830 831 832 833 834 835


def print_library_summary(configs):
    library_name = configs[YAMLKeyword.library_name]
    title = "Library"
    header = ["key", "value"]
    data = list()
836 837 838 839 840 841 842 843 844
    data.append(["MACE Model Path",
                 "%s/%s/%s"
                 % (BUILD_OUTPUT_DIR, library_name, MODEL_OUTPUT_DIR_NAME)])
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        data.append(["MACE Model Header Path",
                     "%s/%s/%s"
                     % (BUILD_OUTPUT_DIR, library_name,
                        MODEL_HEADER_DIR_PATH)])

845 846 847
    MaceLogger.summary(StringFormatter.table(header, data, title))


848
def convert_func(flags):
849
    configs = format_model_config(flags)
850

851
    print_configuration(configs)
852

853
    convert_model(configs, flags.cl_mem_type)
854

855
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
856
        build_model_lib(configs, flags.address_sanitizer, flags.debug_mode)
857 858 859 860 861 862 863

    print_library_summary(configs)


################################
# run
################################
L
liuqi 已提交
864
def build_mace_run(configs, target_abi, toolchain, enable_openmp,
865
                   address_sanitizer, mace_lib_type, debug_mode):
866 867 868 869 870 871 872 873 874 875 876 877 878 879
    library_name = configs[YAMLKeyword.library_name]

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    mace_run_target = MACE_RUN_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
        mace_run_target = MACE_RUN_DYNAMIC_TARGET
    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
L
liuqi 已提交
880
                   "You should convert model first.")
881 882 883 884 885
        build_arg = "--per_file_copt=mace/tools/validation/mace_run.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa

    sh_commands.bazel_build(
        mace_run_target,
        abi=target_abi,
L
liuqi 已提交
886
        toolchain=toolchain,
B
Bin Li 已提交
887 888
        enable_hexagon=get_hexagon_mode(configs),
        enable_hta=get_hta_mode(configs),
889
        enable_openmp=enable_openmp,
Y
yejianwu 已提交
890
        enable_opencl=get_opencl_mode(configs),
891
        enable_quantize=get_quantize_mode(configs),
892
        address_sanitizer=address_sanitizer,
893 894
        symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),
        debug_mode=debug_mode,
895 896 897 898 899 900
        extra_args=build_arg
    )
    sh_commands.update_mace_run_binary(build_tmp_binary_dir,
                                       mace_lib_type == MACELibType.dynamic)


901 902
def build_example(configs, target_abi, toolchain, enable_openmp, mace_lib_type,
                  cl_binary_to_code, device, debug_mode):
903 904 905 906 907 908 909
    library_name = configs[YAMLKeyword.library_name]

    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

910 911 912 913 914 915 916 917 918 919 920 921 922 923
    if cl_binary_to_code:
        sh_commands.gen_opencl_binary_cpps(
            get_opencl_binary_output_path(
                library_name, target_abi, device),
            get_opencl_parameter_output_path(
                library_name, target_abi, device),
            OPENCL_CODEGEN_DIR + '/opencl_binary.cc',
            OPENCL_CODEGEN_DIR + '/opencl_parameter.cc')
    else:
        sh_commands.gen_opencl_binary_cpps(
            "", "",
            OPENCL_CODEGEN_DIR + '/opencl_binary.cc',
            OPENCL_CODEGEN_DIR + '/opencl_parameter.cc')

924 925 926 927 928 929
    libmace_target = LIBMACE_STATIC_TARGET
    if mace_lib_type == MACELibType.dynamic:
        libmace_target = LIBMACE_SO_TARGET

    sh_commands.bazel_build(libmace_target,
                            abi=target_abi,
L
liuqi 已提交
930
                            toolchain=toolchain,
931
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
932
                            enable_opencl=get_opencl_mode(configs),
933
                            enable_quantize=get_quantize_mode(configs),
B
Bin Li 已提交
934 935
                            enable_hexagon=get_hexagon_mode(configs),
                            enable_hta=get_hta_mode(configs),
L
liuqi 已提交
936
                            address_sanitizer=flags.address_sanitizer,
937 938
                            symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),  # noqa
                            debug_mode=debug_mode)
939 940 941 942 943 944

    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)
    sh.mkdir("-p", LIB_CODEGEN_DIR)

    build_arg = ""
945 946 947 948 949 950 951
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.RUN,
                   "You should convert model first.")
        model_lib_path = get_model_lib_output_path(library_name,
                                                   target_abi)
        sh.cp("-f", model_lib_path, LIB_CODEGEN_DIR)
L
Liangliang He 已提交
952
        build_arg = "--per_file_copt=examples/cli/example.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa
953

954 955 956 957 958
    if mace_lib_type == MACELibType.dynamic:
        example_target = EXAMPLE_DYNAMIC_TARGET
        sh.cp("-f", LIBMACE_DYNAMIC_PATH, LIB_CODEGEN_DIR)
    else:
        example_target = EXAMPLE_STATIC_TARGET
959
        sh.cp("-f", LIBMACE_STATIC_PATH, LIB_CODEGEN_DIR)
960 961 962

    sh_commands.bazel_build(example_target,
                            abi=target_abi,
L
liuqi 已提交
963
                            toolchain=toolchain,
964
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
965
                            enable_opencl=get_opencl_mode(configs),
966
                            enable_quantize=get_quantize_mode(configs),
B
Bin Li 已提交
967 968
                            enable_hexagon=get_hexagon_mode(configs),
                            enable_hta=get_hta_mode(configs),
L
liuqi 已提交
969
                            address_sanitizer=flags.address_sanitizer,
970
                            debug_mode=debug_mode,
971 972 973 974 975 976 977 978
                            extra_args=build_arg)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(example_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)
    if os.path.exists(LIB_CODEGEN_DIR):
        sh.rm("-rf", LIB_CODEGEN_DIR)


L
liuqi 已提交
979 980 981 982 983 984 985 986 987 988
def print_package_summary(package_path):
    title = "Library"
    header = ["key", "value"]
    data = list()
    data.append(["MACE Model package Path",
                 package_path])

    MaceLogger.summary(StringFormatter.table(header, data, title))


989
def run_mace(flags):
990
    configs = format_model_config(flags)
991 992

    clear_build_dirs(configs[YAMLKeyword.library_name])
993 994

    target_socs = configs[YAMLKeyword.target_socs]
995
    device_list = DeviceManager.list_devices(flags.device_yml)
L
liuqi 已提交
996
    if target_socs and TargetSOCTag.all not in target_socs:
L
liuqi 已提交
997 998
        device_list = [dev for dev in device_list
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
999
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
1000 1001 1002 1003 1004
        if flags.target_socs == TargetSOCTag.random:
            target_devices = sh_commands.choose_a_random_device(
                device_list, target_abi)
        else:
            target_devices = device_list
1005
        # build target
L
liuqi 已提交
1006
        for dev in target_devices:
L
liuqi 已提交
1007 1008 1009
            if target_abi in dev[YAMLKeyword.target_abis]:
                # get toolchain
                toolchain = infer_toolchain(target_abi)
1010
                device = DeviceWrapper(dev)
L
liuqi 已提交
1011 1012 1013 1014
                if flags.example:
                    build_example(configs,
                                  target_abi,
                                  toolchain,
L
liyin 已提交
1015
                                  flags.enable_openmp,
1016 1017
                                  flags.mace_lib_type,
                                  flags.cl_binary_to_code,
1018 1019
                                  device,
                                  flags.debug_mode)
L
liuqi 已提交
1020 1021 1022 1023
                else:
                    build_mace_run(configs,
                                   target_abi,
                                   toolchain,
L
liyin 已提交
1024
                                   flags.enable_openmp,
L
liuqi 已提交
1025
                                   flags.address_sanitizer,
1026 1027
                                   flags.mace_lib_type,
                                   flags.debug_mode)
L
liuqi 已提交
1028
                # run
1029
                start_time = time.time()
L
liuqi 已提交
1030 1031
                with device.lock():
                    device.run_specify_abi(flags, configs, target_abi)
1032 1033
                elapse_minutes = (time.time() - start_time) / 60
                print("Elapse time: %f minutes." % elapse_minutes)
L
liuqi 已提交
1034 1035 1036 1037
            elif dev[YAMLKeyword.device_name] != SystemType.host:
                six.print_('The device with soc %s do not support abi %s' %
                           (dev[YAMLKeyword.target_socs], target_abi),
                           file=sys.stderr)
1038

L
liuqi 已提交
1039 1040 1041 1042 1043
    # package the output files
    package_path = sh_commands.packaging_lib(BUILD_OUTPUT_DIR,
                                             configs[YAMLKeyword.library_name])
    print_package_summary(package_path)

1044 1045 1046 1047

################################
#  benchmark model
################################
L
liuqi 已提交
1048 1049 1050 1051
def build_benchmark_model(configs,
                          target_abi,
                          toolchain,
                          enable_openmp,
1052 1053
                          mace_lib_type,
                          debug_mode):
1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
    library_name = configs[YAMLKeyword.library_name]

    link_dynamic = mace_lib_type == MACELibType.dynamic
    if link_dynamic:
        benchmark_target = BM_MODEL_DYNAMIC_TARGET
    else:
        benchmark_target = BM_MODEL_STATIC_TARGET

    build_arg = ""
    if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
        mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
                   ModuleName.BENCHMARK,
L
liuqi 已提交
1066
                   "You should convert model first.")
L
Liangliang He 已提交
1067
        build_arg = "--per_file_copt=mace/tools/benchmark/benchmark_model.cc@-DMODEL_GRAPH_FORMAT_CODE"  # noqa
1068 1069 1070

    sh_commands.bazel_build(benchmark_target,
                            abi=target_abi,
L
liuqi 已提交
1071
                            toolchain=toolchain,
1072
                            enable_openmp=enable_openmp,
Y
yejianwu 已提交
1073
                            enable_opencl=get_opencl_mode(configs),
1074
                            enable_quantize=get_quantize_mode(configs),
B
Bin Li 已提交
1075 1076
                            enable_hexagon=get_hexagon_mode(configs),
                            enable_hta=get_hta_mode(configs),
1077 1078
                            symbol_hidden=get_symbol_hidden_mode(debug_mode, mace_lib_type),  # noqa
                            debug_mode=debug_mode,
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089
                            extra_args=build_arg)
    # clear tmp binary dir
    build_tmp_binary_dir = get_build_binary_dir(library_name, target_abi)
    if os.path.exists(build_tmp_binary_dir):
        sh.rm("-rf", build_tmp_binary_dir)
    os.makedirs(build_tmp_binary_dir)

    target_bin = "/".join(sh_commands.bazel_target_to_bin(benchmark_target))
    sh.cp("-f", target_bin, build_tmp_binary_dir)


1090
def benchmark_model(flags):
1091
    configs = format_model_config(flags)
1092 1093

    clear_build_dirs(configs[YAMLKeyword.library_name])
1094 1095

    target_socs = configs[YAMLKeyword.target_socs]
1096
    device_list = DeviceManager.list_devices(flags.device_yml)
L
liuqi 已提交
1097
    if target_socs and TargetSOCTag.all not in target_socs:
L
liuqi 已提交
1098
        device_list = [dev for dev in device_list
1099
                       if dev[YAMLKeyword.target_socs].lower() in target_socs]
1100
    for target_abi in configs[YAMLKeyword.target_abis]:
L
liuqi 已提交
1101 1102 1103 1104 1105
        if flags.target_socs == TargetSOCTag.random:
            target_devices = sh_commands.choose_a_random_device(
                device_list, target_abi)
        else:
            target_devices = device_list
1106
        # build benchmark_model binary
L
liuqi 已提交
1107
        for dev in target_devices:
L
liuqi 已提交
1108 1109 1110 1111 1112
            if target_abi in dev[YAMLKeyword.target_abis]:
                toolchain = infer_toolchain(target_abi)
                build_benchmark_model(configs,
                                      target_abi,
                                      toolchain,
L
liyin 已提交
1113
                                      flags.enable_openmp,
1114 1115
                                      flags.mace_lib_type,
                                      flags.debug_mode)
L
liuqi 已提交
1116
                device = DeviceWrapper(dev)
1117
                start_time = time.time()
L
liuqi 已提交
1118 1119
                with device.lock():
                    device.bm_specific_target(flags, configs, target_abi)
1120 1121
                elapse_minutes = (time.time() - start_time) / 60
                print("Elapse time: %f minutes." % elapse_minutes)
L
liuqi 已提交
1122 1123 1124 1125
            else:
                six.print_('There is no abi %s with soc %s' %
                           (target_abi, dev[YAMLKeyword.target_socs]),
                           file=sys.stderr)
L
liuqi 已提交
1126

1127

L
liuqi 已提交
1128
################################
Y
yejianwu 已提交
1129
# parsing arguments
L
liuqi 已提交
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141
################################
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def str_to_caffe_env_type(v):
    if v.lower() == 'docker':
1142
        return CaffeEnvType.DOCKER
L
liuqi 已提交
1143
    elif v.lower() == 'local':
1144
        return CaffeEnvType.LOCAL
L
liuqi 已提交
1145 1146 1147 1148
    else:
        raise argparse.ArgumentTypeError('[docker | local] expected.')


1149 1150 1151 1152 1153 1154 1155 1156 1157
def str_to_mace_lib_type(v):
    if v.lower() == 'dynamic':
        return MACELibType.dynamic
    elif v.lower() == 'static':
        return MACELibType.static
    else:
        raise argparse.ArgumentTypeError('[dynamic| static] expected.')


1158
def parse_args():
L
Liangliang He 已提交
1159
    """Parses command line arguments."""
1160 1161 1162
    all_type_parent_parser = argparse.ArgumentParser(add_help=False)
    all_type_parent_parser.add_argument(
        '--config',
L
Liangliang He 已提交
1163
        type=str,
1164
        default="",
L
liuqi 已提交
1165
        required=True,
1166
        help="the path of model yaml configuration file.")
1167
    all_type_parent_parser.add_argument(
1168
        "--model_graph_format",
1169 1170
        type=str,
        default="",
1171 1172 1173 1174 1175 1176
        help="[file, code], MACE Model graph format.")
    all_type_parent_parser.add_argument(
        "--model_data_format",
        type=str,
        default="",
        help="['file', 'code'], MACE Model data format.")
1177 1178 1179 1180 1181
    all_type_parent_parser.add_argument(
        "--target_abis",
        type=str,
        default="",
        help="Target ABIs, comma seperated list.")
1182 1183 1184 1185 1186
    all_type_parent_parser.add_argument(
        "--target_socs",
        type=str,
        default="",
        help="Target SOCs, comma seperated list.")
1187 1188 1189 1190
    all_type_parent_parser.add_argument(
        "--debug_mode",
        action="store_true",
        help="Reserve debug symbols.")
1191 1192
    convert_run_parent_parser = argparse.ArgumentParser(add_help=False)
    convert_run_parent_parser.add_argument(
1193 1194
        '--address_sanitizer',
        action="store_true",
L
liuqi 已提交
1195
        help="Whether to use address sanitizer to check memory error")
1196
    run_bm_parent_parser = argparse.ArgumentParser(add_help=False)
1197 1198 1199 1200 1201 1202
    run_bm_parent_parser.add_argument(
        "--mace_lib_type",
        type=str_to_mace_lib_type,
        default=DefaultValues.mace_lib_type,
        help="[static | dynamic], Which type MACE library to use.")
    run_bm_parent_parser.add_argument(
L
liyin 已提交
1203
        "--enable_openmp",
1204
        action="store_true",
L
liyin 已提交
1205
        help="Enable openmp for multiple thread.")
1206
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1207 1208
        "--omp_num_threads",
        type=int,
1209
        default=DefaultValues.omp_num_threads,
W
wuchenghui 已提交
1210
        help="num of openmp threads")
1211
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1212 1213
        "--cpu_affinity_policy",
        type=int,
1214
        default=DefaultValues.cpu_affinity_policy,
W
wuchenghui 已提交
1215
        help="0:AFFINITY_NONE/1:AFFINITY_BIG_ONLY/2:AFFINITY_LITTLE_ONLY")
1216
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1217 1218
        "--gpu_perf_hint",
        type=int,
1219
        default=DefaultValues.gpu_perf_hint,
W
wuchenghui 已提交
1220
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
1221
    run_bm_parent_parser.add_argument(
W
wuchenghui 已提交
1222 1223
        "--gpu_priority_hint",
        type=int,
1224
        default=DefaultValues.gpu_priority_hint,
W
wuchenghui 已提交
1225
        help="0:DEFAULT/1:LOW/2:NORMAL/3:HIGH")
L
liuqi 已提交
1226 1227 1228 1229 1230 1231
    run_bm_parent_parser.add_argument(
        "--device_yml",
        type=str,
        default='',
        help='embedded linux device config yml file'
    )
1232 1233
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
1234 1235 1236 1237
    convert = subparsers.add_parser(
        'convert',
        parents=[all_type_parent_parser, convert_run_parent_parser],
        help='convert to mace model (file or code)')
1238 1239 1240 1241 1242
    convert.add_argument(
        "--cl_mem_type",
        type=str,
        default=None,
        help="Which type of OpenCL memory type to use [image | buffer].")
1243
    convert.set_defaults(func=convert_func)
1244 1245 1246
    run = subparsers.add_parser(
        'run',
        parents=[all_type_parent_parser, run_bm_parent_parser,
1247
                 convert_run_parent_parser],
1248 1249
        help='run model in command line')
    run.set_defaults(func=run_mace)
1250 1251 1252 1253
    run.add_argument(
        "--disable_tuning",
        action="store_true",
        help="Disable tuning for specific thread.")
1254 1255
    run.add_argument(
        "--round",
L
Liangliang He 已提交
1256
        type=int,
1257 1258 1259 1260 1261
        default=1,
        help="The model running round.")
    run.add_argument(
        "--validate",
        action="store_true",
1262 1263
        help="whether to verify the results are consistent with "
             "the frameworks.")
B
Bin Li 已提交
1264
    run.add_argument(
B
Bin Li 已提交
1265 1266 1267 1268 1269
        "--layers",
        type=str,
        default="-1",
        help="'start_layer:end_layer' or 'layer', similar to python slice."
             " Use with --validate flag.")
1270
    run.add_argument(
L
liuqi 已提交
1271 1272 1273
        "--caffe_env",
        type=str_to_caffe_env_type,
        default='docker',
1274 1275
        help="[docker | local] you can specific caffe environment for"
             " validation. local environment or caffe docker image.")
1276 1277 1278 1279
    run.add_argument(
        "--vlog_level",
        type=int,
        default=0,
1280
        help="[1~5]. Verbose log level for debug.")
1281
    run.add_argument(
L
Liangliang He 已提交
1282
        "--gpu_out_of_range_check",
1283 1284 1285 1286 1287 1288
        action="store_true",
        help="Enable out of memory check for gpu.")
    run.add_argument(
        "--restart_round",
        type=int,
        default=1,
1289
        help="restart round between run.")
1290 1291 1292 1293 1294 1295
    run.add_argument(
        "--report",
        action="store_true",
        help="print run statistics report.")
    run.add_argument(
        "--report_dir",
1296 1297
        type=str,
        default="",
1298 1299
        help="print run statistics report.")
    run.add_argument(
李寅 已提交
1300 1301 1302 1303
        "--runtime_failure_ratio",
        type=float,
        default=0.0,
        help="[mock runtime failure ratio].")
L
liuqi 已提交
1304 1305 1306 1307
    run.add_argument(
        "--example",
        action="store_true",
        help="whether to run example.")
李寅 已提交
1308 1309 1310 1311 1312 1313 1314 1315 1316
    run.add_argument(
        "--quantize_stat",
        action="store_true",
        help="whether to stat quantization range.")
    run.add_argument(
        "--input_dir",
        type=str,
        default="",
        help="quantize stat input dir.")
1317 1318 1319 1320 1321
    run.add_argument(
        "--output_dir",
        type=str,
        default="",
        help="quantize stat output dir.")
1322 1323 1324 1325
    run.add_argument(
        "--cl_binary_to_code",
        action="store_true",
        help="convert OpenCL binaries to cpp.")
1326 1327
    benchmark = subparsers.add_parser(
        'benchmark',
1328
        parents=[all_type_parent_parser, run_bm_parent_parser],
1329 1330
        help='benchmark model for detail information')
    benchmark.set_defaults(func=benchmark_model)
B
Bin Li 已提交
1331 1332 1333 1334 1335 1336 1337 1338 1339 1340
    benchmark.add_argument(
        "--max_num_runs",
        type=int,
        default=100,
        help="max number of runs.")
    benchmark.add_argument(
        "--max_seconds",
        type=float,
        default=10.0,
        help="max number of seconds to run.")
L
Liangliang He 已提交
1341 1342
    return parser.parse_known_args()

1343

Y
yejianwu 已提交
1344
if __name__ == "__main__":
1345 1346
    flags, unparsed = parse_args()
    flags.func(flags)