From 37c1726fc139665542427f852f6a0c0eceb64537 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 28 Oct 2021 19:47:37 +0800 Subject: [PATCH] refactor(sdk): refactor load and run with new framework GitOrigin-RevId: b092699dee49eab068e262327b078ce157e36f26 --- CMakeLists.txt | 7 +- cmake/gflags.cmake | 1 + lite/CMakeLists.txt | 3 + lite/load_and_run/BUILD | 38 ++ lite/load_and_run/CMakeLists.txt | 29 + lite/load_and_run/dump_with_testcase.py | 404 ++++++++++++ lite/load_and_run/dump_with_testcase_mge.py | 535 +++++++++++++++ lite/load_and_run/src/helpers/common.h | 74 +++ lite/load_and_run/src/helpers/data_parser.cpp | 266 ++++++++ lite/load_and_run/src/helpers/data_parser.h | 48 ++ lite/load_and_run/src/helpers/json_loader.cpp | 297 +++++++++ lite/load_and_run/src/helpers/json_loader.h | 183 ++++++ lite/load_and_run/src/helpers/npy.h | 615 ++++++++++++++++++ lite/load_and_run/src/helpers/outdumper.cpp | 48 ++ lite/load_and_run/src/helpers/outdumper.h | 42 ++ lite/load_and_run/src/helpers/text_table.cpp | 119 ++++ lite/load_and_run/src/helpers/text_table.h | 133 ++++ lite/load_and_run/src/main.cpp | 31 + lite/load_and_run/src/models/model.cpp | 60 ++ lite/load_and_run/src/models/model.h | 49 ++ lite/load_and_run/src/models/model_lite.cpp | 50 ++ lite/load_and_run/src/models/model_lite.h | 73 +++ lite/load_and_run/src/models/model_mdl.cpp | 105 +++ lite/load_and_run/src/models/model_mdl.h | 117 ++++ .../src/options/device_options.cpp | 200 ++++++ .../load_and_run/src/options/device_options.h | 49 ++ .../src/options/extern_c_opr_options.cpp | 216 ++++++ .../src/options/extern_c_opr_options.h | 64 ++ .../src/options/fastrun_options.cpp | 231 +++++++ .../src/options/fastrun_options.h | 57 ++ lite/load_and_run/src/options/io_options.cpp | 295 +++++++++ lite/load_and_run/src/options/io_options.h | 78 +++ .../src/options/layout_options.cpp | 171 +++++ .../load_and_run/src/options/layout_options.h | 56 ++ .../src/options/optimize_options.cpp | 600 +++++++++++++++++ .../src/options/optimize_options.h | 207 ++++++ lite/load_and_run/src/options/option_base.h | 87 +++ .../src/options/plugin_options.cpp | 401 ++++++++++++ .../load_and_run/src/options/plugin_options.h | 105 +++ .../src/options/strategy_options.cpp | 96 +++ .../src/options/strategy_options.h | 68 ++ lite/load_and_run/src/strategys/strategy.cpp | 24 + lite/load_and_run/src/strategys/strategy.h | 63 ++ .../src/strategys/strategy_fitting.cpp | 24 + .../src/strategys/strategy_normal.cpp | 167 +++++ 45 files changed, 6581 insertions(+), 5 deletions(-) create mode 100644 cmake/gflags.cmake create mode 100644 lite/load_and_run/BUILD create mode 100644 lite/load_and_run/CMakeLists.txt create mode 100755 lite/load_and_run/dump_with_testcase.py create mode 100755 lite/load_and_run/dump_with_testcase_mge.py create mode 100644 lite/load_and_run/src/helpers/common.h create mode 100644 lite/load_and_run/src/helpers/data_parser.cpp create mode 100644 lite/load_and_run/src/helpers/data_parser.h create mode 100644 lite/load_and_run/src/helpers/json_loader.cpp create mode 100644 lite/load_and_run/src/helpers/json_loader.h create mode 100644 lite/load_and_run/src/helpers/npy.h create mode 100644 lite/load_and_run/src/helpers/outdumper.cpp create mode 100644 lite/load_and_run/src/helpers/outdumper.h create mode 100644 lite/load_and_run/src/helpers/text_table.cpp create mode 100644 lite/load_and_run/src/helpers/text_table.h create mode 100644 lite/load_and_run/src/main.cpp create mode 100644 lite/load_and_run/src/models/model.cpp create mode 100644 lite/load_and_run/src/models/model.h create mode 100644 lite/load_and_run/src/models/model_lite.cpp create mode 100644 lite/load_and_run/src/models/model_lite.h create mode 100644 lite/load_and_run/src/models/model_mdl.cpp create mode 100644 lite/load_and_run/src/models/model_mdl.h create mode 100644 lite/load_and_run/src/options/device_options.cpp create mode 100644 lite/load_and_run/src/options/device_options.h create mode 100644 lite/load_and_run/src/options/extern_c_opr_options.cpp create mode 100644 lite/load_and_run/src/options/extern_c_opr_options.h create mode 100644 lite/load_and_run/src/options/fastrun_options.cpp create mode 100644 lite/load_and_run/src/options/fastrun_options.h create mode 100644 lite/load_and_run/src/options/io_options.cpp create mode 100644 lite/load_and_run/src/options/io_options.h create mode 100644 lite/load_and_run/src/options/layout_options.cpp create mode 100644 lite/load_and_run/src/options/layout_options.h create mode 100644 lite/load_and_run/src/options/optimize_options.cpp create mode 100644 lite/load_and_run/src/options/optimize_options.h create mode 100644 lite/load_and_run/src/options/option_base.h create mode 100644 lite/load_and_run/src/options/plugin_options.cpp create mode 100644 lite/load_and_run/src/options/plugin_options.h create mode 100644 lite/load_and_run/src/options/strategy_options.cpp create mode 100644 lite/load_and_run/src/options/strategy_options.h create mode 100644 lite/load_and_run/src/strategys/strategy.cpp create mode 100644 lite/load_and_run/src/strategys/strategy.h create mode 100644 lite/load_and_run/src/strategys/strategy_fitting.cpp create mode 100644 lite/load_and_run/src/strategys/strategy_normal.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5464eda9c..fb7635468 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -74,7 +74,6 @@ option(MGE_ENABLE_EXCEPTIONS "Build with exceptions" ON) option(MGE_WITH_TEST "Enable test for MegEngine." OFF) option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) option(MGE_BUILD_IMPERATIVE_RT "Build _imperative_rt Python Module " ON) -option(MGE_BUILD_SDK "Build load_and_run" ON) option(MGE_INFERENCE_ONLY "Build inference only library." OFF) option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) option(MGE_WITH_ROCM "Enable ROCM support" OFF) @@ -542,6 +541,8 @@ if(MGE_WITH_TEST) include(cmake/gtest.cmake) endif() +include(cmake/gflags.cmake) + if(MGE_BUILD_IMPERATIVE_RT) set(CMAKE_CXX_STANDARD 17) endif() @@ -1147,10 +1148,6 @@ endif() add_subdirectory(src) -if(MGE_BUILD_SDK) - add_subdirectory(sdk/load-and-run) -endif() - if(MGE_BUILD_IMPERATIVE_RT) add_subdirectory(imperative) message(STATUS "Enable imperative python wrapper runtime") diff --git a/cmake/gflags.cmake b/cmake/gflags.cmake new file mode 100644 index 000000000..9dbb80350 --- /dev/null +++ b/cmake/gflags.cmake @@ -0,0 +1 @@ +add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/gflags ${CMAKE_CURRENT_BINARY_DIR}/gflags) \ No newline at end of file diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index 6e422252b..7210506de 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -150,6 +150,9 @@ if(MGE_WITH_TEST) add_subdirectory(test) endif() +#load_and_run +add_subdirectory(load_and_run) + # tools and example add_executable(rc4_encryptor tools/rc4_encrypt.cpp) diff --git a/lite/load_and_run/BUILD b/lite/load_and_run/BUILD new file mode 100644 index 000000000..41bfb3b7b --- /dev/null +++ b/lite/load_and_run/BUILD @@ -0,0 +1,38 @@ +load("//brain/megbrain/lite:flags.bzl","pthread_select") + +cc_library( + name = "mgblar", + copts = ["-std=c++14"], + + srcs = glob(["src/**/*.cpp"], exclude = ["src/main.cpp"]), + hdrs = glob(["src/**/*.h"]), + includes = ["src"], + features = if_opt([ + "no_exceptions", + "no_rtti", + ]), + defines = [ + "LITE_BUILD_WITH_MGE=1", + ], + + deps = ["//brain/megbrain/lite:lite_static_test"]+ + pthread_select( + ["@com_github_gflags_gflags//:gflags_nothreads"], + ["//external:gflags"] + ), + alwayslink = 1, + visibility = ["//visibility:public"], +) + +cc_megvii_binary( + name = "load_and_run", + copts = ["-std=c++14"], + srcs = ["src/main.cpp"], + features = if_opt([ + "no_exceptions", + "no_rtti", + ]), + internal_deps = [":mgblar"], + visibility = ["//visibility:public"], +) + diff --git a/lite/load_and_run/CMakeLists.txt b/lite/load_and_run/CMakeLists.txt new file mode 100644 index 000000000..9a8ea8e6f --- /dev/null +++ b/lite/load_and_run/CMakeLists.txt @@ -0,0 +1,29 @@ +# BUILD the load and run for lite +include_directories(PUBLIC $) +file (GLOB_RECURSE SOURCES ./*.cpp) + +add_executable (load_and_run ${SOURCES}) + +target_link_libraries(load_and_run lite_static) +target_link_libraries(load_and_run megbrain) +target_link_libraries(load_and_run gflags) + +if(LITE_BUILD_WITH_RKNPU) + #rknn sdk1.0.0 depend on libc++_shared, use gold to remove NEEDED so symbol check + target_link_options(load_and_run PRIVATE "-fuse-ld=gold") +endif() + +if(MGE_WITH_ROCM) + # FIXME: hip obj can not find cpp obj only through lite_static + target_link_libraries(load_and_run megdnn) +endif() + +if(UNIX) + if(APPLE OR ANDROID) + target_link_libraries(load_and_run dl) + else() + target_link_libraries(load_and_run dl rt) + endif() +endif() + +install (TARGETS load_and_run EXPORT ${LITE_EXPORT_TARGETS} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file diff --git a/lite/load_and_run/dump_with_testcase.py b/lite/load_and_run/dump_with_testcase.py new file mode 100755 index 000000000..013324c44 --- /dev/null +++ b/lite/load_and_run/dump_with_testcase.py @@ -0,0 +1,404 @@ +#!/usr/bin/env mdl +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from megskull.graph import NodeFilter, FpropEnv +from megskull.opr.all import AssertEqual, DataProvider, BatchNormalization +from megskull.utils.logconf import get_logger +from meghair.utils import io +import megbrain as mgb + +import argparse +import struct +import re +import os + +import numpy as np +import cv2 + +logger = get_logger(__name__) + +def auto_reformat_image(args, path, data, dst_shape): + """reformat image to target shape + + :param data: image data as numpy array + :param dst_shape: target shape + """ + dim3_format = False # required input format does not contain batch + hwc_format = False # required input format is NHWC + + if len(dst_shape) == 3: + dst_shape = (1, ) + dst_shape + dim3_format = True + + assert len(dst_shape) == 4, 'bad dst_shape: {}'.format(dst_shape) + chl = dst_shape[1] + if chl in [1, 3]: + n, c, h, w = dst_shape + dst_shape = (n, h, w, c) + else: + chl = dst_shape[3] + assert chl in [1, 3], ( + 'can not infer input format from shape: {}'.format(dst_shape)) + hwc_format = True + + # dst_shape has now been normalized to NHWC format + + if args.resize_input: + h, w = dst_shape[1:3] + data = cv2.resize(data, (w, h)) + logger.info('input {} resized to {}'.format(path, data.shape)) + + if chl == 1: + data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) + data = data[:, :, np.newaxis] + + assert data.ndim == 3 + data = data[np.newaxis] + # data normalized to NHWC format + + if not hwc_format: + data = np.transpose(data, (0, 3, 1, 2)) + + if dim3_format: + data = np.squeeze(data, 0) + + return data + +def read_input_data(args, dst_shape, dtype, path, repeat): + def check_shape_equal(dst_shape, data_shape): + assert len(data_shape) == len(dst_shape) , ( + 'input/data shapes mismatch: {} vs {}'.format( + dst_shape, data_shape)) + + if data_shape[1:] != dst_shape[1:]: + logger.warning('dst_shape is {}; data_shape is {}'.format( + dst_shape, data_shape)) + + if path.startswith('#'): + assert not args.resize_input + assert not args.input_transform + spec = path + m = re.match( + r'^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$', spec) + assert m, 'bad spec {}'.format(spec) + + rng_min = float(m.group(1)) + rng_max = float(m.group(2)) + if m.group(3): + shape_str = m.group(3) + try: + shape = shape_str[1:].split(',') + if shape[-1].strip() == '...': + shape = shape[:-1] + shape.extend(list(dst_shape[len(shape):])) + data_shape = tuple(map(int, shape)) + except ValueError as e: + raise ValueError('bad spec {}: {}'.format(spec, e.args)) + else: + data_shape = dst_shape + + check_shape_equal(dst_shape, data_shape) + return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) + + # try to load image + data = cv2.imread(path, cv2.IMREAD_COLOR) + if data is None: + assert not args.resize_input + data = io.load(path) + assert isinstance(data, np.ndarray) + else: + # load image succeeds, so we expect input format is image format + data = auto_reformat_image(args, path, data, dst_shape) + + data = np.repeat(data, repeat, axis=0) + if repeat > 1: + logger.info('repeat input for {} times, data shape is {}'.format( + repeat, data.shape)) + + check_shape_equal(dst_shape, data.shape) + + if args.input_transform: + data = eval(args.input_transform, {'data': data, 'np': np}) + + return data + + +def gen_one_testcase(args, inputs, spec): + paths = spec.split(';') + if len(paths) != len(inputs): + if len(paths) == 1 and paths[0].startswith('#'): + paths = ['{}:{}'.format(name, paths[0]) for name in inputs.keys()] + assert len(paths) == len(inputs), ( + 'required inputs: {}; data paths: {}'.format(inputs.keys(), paths)) + if len(paths) == 1 and ':' not in paths[0]: + paths[0] = next(iter(inputs.keys())) + ':' + paths[0] + + ret = {} + for path in paths: + var, path = path.split(':') + if args.repeat: + repeat = args.repeat + else: + repeat = 1 + ret[var] = read_input_data(args, inputs[var].imm_shape, + inputs[var].dtype, path, repeat) + return ret + + +def make_feeds(args): + outputs = io.load_network(args.input).outputs + if not args.no_assert: + env = FpropEnv(verbose_fprop=False) + # set flag so ExternCOprPlaceholder produce expected output + env.flags.user['extern_c_opr_eval'] = True + func = env.comp_graph.compile(None, [mgb.copy_output(env.get_mgbvar(i)) + for i in outputs]) + + def expect_name(var): return 'expect:{}'.format(var.name) + + nf = NodeFilter.make_all_deps(*outputs) + inputs = {i.name: i for i in nf.data_provider()} + if args.init_bn: + for i in nf: + if isinstance(i, BatchNormalization): + if i._iter.get_value() == 0: + i._iter.set_value(1) + i._variance.set_value(np.ones(i._variance.shape)) + + testcases = [] + + np.set_printoptions(precision=2, threshold=4, suppress=True) + + data_list = [] + for item in args.data: + if item.startswith('@'): + with open(item[1:], 'r') as f: + data_list.extend([ line.rstrip() for line in f if line.rstrip() != '']) + else: + data_list.append(item) + + for inp_spec in data_list: + cur_testcase = gen_one_testcase(args, inputs, inp_spec) + assert len(cur_testcase) == len(inputs), ( + 'required inputs: {}; given data: {}'.format( + inputs.keys(), cur_testcase.keys())) + + if not args.no_assert: + outputs_get = func(**cur_testcase) + for var, val in zip(outputs, outputs_get): + cur_testcase[expect_name(var)] = val + logger.info( + 'generate test groundtruth: var={} shape={} range=({}, {})' + ' mean={} var={}'.format( + var, val.shape, val.min(), val.max(), + np.mean(val), np.var(val))) + testcases.append(cur_testcase) + logger.info('add testcase: \n {}'.format( + '\n '.join('{}: shape={} dtype={} range=({:.2f},{:.2f}) ' + 'mean={:.2f} sd={:.2f}'.format( + k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), + np.std(v)) + for k, v in sorted(cur_testcase.items())))) + + if not args.no_assert: + def expect_shp(var): + ret = var.partial_shape.determined_shape + if ret: + return ret + return testcases[0][expect_name(var)].shape + + verbose = not args.silent + outputs = [AssertEqual(DataProvider(expect_name(i), expect_shp(i), + dtype=i.dtype, + comp_node=i.comp_node), + i, verbose=verbose, maxerr=args.maxerr) + for i in outputs] + return {'outputs': outputs, 'testcases': testcases} + +def optimize_for_inference(args, outputs): + args_map = { + 'enable_io16xc32': 'f16_io_f32_comp', + 'enable_ioc16': 'f16_io_comp', + 'enable_hwcd4': 'use_nhwcd4', + 'enable_nchw4': 'use_nchw4', + 'enable_nchw88': 'use_nchw88', + 'enable_nchw44': 'use_nchw44', + 'enable_nchw44_dot': 'use_nchw44_dot', + 'enable_nchw32': 'use_nchw32', + 'enable_chwn4': 'use_chwn4', + 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity', + 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z', + 'enable_nchw64': 'use_nchw64', + 'enable_fuse_preprocess': 'fuse_preprocess', + } + + kwargs = {} + for k, v in args_map.items(): + if getattr(args, k): + assert args.optimize_for_inference, ( + 'optimize_for_inference should be set when {} is given'.format( + k)) + kwargs[v] = True + + if args.optimize_for_inference: + return mgb.optimize_for_inference(outputs, **kwargs) + + return outputs + +def main(): + parser = argparse.ArgumentParser( + description='Pack computing graph, input values and expected output ' + 'values into one file for checking correctness. README.md gives more ' + 'details on the usage', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('input', help='input file; see README for details') + parser.add_argument('-o', '--output', help='output file', required=True) + parser.add_argument('--init-bn', action='store_true', + help='initialize untrained batch-normalization, to ' + 'avoid NaN or Inf results') + parser.add_argument( + '-d', '--data', default=[], action='append', + help='Given input test data when input file is a network, ' + 'and current network output would be used as groundtruth. ' + 'The format is var0:file0;var1:file1... to specify data files for ' + 'input vars. It can also be #rand(min,max,shape...) for generating ' + 'random input data, for example, #rand(0,255), ' + '#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means ' + 'the remaining part of the original shape. ' + 'If the shape is not specified, the shape of ' + 'corresponding DataProvider in the network will be used. ' + 'If there is only one input var, its name can be omitted. ' + 'Each data file can either be an image which can be loaded by opencv, ' + 'or a pickled numpy.ndarray. ' + 'This option can be given multiple times to add multiple testcases. ' + ' *NOTE* ' + 'If you start the data with the letter @, the rest should be a ' + 'filename, and each line in the file should be a single datum in ' + 'the format described above. ' + ) + parser.add_argument( + '--repeat', type=int, default=1, + help='Specify how many times the input image is repeated. ' + 'Useful when running benchmark for batch size other than one. ' + 'Have no effect on randomly generated input data.') + parser.add_argument('--silent', action='store_true', + help='set verbose to False in AssertEqual opr') + parser.add_argument('--optimize-for-inference', action='store_true', + help='enbale optimization for inference') + parser.add_argument('--no-assert', action='store_true', + help='do not insert AssertEqual opr to check result; ' + 'this option is useful for benchmarking') + parser.add_argument('--maxerr', type=float, default=AssertEqual.maxerr, + help='max error for AssertEqual check during runtime') + parser.add_argument('--resize-input', action='store_true', + help='resize input image to fit input var shape') + parser.add_argument('--input-transform', + help='a python expression to transform the input data. ' + 'Example: data / np.std(data)') + parser.add_argument('--discard-var-name', action='store_true', + help='discard variable and param names in the ' + 'generated output') + parser.add_argument('--output-strip-info', action='store_true', + help='output code strip information') + parser.add_argument('--enable-io16xc32', action='store_true', + help='transform the mode to float16 io float32 compute') + parser.add_argument('--enable-ioc16', action='store_true', + help='transform the dtype of the model to float16 io ' + 'and compute') + parser.add_argument('--enable-fuse-conv-bias-nonlinearity', + action='store_true', + help='fuse convolution bias and nonlinearity opr to a ' + 'conv_bias opr and compute') + parser.add_argument('--enable-hwcd4', action='store_true', + help='transform the model format from NCHW to NHWCD4 ' + 'for inference; you may need to disable CUDA and set ' + 'MGB_USE_MEGDNN_DBG=2') + parser.add_argument('--enable-nchw4', action='store_true', + help='transform the model format from NCHW to NCHW4 ' + 'for inference') + parser.add_argument('--enable-nchw88', action='store_true', + help='transform the model format from NCHW to NCHW88 ' + 'for inference') + parser.add_argument('--enable-nchw44', action='store_true', + help='transform the model format from NCHW to NCHW44 ' + 'for inference') + parser.add_argument('--enable-nchw44-dot', action='store_true', + help='transform the model format from NCHW to NCHW44_DOT ' + 'for optimizing armv8.2 dot in inference') + parser.add_argument('--enable-chwn4', action='store_true', + help='transform the model format to CHWN4 ' + 'for inference, mainly used for nvidia tensorcore') + parser.add_argument('--enable-nchw32', action='store_true', + help='transform the model format from NCHW4 to NCHW32 ' + 'for inference on nvidia TensoCore') + parser.add_argument('--enable-nchw64', action='store_true', + help='transform the model format from NCHW to NCHW64 ' + 'for inference on Nvidia GPU') + parser.add_argument('--enable-fuse-conv-bias-with-z', action='store_true', + help='fuse conv_bias with z input for inference on ' + 'nvidia GPU (this optimization pass will result in mismatch ' + 'of the precision of output of training and inference)') + parser.add_argument('--enable-fuse-preprocess', action='store_true', + help='fuse astype\pad_channel\dimshuffle and etc opr ' + 'from h2d op') + args = parser.parse_args() + if args.data: + feeds = make_feeds(args) + else: + feeds = io.load(args.input) + + assert isinstance(feeds, dict) and feeds['testcases'], ( + 'testcases can not be empty') + + env = FpropEnv(verbose_fprop=False) + + outputs = feeds['outputs'] + output_mgbvars = list(map(env.get_mgbvar, outputs)) + + output_mgbvars = optimize_for_inference(args, output_mgbvars) + + inputs = sorted(((i.name, i.dtype) for i in + NodeFilter.make_all_deps(*outputs).data_provider())) + if args.discard_var_name: + sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) + else: + sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) + + with open(args.output, 'wb') as fout: + fout.write(b'mgbtest0') + fout.write(struct.pack('I', len(feeds['testcases']))) + stat = mgb.serialize_comp_graph_to_file( + args.output, output_mgbvars, append=True, + output_strip_info=args.output_strip_info, + **sereg_kwargs) + logger.info('graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'. + format(stat.tot_bytes / 1024, + (stat.tot_bytes - stat.tensor_value_bytes) / 1024)) + + for testcase in feeds['testcases']: + assert isinstance(testcase, dict) + cg = mgb.comp_graph() + cn = mgb.comp_node('cpux') + output_mgbvars = [] + for name, dtype in inputs: + output_mgbvars.append(cg.make_shared(cn, value=testcase.pop(name), + dtype=dtype)) + assert not testcase, 'extra inputs provided in testcase: {}'.format( + testcase.keys()) + + mgb.serialize_comp_graph_to_file( + args.output, + output_mgbvars, + append=True, + output_strip_info=args.output_strip_info, + append_json=True) + +if __name__ == '__main__': + main() diff --git a/lite/load_and_run/dump_with_testcase_mge.py b/lite/load_and_run/dump_with_testcase_mge.py new file mode 100755 index 000000000..2de9af342 --- /dev/null +++ b/lite/load_and_run/dump_with_testcase_mge.py @@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import argparse +import os +import re +import struct + +import cv2 +import numpy as np + +import megengine as mge +import megengine.core._imperative_rt as rt +import megengine.core.tensor.megbrain_graph as G +from megengine import tensor +from megengine.core._imperative_rt.core2 import apply +from megengine.core.ops import builtin +from megengine.utils import comp_graph_tools as cgtools + +logger = mge.get_logger(__name__) + + +def auto_reformat_image(args, path, data, dst_shape): + """reformat image to target shape + + :param data: image data as numpy array + :param dst_shape: target shape + """ + dim3_format = False # required input format does not contain batch + hwc_format = False # required input format is NHWC + + if not dst_shape: # input tensor shape is not predefined + if len(data.shape) == 2: + chl = 1 + h = data.shape[0] + w = data.shape[1] + else: + assert len(data.shape) == 3, "Input image must be of dimension 2 or 3" + h, w, chl = data.shape + dst_shape = (1, chl, h, w) + + if len(dst_shape) == 3: + dst_shape = (1,) + dst_shape + dim3_format = True + + assert len(dst_shape) == 4, "bad dst_shape: {}".format(dst_shape) + chl = dst_shape[1] + if chl in [1, 3]: + n, c, h, w = dst_shape + dst_shape = (n, h, w, c) + else: + chl = dst_shape[3] + assert chl in [1, 3], "can not infer input format from shape: {}".format( + dst_shape + ) + hwc_format = True + + # dst_shape has now been normalized to NHWC format + + if args.resize_input: + h, w = dst_shape[1:3] + data = cv2.resize(data, (w, h)) + logger.info("input {} resized to {}".format(path, data.shape)) + + if chl == 1: + data = cv2.cvtColor(data, cv2.COLOR_BGR2GRAY) + data = data[:, :, np.newaxis] + + assert data.ndim == 3 + data = data[np.newaxis] + # data normalized to NHWC format + + if not hwc_format: + data = np.transpose(data, (0, 3, 1, 2)) + + if dim3_format: + data = np.squeeze(data, 0) + + return data + + +def read_input_data(args, dst_shape, dtype, path, repeat): + def check_shape_equal(dst_shape, data_shape): + if len(dst_shape): + assert len(data_shape) == len( + dst_shape + ), "input/data shapes mismatch: {} vs {}".format(dst_shape, data_shape) + + if data_shape[1:] != dst_shape[1:]: + logger.warning( + "dst_shape is {}; data_shape is {}".format(dst_shape, data_shape) + ) + + if path.startswith("#"): + assert not args.resize_input + assert not args.input_transform + spec = path + m = re.match(r"^#rand\(([-0-9.]*)\s*,\s*([-0-9.]*)\s*(,[^\)]+)?\)$", spec) + assert m, "bad spec {}".format(spec) + + rng_min = float(m.group(1)) + rng_max = float(m.group(2)) + if m.group(3): + shape_str = m.group(3) + try: + shape = shape_str[1:].split(",") + if shape[-1].strip() == "...": + shape = shape[:-1] + shape.extend(list(dst_shape[len(shape) :])) + data_shape = tuple(map(int, shape)) + except ValueError as e: + raise ValueError("bad spec {}: {}".format(spec, e.args)) + else: + data_shape = dst_shape + + check_shape_equal(dst_shape, data_shape) + return np.random.uniform(rng_min, rng_max, data_shape).astype(dtype) + + # try to load image + data = cv2.imread(path, cv2.IMREAD_COLOR) + if data is None: + assert not args.resize_input + data = np.load(path) + assert isinstance(data, np.ndarray) + else: + # load image succeeds, so we expect input format is image format + data = auto_reformat_image(args, path, data, dst_shape) + + data = np.repeat(data, repeat, axis=0) + if repeat > 1: + logger.info( + "repeat input for {} times, data shape is {}".format(repeat, data.shape) + ) + + check_shape_equal(dst_shape, data.shape) + + if args.input_transform: + data = eval(args.input_transform, {"data": data, "np": np}) + + return data + + +def gen_one_testcase(args, inputs, spec): + paths = spec.split(";") + if len(paths) != len(inputs): + if len(paths) == 1 and paths[0].startswith("#"): + paths = ["{}:{}".format(name, paths[0]) for name in inputs.keys()] + assert len(paths) == len(inputs), "required inputs: {}; data paths: {}".format( + inputs.keys(), paths + ) + if len(paths) == 1 and ":" not in paths[0]: + paths[0] = next(iter(inputs.keys())) + ":" + paths[0] + + ret = {} + for path in paths: + var, path = path.split(":") + if args.repeat: + repeat = args.repeat + else: + repeat = 1 + ret[var] = read_input_data( + args, inputs[var].shape, inputs[var].dtype, path, repeat + ) + return ret + + +def make_feeds(args): + ret = G.load_graph(args.input) + cg_rt, outputs = ret.graph, ret.output_vars_list + inputs = cgtools.get_dep_vars(outputs, "Host2DeviceCopy") + + inputs = {i.name: i for i in inputs} + if not args.no_assert: + + replace_varmap = {} + inp_map = {} + # replace var use InputNode + for name, var in inputs.items(): + inp = G.InputNode( + device="xpux", dtype=var.dtype, shape=var.shape, graph=cg_rt + ) + replace_varmap[var] = inp.outputs[0] + inp_map[name] = inp + + new = cgtools.replace_vars(outputs, replace_varmap) + if isinstance(new, rt.VarNode): + new = list(new) + + output_nodes = [G.OutputNode(var) for var in new] + func = cg_rt.compile([node.outputs[0] for node in output_nodes]) + + def make_dev_tensor(value, dtype=None, device=None): + return tensor(value, dtype=dtype, device=device)._dev_tensor() + + def calculate(*args, **kwargs): + output_val = [] + # set inputs value + for name, var in inputs.items(): + val = kwargs.pop(name, None) + assert val is not None, "miss input name{}".format(name) + dev_tensor = make_dev_tensor(val, dtype=var.dtype, device="xpux") + inp_map[name].set_value(dev_tensor) + + func.execute() + + for res in output_nodes: + output_val.append(res.get_value().numpy()) + return output_val + + def expect_name(var): + return "{}:expect".format(var.name) + + testcases = [] + + np.set_printoptions(precision=2, threshold=4, suppress=True) + + data_list = [] + for item in args.data: + if item.startswith("@"): + with open(item[1:], "r") as f: + data_list.extend([line.rstrip() for line in f if line.rstrip() != ""]) + else: + data_list.append(item) + + for inp_spec in data_list: + cur_testcase = gen_one_testcase(args, inputs, inp_spec) + assert len(cur_testcase) == len( + inputs + ), "required inputs: {}; given data: {}".format( + inputs.keys(), cur_testcase.keys() + ) + + if not args.no_assert: + outputs_get = calculate(**cur_testcase) + for var, val in zip(outputs, outputs_get): + cur_testcase[expect_name(var)] = val + logger.info( + "generate test groundtruth: var={} shape={} range=({}, {})" + " mean={} var={}".format( + var, val.shape, val.min(), val.max(), np.mean(val), np.var(val) + ) + ) + testcases.append(cur_testcase) + logger.info( + "add testcase: \n {}".format( + "\n ".join( + "{}: shape={} dtype={} range=({:.2f},{:.2f}) " + "mean={:.2f} sd={:.2f}".format( + k, v.shape, v.dtype, v.min(), v.max(), np.mean(v), np.std(v) + ) + for k, v in sorted(cur_testcase.items()) + ) + ) + ) + + if not args.no_assert: + + def expect_shp(var): + ret = var.shape + if ret: + return ret + return testcases[0][expect_name(var)].shape + + def assert_equal(expect, real, **kwargs): + op = builtin.AssertEqual(**kwargs) + (res,) = G.apply_normal_varnode(op, expect, real) + return res + + verbose = not args.silent + + outputs_new = [] + for i in outputs: + device = rt.CompNode("xpux") + dtype = i.dtype + name = expect_name(i) + shape = expect_shp(i) + # make expect output as one input of model. + expect_get = rt.make_h2d(cg_rt, device, dtype, shape, name) + # insert assert opr to check expect and real. + outputs_new.append( + assert_equal( + expect_get, + i, + verbose=verbose, + maxerr=args.maxerr, + ) + ) + inputs[expect_name(i)] = expect_get + outputs = outputs_new + + return {"outputs": outputs, "testcases": testcases} + + +def optimize_for_inference(args, outputs): + args_list = [ + "enable_io16xc32", + "enable_ioc16", + "enable_hwcd4", + "enable_nchw4", + "enable_nchw88", + "enable_nchw44", + "enable_nchw44_dot", + "enable_nchw32", + "enable_chwn4", + "enable_fuse_conv_bias_nonlinearity", + "enable_fuse_conv_bias_with_z", + "enable_fuse_preprocess", + ] + kwargs = {} + for k in args_list: + if getattr(args, k): + kwargs[k] = True + + if args.optimize_for_inference: + outputs = G.optimize_for_inference(outputs, **kwargs) + + return outputs + + +def main(): + parser = argparse.ArgumentParser( + description="Pack computing graph, input values and expected output " + "values into one file for checking correctness. README.md gives more " + "details on the usage", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("input", help="MegEngine dumped model file") + parser.add_argument("-o", "--output", help="output file", required=True) + parser.add_argument( + "-d", + "--data", + default=[], + action="append", + required=True, + help="Given input test data when input file is a network, " + "and current network output would be used as groundtruth. " + "The format is var0:file0;var1:file1... to specify data files for " + "input vars. It can also be #rand(min,max,shape...) for generating " + "random input data, for example, #rand(0,255), " + "#rand(0,255,1,3,224,224) or #rand(0, 255, 1, ...) where `...` means " + "the remaining part of the original shape. " + "If the shape is not specified, the shape of " + "corresponding input tensors in the network will be used. " + "If there is only one input var, its name can be omitted. " + "Each data file can either be an image which can be loaded by opencv, " + "or a pickled numpy.ndarray. " + "This option can be given multiple times to add multiple testcases. " + " *NOTE* " + "If you start the data with the letter @, the rest should be a " + "filename, and each line in the file should be a single datum in " + "the format described above. ", + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="Specify how many times the input image is repeated. " + "Useful when running benchmark for batch size other than one. " + "Have no effect on randomly generated input data.", + ) + parser.add_argument( + "--silent", + action="store_true", + help="set verbose to False in asserti_equal opr", + ) + parser.add_argument( + "--optimize-for-inference", + action="store_true", + help="enable optimization for inference", + ) + parser.add_argument( + "--no-assert", + action="store_true", + help="do not insert assert_equal opr to check result; " + "this option is useful for benchmarking", + ) + parser.add_argument( + "--maxerr", + type=float, + default=1e-4, + help="max error for assert_equal check during runtime", + ) + parser.add_argument( + "--resize-input", + action="store_true", + help="resize input image to fit input var shape", + ) + parser.add_argument( + "--input-transform", + help="a python expression to transform the input data. " + "Example: data / np.std(data)", + ) + parser.add_argument( + "--discard-var-name", + action="store_true", + help="discard variable and param names in the " "generated output", + ) + parser.add_argument( + "--output-strip-info", action="store_true", help="output code strip information" + ) + parser.add_argument( + "--enable-io16xc32", + action="store_true", + help="transform the mode to float16 io float32 compute", + ) + parser.add_argument( + "--enable-ioc16", + action="store_true", + help="transform the dtype of the model to float16 io " "and compute", + ) + parser.add_argument( + "--enable-fuse-conv-bias-nonlinearity", + action="store_true", + help="fuse convolution bias and nonlinearity opr to a " + "conv_bias opr and compute", + ) + parser.add_argument( + "--enable-hwcd4", + action="store_true", + help="transform the model format from NCHW to NHWCD4 " + "for inference; you may need to disable CUDA and set " + "MGB_USE_MEGDNN_DBG=2", + ) + parser.add_argument( + "--enable-nchw4", + action="store_true", + help="transform the model format from NCHW to NCHW4 " "for inference", + ) + parser.add_argument( + "--enable-nchw88", + action="store_true", + help="transform the model format from NCHW to NCHW88 " "for inference", + ) + parser.add_argument( + "--enable-nchw44", + action="store_true", + help="transform the model format from NCHW to NCHW44 " "for inference", + ) + parser.add_argument( + "--enable-nchw44-dot", + action="store_true", + help="transform the model format from NCHW to NCHW44_DOT " + "for optimizing armv8.2 dot in inference", + ) + parser.add_argument( + "--enable-nchw32", + action="store_true", + help="transform the model format from NCHW4 to NCHW32 " + "for inference on nvidia TensoCore", + ) + parser.add_argument( + "--enable-chwn4", + action="store_true", + help="transform the model format to CHWN4 " + "for inference, mainly used for nvidia tensorcore", + ) + parser.add_argument( + "--enable-fuse-conv-bias-with-z", + action="store_true", + help="fuse conv_bias with z input for inference on " + "nvidia GPU (this optimization pass will result in mismatch " + "of the precision of output of training and inference)", + ) + parser.add_argument( + "--enable-fuse-preprocess", + action="store_true", + help="fuse astype\pad_channel\dimshuffle and etc opr " + "from h2d opr", + ) + args = parser.parse_args() + + feeds = make_feeds(args) + + assert isinstance(feeds, dict) and feeds["testcases"], "testcases can not be empty" + + output_mgbvars = feeds["outputs"] + output_mgbvars = optimize_for_inference(args, output_mgbvars) + + inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") + inputs = sorted((i.name, i.dtype) for i in inputs) + + if args.discard_var_name: + sereg_kwargs = dict(keep_var_name=0, keep_param_name=False) + else: + sereg_kwargs = dict(keep_var_name=2, keep_param_name=True) + + strip_info_file = args.output + ".json" if args.output_strip_info else None + + with open(args.output, "wb") as fout: + fout.write(b"mgbtest0") + fout.write(struct.pack("I", len(feeds["testcases"]))) + dump_content, stat = G.dump_graph( + output_mgbvars, + append_json=True, + strip_info_file=strip_info_file, + **sereg_kwargs, + ) + fout.write(dump_content) + + logger.info( + "graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB".format( + stat.tot_bytes / 1024, (stat.tot_bytes - stat.tensor_value_bytes) / 1024 + ) + ) + + def make_dev_tensor(value, dtype=None, device=None): + return tensor(value, dtype=dtype, device=device)._dev_tensor() + + for testcase in feeds["testcases"]: + assert isinstance(testcase, dict) + cg = G.Graph() + output_mgbvars = [] + for name, dtype in inputs: + output_mgbvars.append( + cg.make_const( + make_dev_tensor(testcase.pop(name), dtype=dtype, device="cpux") + ) + ) + assert not testcase, "extra inputs provided in testcase: {}".format( + testcase.keys() + ) + with open(args.output, "ab") as fout: + dump_content, _ = G.dump_graph( + output_mgbvars, strip_info_file=strip_info_file, append_json=True + ) + fout.write(dump_content) + + +if __name__ == "__main__": + main() diff --git a/lite/load_and_run/src/helpers/common.h b/lite/load_and_run/src/helpers/common.h new file mode 100644 index 000000000..6fc04bc48 --- /dev/null +++ b/lite/load_and_run/src/helpers/common.h @@ -0,0 +1,74 @@ +/** + * \file lite/load_and_run/src/helpers/common.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include +DECLARE_int32(thread); +namespace lar { +/*! + * \brief: state of model running + */ +enum class RunStage { + + BEFORE_MODEL_LOAD = 0, + + AFTER_MODEL_LOAD = 1, + + BEFORE_OUTSPEC_SET = 2, + + //! using for dump static memory information svg file + AFTER_OUTSPEC_SET = 3, + + //! using for external c opr library + MODEL_RUNNING = 4, + + //! using for output dumper + AFTER_RUNNING_WAIT = 5, + + //! using for external c opr library + AFTER_RUNNING_ITER = 6, + + AFTER_MODEL_RUNNING = 7, +}; +/*! + * \brief: type of different model + */ +enum class ModelType { + LITE_MODEL = 0, + MEGDL_MODEL, + UNKNOWN, +}; +/*! + * \brief: param for running model + */ +struct RuntimeParam { + RunStage stage = RunStage::AFTER_MODEL_LOAD; + size_t warmup_iter; //! warm up number before running model + size_t run_iter; //! iteration number for running model + size_t threads = FLAGS_thread; //! thread number for running model (NOTE:it's + //! different from multithread device ) + size_t testcase_num = 1; //! testcase number for model with testcase +}; +/*! + * \brief:layout type for running model optimization + */ +enum class OptLayoutType { + NCHW4 = 1 << 0, + CHWN4 = 1 << 1, + NCHW44 = 1 << 2, + NCHW88 = 1 << 3, + NCHW32 = 1 << 4, + NCHW64 = 1 << 5, + NHWCD4 = 1 << 6, + NCHW44_DOT = 1 << 7 +}; + +} // namespace lar +// vim: syntax=cpp.doxygen diff --git a/lite/load_and_run/src/helpers/data_parser.cpp b/lite/load_and_run/src/helpers/data_parser.cpp new file mode 100644 index 000000000..0ba71626f --- /dev/null +++ b/lite/load_and_run/src/helpers/data_parser.cpp @@ -0,0 +1,266 @@ +/** + * \file lite/load_and_run/src/helpers/data_parser.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "data_parser.h" +#include +#include "json_loader.h" +#include "npy.h" + +using namespace lar; + +/*! + * \brief feed different data to diffferent parser + * \param path data file path or data string + */ +void DataParser::feed(const std::string& path) { + std::string blob_name = "data", blob_string = path; + size_t sep = path.find(":"); + if (sep != std::string::npos) { + blob_name = path.substr(0, sep); + blob_string = path.substr(sep + 1); + } + + auto endWith = [blob_string](std::string suffix) -> bool { + return blob_string.rfind(suffix) == (blob_string.length() - suffix.length()); + }; + + if (endWith(".ppm") || endWith(".pgm")) { + parse_image(blob_name, blob_string); + } else if (endWith(".json")) { + parse_json(blob_string); + } else if (endWith(".npy")) { + parse_npy(blob_name, blob_string); + } else { + parse_string(blob_name, blob_string); + } +} + +void DataParser::parse_json(const std::string& path) { + mgb::JsonLoader json; + std::shared_ptr root = json.load(path.c_str()); + + mgb_assert(root != nullptr, "parse json %s fail", path.c_str()); + // parse json to data map + const std::string SHAPE = "shape", TYPE = "type", RAW = "raw"; + for (auto& item : root->objects()) { + auto&& value = *item.second; + auto&& shape = value[SHAPE]; + mgb_assert(shape->is_array()); + + auto&& type = value[TYPE]; + mgb_assert(type->is_str()); + + auto&& raw = value[RAW]; + mgb_assert(raw->is_array()); + + megdnn::SmallVector data_shape; + for (auto&& shape_ptr : shape->array()) { + data_shape.append({static_cast(std::round(shape_ptr->number()))}); + } + + // get type + const std::map type_map = { + {"float32", mgb::dtype::Float32()}, {"float", mgb::dtype::Float32()}, + {"int32", mgb::dtype::Int32()}, {"int", mgb::dtype::Int32()}, + {"int8", mgb::dtype::Int8()}, {"uint8", mgb::dtype::Uint8()}}; + + const std::string& type_str = type->str(); + mgb_assert( + type_map.find(type_str) != type_map.end(), + "unknown json data type for --input"); + + mgb::DType datatype = type_map.at(type_str); + mgb::HostTensorND hv; + hv.comp_node(mgb::CompNode::default_cpu(), true) + .dtype(datatype) + .resize(data_shape); + mgb::dt_byte* raw_ptr = hv.raw_ptr(); + size_t elem_size = datatype.size(); + + // get raw + const size_t array_size = raw->len(); + for (size_t idx = 0; idx < array_size; ++idx) { + double tmp = (*raw)[idx]->number(); + + switch (datatype.enumv()) { + case megdnn::DTypeEnum::Int32: { + int32_t ival = std::round(tmp); + memcpy(((char*)raw_ptr) + idx * elem_size, &ival, elem_size); + } break; + case megdnn::DTypeEnum::Uint8: + case megdnn::DTypeEnum::Int8: { + int8_t cval = std::round(tmp); + memcpy(((char*)raw_ptr) + idx, &cval, sizeof(int8_t)); + } break; + case megdnn::DTypeEnum::Float32: { + float fval = tmp; + memcpy(((char*)raw_ptr) + idx * elem_size, &fval, elem_size); + } break; + default: + break; + } + } + + inputs.insert(std::make_pair(item.first, std::move(hv))); + } +} + +void DataParser::parse_image(const std::string& name, const std::string& path) { + // load binary ppm/pgm + std::ifstream fin; + fin.open(path, std::ifstream::binary | std::ifstream::in); + mgb_assert(fin.is_open(), "open file %s failed for --input", path.c_str()); + + size_t w = 0, h = 0, channel = 0; + char buf[128] = {0}; + + fin.getline(buf, 128); + if ('5' == buf[1]) { + channel = 1; + } else if ('6' == buf[1]) { + channel = 3; + } else { + mgb_assert(0, "not a formal ppm/pgm"); + } + + while (fin.getline(buf, 128)) { + if (buf[0] == '#') { + continue; + } + break; + } + std::stringstream ss; + ss << std::string(buf); + ss >> w; + ss >> h; + + mgb_assert(w > 0 and h > 0); + + mgb::HostTensorND hv; + hv.comp_node(mgb::CompNode::default_cpu(), true) + .dtype(mgb::dtype::Uint8()) + .resize({1, h, w, channel}); + + fin.read((char*)(hv.raw_ptr()), hv.layout().total_nr_elems()); + fin.close(); + inputs.insert(std::make_pair(name, std::move(hv))); +} + +void DataParser::parse_npy(const std::string& name, const std::string& path) { + std::string type_str; + std::vector stl_shape; + std::vector raw; + npy::LoadArrayFromNumpy(path, type_str, stl_shape, raw); + + megdnn::SmallVector shape; + for (auto val : stl_shape) { + shape.append({static_cast(val)}); + } + + const std::map type_map = { + {"f4", mgb::dtype::Float32()}, {"i4", mgb::dtype::Int32()}, + {"i2", mgb::dtype::Int16()}, {"u2", mgb::dtype::Uint16()}, + {"i1", mgb::dtype::Int8()}, {"u1", mgb::dtype::Uint8()}}; + + megdnn::DType hv_type; + for (auto& item : type_map) { + if (type_str.find(item.first) != std::string::npos) { + hv_type = item.second; + break; + } + } + + mgb::HostTensorND hv; + hv.comp_node(mgb::CompNode::default_cpu(), true).dtype(hv_type).resize(shape); + mgb::dt_byte* raw_ptr = hv.raw_ptr(); + memcpy(raw_ptr, raw.data(), raw.size()); + + inputs.insert(std::make_pair(name, std::move(hv))); +} + +void DataParser::parse_string(const std::string name, const std::string& str) { + // data type + megdnn::DType data_type = mgb::dtype::Int32(); + if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) { + data_type = mgb::dtype::Float32(); + } + // shape + size_t number_cnt = 0; + + std::shared_ptr brace_root = std::make_shared(); + std::shared_ptr cur = brace_root; + for (size_t i = 0; i < str.size(); ++i) { + char c = str[i]; + if (c == '[') { + std::shared_ptr child = std::make_shared(); + child->parent = cur; + cur->chidren.emplace_back(child); + cur = child; + } else if (c == ']') { + cur = cur->parent.lock(); + } else if (c == ',') { + number_cnt++; + } + continue; + } + ++number_cnt; + + mgb_assert(cur == brace_root, "braces not closed for --input"); + megdnn::SmallVector shape; + cur = brace_root; + while (not cur->chidren.empty()) { + shape.append({cur->chidren.size()}); + number_cnt /= cur->chidren.size(); + cur = cur->chidren[0]; + } + mgb_assert(number_cnt > 0); + shape.append({number_cnt}); + + // data + std::string json_arr; + for (size_t i = 0; i < str.size(); ++i) { + char c = str[i]; + if (c != '[' and c != ']') { + json_arr += c; + } + } + json_arr = "[" + json_arr + "]"; + + // reuse json parser to resolve raw data + mgb::JsonLoader json; + std::shared_ptr json_root = + json.load(json_arr.data(), json_arr.size()); + mgb_assert(json_root != nullptr, "parse json fail in parse_string"); + + mgb::HostTensorND hv; + hv.comp_node(mgb::CompNode::default_cpu(), true).dtype(data_type).resize(shape); + mgb::dt_byte* raw_ptr = hv.raw_ptr(); + + const size_t array_len = json_root->len(); + const size_t elem_size = data_type.size(); + for (size_t idx = 0; idx < array_len; ++idx) { + double tmp = json_root->array()[idx]->number(); + switch (data_type.enumv()) { + case megdnn::DTypeEnum::Int32: { + int32_t ival = std::round(tmp); + memcpy(((char*)raw_ptr) + idx * elem_size, &ival, elem_size); + } break; + case megdnn::DTypeEnum::Float32: { + float fval = tmp; + memcpy(((char*)raw_ptr) + idx * elem_size, &fval, elem_size); + } break; + default: + break; + } + } + inputs.insert(std::make_pair(name, std::move(hv))); +} diff --git a/lite/load_and_run/src/helpers/data_parser.h b/lite/load_and_run/src/helpers/data_parser.h new file mode 100644 index 000000000..21dac3879 --- /dev/null +++ b/lite/load_and_run/src/helpers/data_parser.h @@ -0,0 +1,48 @@ +/** + * \file lite/load_and_run/src/helpers/data_parser.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include +#include "megbrain/opr/io.h" + +namespace lar { +/*! + * \brief data parser for --input + * support .json|.ppm|.pgm|.npy data and user define data string + * data string format: [0,0,227,227] + */ +struct DataParser { + struct Brace { + std::weak_ptr parent; + std::vector> chidren; + }; + void feed(const std::string& path); + + std::unordered_map inputs; + +private: + //! parser for json data + void parse_json(const std::string& path); + + //! parser for .ppm .pgm image + void parse_image(const std::string& name, const std::string& path); + + //! parser for .npy data + void parse_npy(const std::string& name, const std::string& path); + + //! parser for user define string + void parse_string(const std::string name, const std::string& str); +}; +} // namespace lar diff --git a/lite/load_and_run/src/helpers/json_loader.cpp b/lite/load_and_run/src/helpers/json_loader.cpp new file mode 100644 index 000000000..cd4609f22 --- /dev/null +++ b/lite/load_and_run/src/helpers/json_loader.cpp @@ -0,0 +1,297 @@ +/** + * \file lite/load_and_run/src/helpers/json_loader.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "json_loader.h" + +using namespace mgb; + +template +T* JsonLoader::Value::safe_cast() { + T* ptr = (T*)(this); + if (nullptr == ptr) { + fprintf(stderr, "cast ptr is null\n"); + } + return ptr; +} + +std::unique_ptr& JsonLoader::Value::operator[]( + const std::string& key) { + mgb_assert(Type::OBJECT == m_type); + auto t = safe_cast(); + return t->m_obj.at(key); +} + +std::unique_ptr& JsonLoader::Value::operator[](const size_t index) { + mgb_assert(Type::ARRAY == m_type); + auto t = safe_cast(); + return t->m_obj[index]; +} + +std::map>& JsonLoader::Value:: + objects() { + mgb_assert(Type::OBJECT == m_type); + auto t = safe_cast(); + return t->m_obj; +} + +size_t JsonLoader::Value::len() { + if (Type::ARRAY == m_type) { + auto t = safe_cast(); + return t->m_obj.size(); + } else if (Type::OBJECT == m_type) { + auto t = safe_cast(); + return t->m_obj.size(); + } + return 0; +} + +megdnn::SmallVector>& JsonLoader::Value::array() { + mgb_assert(Type::ARRAY == m_type); + auto t = safe_cast(); + return t->m_obj; +} + +double JsonLoader::Value::number() { + mgb_assert(Type::NUMBER == m_type); + auto t = safe_cast(); + return t->value(); +} + +std::string JsonLoader::Value::str() { + if (Type::STRING == m_type) { + auto t = safe_cast(); + return t->value(); + } + return std::string(); +} + +void JsonLoader::expect(char c) { + mgb_assert(c == (*m_buf)); + m_buf++; +} + +void JsonLoader::skip_whitespace() { + const char* p = m_buf; + while (*p == ' ' || *p == '\t' || *p == '\n' || *p == '\r') { + ++p; + } + m_buf = p; +} + +std::unique_ptr JsonLoader::parse_object() { + expect('{'); + skip_whitespace(); + + std::unique_ptr ret; + JsonLoader::ObjectValue* pObject = new JsonLoader::ObjectValue(); + + if ('}' == *m_buf) { + m_buf = m_buf + 1; + ret.reset((JsonLoader::Value*)(pObject)); + return ret; + } + + while (true) { + std::unique_ptr key = parse_string(); + if (m_state != State::OK) { + return ret; + } + + skip_whitespace(); + if (':' != (*m_buf)) { + m_state = State::MISS_COLON; + return ret; + } + m_buf++; + skip_whitespace(); + + std::unique_ptr pVal = parse_value(); + if (m_state != State::OK) { + return ret; + } + + if (pObject->m_obj.find(pVal->str()) != pObject->m_obj.end()) { + m_state = State::KEY_NOT_UNIQUE; + return ret; + } + + pObject->m_obj.insert(std::make_pair(key->str(), std::move(pVal))); + + skip_whitespace(); + if (',' == (*m_buf)) { + m_buf++; + skip_whitespace(); + } else if ('}' == (*m_buf)) { + m_buf++; + break; + } else { + m_state = State::MISS_BRACE; + break; + } + } + + ret.reset((JsonLoader::Value*)(pObject)); + return ret; +} + +std::unique_ptr JsonLoader::parse_array() { + expect('['); + skip_whitespace(); + + std::unique_ptr ret; + JsonLoader::ArrayValue* pArray = new JsonLoader::ArrayValue(); + + if (']' == *m_buf) { + m_buf = m_buf + 1; + + ret.reset((JsonLoader::Value*)(pArray)); + return ret; + } + + while (true) { + std::unique_ptr pVal = parse_value(); + if (m_state != State::OK) { + mgb_assert(0, "parse value failed during pase array"); + return ret; + } + + pArray->m_obj.emplace_back(pVal.get()); + pVal.release(); + + skip_whitespace(); + if (',' == *m_buf) { + m_buf++; + skip_whitespace(); + } else if (']' == *m_buf) { + m_buf++; + break; + } else { + m_state = State::BAD_ARRAY; + return ret; + } + } + + ret.reset((JsonLoader::Value*)(pArray)); + return ret; +} + +std::unique_ptr JsonLoader::parse_string() { + expect('\"'); + + std::unique_ptr ret; + JsonLoader::StringValue* pStr = new JsonLoader::StringValue(); + + const char* p = m_buf; + while (true) { + if (*p == '\"') { + p++; + break; + } else { + pStr->m_value += (*p); + p++; + } + } + m_buf = p; + ret.reset((JsonLoader::Value*)(pStr)); + return ret; +} + +std::unique_ptr JsonLoader::parse_number() { + const char* p = m_buf; + + auto loop_digit = [this](const char*& p) { + if (not std::isdigit(*p)) { + m_state = State::BAD_DIGIT; + return; + } + while (std::isdigit(*p)) { + p++; + } + return; + }; + + if (*p == '-') + p++; + if (*p == '0') + p++; + else { + loop_digit(std::ref(p)); + } + if (*p == '.') { + p++; + loop_digit(std::ref(p)); + } + + if (*p == 'e' || *p == 'E') { + p++; + if (*p == '+' || *p == '-') + p++; + loop_digit(std::ref(p)); + } + JsonLoader::NumberValue* pNum = new JsonLoader::NumberValue(); + pNum->m_value = strtod(m_buf, nullptr); + + m_buf = p; + + std::unique_ptr ret; + ret.reset((JsonLoader::Value*)(pNum)); + return ret; +} + +std::unique_ptr JsonLoader::parse_value() { + switch (*m_buf) { + case '[': + return parse_array(); + case '{': + return parse_object(); + case '\"': + return parse_string(); + case '\0': + m_state = State::BAD_TYPE; + break; + default: + return parse_number(); + } + return nullptr; +} + +std::unique_ptr JsonLoader::load( + const char* content, const size_t size) { + m_buf = content; + skip_whitespace(); + std::unique_ptr value = parse_value(); + skip_whitespace(); + + if (m_state != State::OK) { + return nullptr; + } + mgb_assert(size == static_cast(m_buf - content)); + + return value; +} + +std::unique_ptr JsonLoader::load(const char* path) { + std::unique_ptr fin( + std::fopen(path, "rb"), [](std::FILE* fp) { std::fclose(fp); }); + + mgb_assert(fin.get(), "failed to open %s: %s", path, strerror(errno)); + std::fseek(fin.get(), 0, SEEK_END); + const size_t size = ftell(fin.get()); + std::fseek(fin.get(), 0, SEEK_SET); + + std::unique_ptr buf(static_cast(malloc(size))); + + auto nr = std::fread(buf.get(), 1, size, fin.get()); + mgb_assert(nr == size); + + return load(buf.get(), size); +} diff --git a/lite/load_and_run/src/helpers/json_loader.h b/lite/load_and_run/src/helpers/json_loader.h new file mode 100644 index 000000000..f05b689ab --- /dev/null +++ b/lite/load_and_run/src/helpers/json_loader.h @@ -0,0 +1,183 @@ +/** + * \file lite/load_and_run/src/helpers/json_loader.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "megbrain/common.h" +#include "megdnn/thin/small_vector.h" + +namespace mgb { +/*! + * \brief JSON format data loader for --input + */ +class JsonLoader { +public: + // base class for different value format + class Value { + protected: + enum struct Type : uint8_t { UNKNOWN, NUMBER, STRING, OBJECT, ARRAY }; + Type m_type; + + public: + template + T* safe_cast(); + + Value() { m_type = Type::UNKNOWN; } + + Value(Type type) : m_type(type) {} + + virtual ~Value() {} + + bool is_array() { return Type::ARRAY == m_type; } + + bool is_object() { return Type::OBJECT == m_type; } + + bool is_number() { return Type::NUMBER == m_type; } + + bool is_str() { return Type::STRING == m_type; } + + std::unique_ptr& operator[](const std::string& key); + + std::unique_ptr& operator[](const size_t index); + + std::map>& objects(); + + size_t len(); + + megdnn::SmallVector>& array(); + + double number(); + + std::string str(); + }; + + void expect(char c); + + void skip_whitespace(); + + std::unique_ptr parse_object(); + + std::unique_ptr parse_array(); + + std::unique_ptr parse_string(); + + std::unique_ptr parse_number(); + + std::unique_ptr parse_value(); + + enum struct State : uint8_t { + OK = 0, + BAD_TYPE, + BAD_DIGIT, + BAD_ARRAY, + MISS_COLON, + MISS_BRACE, + KEY_NOT_UNIQUE + }; + + JsonLoader() { m_state = State::OK; } + + std::unique_ptr load(const char* content, const size_t size); + + std::unique_ptr load(const char* path); + + class NumberValue final : public Value { + friend std::unique_ptr JsonLoader::parse_number(); + double m_value; + + public: + NumberValue() : Value(Type::NUMBER) {} + + double value() { return m_value; } + }; + + class StringValue final : public Value { + std::string m_value; + + public: + StringValue() : Value(Type::STRING) {} + + std::string value() { return m_value; } + + friend std::unique_ptr JsonLoader::parse_string(); + }; + + class ArrayValue final : public Value { + megdnn::SmallVector> m_obj; + + public: + ArrayValue() : Value(Type::ARRAY) {} + + ArrayValue(ArrayValue& arr) : Value(arr) { + m_obj.clear(); + for (auto& item : arr.m_obj) { + m_obj.emplace_back(item.get()); + item.release(); + } + } + + ArrayValue(ArrayValue&& arr) : Value(arr) { + m_obj.clear(); + for (auto& item : arr.m_obj) { + m_obj.emplace_back(item.get()); + item.release(); + } + } + + friend std::unique_ptr JsonLoader::parse_array(); + friend std::unique_ptr& JsonLoader::Value::operator[]( + const size_t index); + friend megdnn::SmallVector>& JsonLoader:: + Value::array(); + friend size_t JsonLoader::Value::len(); + }; + + class ObjectValue final : public Value { + std::map> m_obj; + + public: + ObjectValue() : Value(Type::OBJECT) {} + + ObjectValue(ObjectValue& arr) : Value(arr) { + m_obj.clear(); + for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { + m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); + } + } + + ObjectValue(ObjectValue&& arr) : Value(arr) { + m_obj.clear(); + for (auto itra = arr.m_obj.begin(); itra != arr.m_obj.end(); ++itra) { + m_obj.emplace(std::make_pair(itra->first, std::move(itra->second))); + } + } + + friend std::unique_ptr JsonLoader::parse_object(); + friend std::unique_ptr& JsonLoader::Value::operator[]( + const std::string&); + friend std::map>& JsonLoader:: + Value::objects(); + friend size_t JsonLoader::Value::len(); + }; + +private: + const char* m_buf; + State m_state; +}; + +} // namespace mgb diff --git a/lite/load_and_run/src/helpers/npy.h b/lite/load_and_run/src/helpers/npy.h new file mode 100644 index 000000000..afdef31c7 --- /dev/null +++ b/lite/load_and_run/src/helpers/npy.h @@ -0,0 +1,615 @@ +/* + Copyright 2017 Leon Merten Lohse + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in + all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +#ifndef NPY_H +#define NPY_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace npy { + +/* Compile-time test for byte order. + If your compiler does not define these per default, you may want to define + one of these constants manually. + Defaults to little endian order. */ +#if defined(__BYTE_ORDER) && __BYTE_ORDER == __BIG_ENDIAN || \ + defined(__BIG_ENDIAN__) || defined(__ARMEB__) || defined(__THUMBEB__) || \ + defined(__AARCH64EB__) || defined(_MIBSEB) || defined(__MIBSEB) || \ + defined(__MIBSEB__) +const bool big_endian = true; +#else +const bool big_endian = false; +#endif + +const char magic_string[] = "\x93NUMPY"; +const size_t magic_string_length = 6; + +const char little_endian_char = '<'; +const char big_endian_char = '>'; +const char no_endian_char = '|'; + +constexpr char host_endian_char = (big_endian ? big_endian_char : little_endian_char); + +/* npy array length */ +typedef unsigned long int ndarray_len_t; + +inline void write_magic( + std::ostream& ostream, unsigned char v_major = 1, unsigned char v_minor = 0) { + ostream.write(magic_string, magic_string_length); + ostream.put(v_major); + ostream.put(v_minor); +} + +inline void read_magic( + std::istream& istream, unsigned char& v_major, unsigned char& v_minor) { + char buf[magic_string_length + 2]; + istream.read(buf, magic_string_length + 2); + + if (!istream) { + fprintf(stderr, "io error: failed reading file"); + } + + if (0 != std::memcmp(buf, magic_string, magic_string_length)) { + fprintf(stderr, "this file does not have a valid npy format."); + } + + v_major = buf[magic_string_length]; + v_minor = buf[magic_string_length + 1]; +} + +// typestring magic +struct Typestring { +private: + char c_endian; + char c_type; + int len; + +public: + inline std::string str() { + const size_t max_buflen = 16; + char buf[max_buflen]; + std::sprintf(buf, "%c%c%u", c_endian, c_type, len); + return std::string(buf); + } + + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'f'}, len{sizeof(float)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'f'}, len{sizeof(double)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'f'}, len{sizeof(long double)} {} + + Typestring(const std::vector&) + : c_endian{no_endian_char}, c_type{'i'}, len{sizeof(char)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(short)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(int)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(long)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'i'}, len{sizeof(long long)} {} + + Typestring(const std::vector&) + : c_endian{no_endian_char}, c_type{'u'}, len{sizeof(unsigned char)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned short)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned int)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, c_type{'u'}, len{sizeof(unsigned long)} {} + Typestring(const std::vector&) + : c_endian{host_endian_char}, + c_type{'u'}, + len{sizeof(unsigned long long)} {} + + Typestring(const std::vector>&) + : c_endian{host_endian_char}, + c_type{'c'}, + len{sizeof(std::complex)} {} + Typestring(const std::vector>&) + : c_endian{host_endian_char}, + c_type{'c'}, + len{sizeof(std::complex)} {} + Typestring(const std::vector>&) + : c_endian{host_endian_char}, + c_type{'c'}, + len{sizeof(std::complex)} {} +}; + +inline void parse_typestring(std::string typestring) { + std::regex re("'([<>|])([ifuc])(\\d+)'"); + std::smatch sm; + + std::regex_match(typestring, sm, re); + + if (sm.size() != 4) { + fprintf(stderr, "invalid typestring"); + } +} + +namespace pyparse { + +/** + Removes leading and trailing whitespaces + */ +inline std::string trim(const std::string& str) { + const std::string whitespace = " \t"; + auto begin = str.find_first_not_of(whitespace); + + if (begin == std::string::npos) + return ""; + + auto end = str.find_last_not_of(whitespace); + + return str.substr(begin, end - begin + 1); +} + +inline std::string get_value_from_map(const std::string& mapstr) { + size_t sep_pos = mapstr.find_first_of(":"); + if (sep_pos == std::string::npos) + return ""; + + std::string tmp = mapstr.substr(sep_pos + 1); + return trim(tmp); +} + +/** + Parses the string representation of a Python dict + + The keys need to be known and may not appear anywhere else in the data. + */ +inline std::unordered_map parse_dict( + std::string in, std::vector& keys) { + std::unordered_map map; + + if (keys.size() == 0) + return map; + + in = trim(in); + + // unwrap dictionary + if ((in.front() == '{') && (in.back() == '}')) + in = in.substr(1, in.length() - 2); + else { + fprintf(stderr, "Not a Python dictionary."); + } + + std::vector> positions; + + for (auto const& value : keys) { + size_t pos = in.find("'" + value + "'"); + + if (pos == std::string::npos) { + fprintf(stderr, "Missing %s key.", value.c_str()); + } + + std::pair position_pair{pos, value}; + positions.push_back(position_pair); + } + + // sort by position in dict + std::sort(positions.begin(), positions.end()); + + for (size_t i = 0; i < positions.size(); ++i) { + std::string raw_value; + size_t begin{positions[i].first}; + size_t end{std::string::npos}; + + std::string key = positions[i].second; + + if (i + 1 < positions.size()) + end = positions[i + 1].first; + + raw_value = in.substr(begin, end - begin); + + raw_value = trim(raw_value); + + if (raw_value.back() == ',') + raw_value.pop_back(); + + map[key] = get_value_from_map(raw_value); + } + + return map; +} + +/** + Parses the string representation of a Python boolean + */ +inline bool parse_bool(const std::string& in) { + if (in == "True") + return true; + if (in == "False") + return false; + + fprintf(stderr, "Invalid python boolan."); + return false; +} + +/** + Parses the string representation of a Python str + */ +inline std::string parse_str(const std::string& in) { + if ((in.front() == '\'') && (in.back() == '\'')) + return in.substr(1, in.length() - 2); + + fprintf(stderr, "Invalid python string."); + return ""; +} + +/** + Parses the string represenatation of a Python tuple into a vector of its items + */ +inline std::vector parse_tuple(std::string in) { + std::vector v; + const char seperator = ','; + + in = trim(in); + + if ((in.front() == '(') && (in.back() == ')')) + in = in.substr(1, in.length() - 2); + else { + fprintf(stderr, "Invalid Python tuple."); + } + + std::istringstream iss(in); + + for (std::string token; std::getline(iss, token, seperator);) { + v.push_back(token); + } + + return v; +} + +template +inline std::string write_tuple(const std::vector& v) { + if (v.size() == 0) + return ""; + + std::ostringstream ss; + + if (v.size() == 1) { + ss << "(" << v.front() << ",)"; + } else { + const std::string delimiter = ", "; + // v.size() > 1 + ss << "("; + std::copy( + v.begin(), v.end() - 1, + std::ostream_iterator(ss, delimiter.c_str())); + ss << v.back(); + ss << ")"; + } + + return ss.str(); +} + +inline std::string write_boolean(bool b) { + if (b) + return "True"; + else + return "False"; +} + +} // namespace pyparse + +inline void parse_header(std::string header, std::string& descr) { + /* + The first 6 bytes are a magic string: exactly "x93NUMPY". + The next 1 byte is an unsigned byte: the major version number of the file + format, e.g. x01. The next 1 byte is an unsigned byte: the minor version + number of the file format, e.g. x00. Note: the version of the file format + is not tied to the version of the numpy package. The next 2 bytes form a + little-endian unsigned short int: the length of the header data + HEADER_LEN. The next HEADER_LEN bytes form the header data describing the + array's format. It is an ASCII string which contains a Python literal + expression of a dictionary. It is terminated by a newline ('n') and + padded with spaces + ('x20') to make the total length of the magic string + 4 + HEADER_LEN be + evenly divisible by 16 for alignment purposes. The dictionary contains + three keys: + + "descr" : dtype.descr + An object that can be passed as an argument to the numpy.dtype() + constructor to create the array's dtype. For repeatability and + readability, this dictionary is formatted using pprint.pformat() so the + keys are in alphabetic order. + */ + + // remove trailing newline + if (header.back() != '\n') + fprintf(stderr, "invalid header"); + header.pop_back(); + + // parse the dictionary + std::vector keys{"descr"}; + auto dict_map = npy::pyparse::parse_dict(header, keys); + + if (dict_map.size() == 0) + fprintf(stderr, "invalid dictionary in header"); + + std::string descr_s = dict_map["descr"]; + parse_typestring(descr_s); + // remove + descr = npy::pyparse::parse_str(descr_s); + return; +} + +inline void parse_header( + std::string header, std::string& descr, bool& fortran_order, + std::vector& shape) { + /* + The first 6 bytes are a magic string: exactly "x93NUMPY". + The next 1 byte is an unsigned byte: the major version number of the file + format, e.g. x01. The next 1 byte is an unsigned byte: the minor version + number of the file format, e.g. x00. Note: the version of the file format + is not tied to the version of the numpy package. The next 2 bytes form a + little-endian unsigned short int: the length of the header data + HEADER_LEN. The next HEADER_LEN bytes form the header data describing the + array's format. It is an ASCII string which contains a Python literal + expression of a dictionary. It is terminated by a newline ('n') and + padded with spaces + ('x20') to make the total length of the magic string + 4 + HEADER_LEN be + evenly divisible by 16 for alignment purposes. The dictionary contains + three keys: + + "descr" : dtype.descr + An object that can be passed as an argument to the numpy.dtype() + constructor to create the array's dtype. "fortran_order" : bool Whether + the array data is Fortran-contiguous or not. Since Fortran-contiguous + arrays are a common form of non-C-contiguity, we allow them to be written + directly to disk for efficiency. "shape" : tuple of int The shape of the + array. For repeatability and readability, this dictionary is formatted + using pprint.pformat() so the keys are in alphabetic order. + */ + + // remove trailing newline + if (header.back() != '\n') + fprintf(stderr, "invalid header"); + header.pop_back(); + + // parse the dictionary + std::vector keys{"descr", "fortran_order", "shape"}; + auto dict_map = npy::pyparse::parse_dict(header, keys); + + if (dict_map.size() == 0) + fprintf(stderr, "invalid dictionary in header"); + + std::string descr_s = dict_map["descr"]; + std::string fortran_s = dict_map["fortran_order"]; + std::string shape_s = dict_map["shape"]; + + // TODO: extract info from typestring + parse_typestring(descr_s); + // remove + descr = npy::pyparse::parse_str(descr_s); + + // convert literal Python bool to C++ bool + fortran_order = npy::pyparse::parse_bool(fortran_s); + + // parse the shape tuple + auto shape_v = npy::pyparse::parse_tuple(shape_s); + if (shape_v.size() == 0) + fprintf(stderr, "invalid shape tuple in header"); + + for (auto item : shape_v) { + ndarray_len_t dim = static_cast(std::stoul(item)); + shape.push_back(dim); + } +} + +inline std::string write_header_dict( + const std::string& descr, bool fortran_order, + const std::vector& shape) { + std::string s_fortran_order = npy::pyparse::write_boolean(fortran_order); + std::string shape_s = npy::pyparse::write_tuple(shape); + + return "{'descr': '" + descr + "', 'fortran_order': " + s_fortran_order + + ", 'shape': " + shape_s + ", }"; +} + +inline void write_header( + std::ostream& out, const std::string& descr, bool fortran_order, + const std::vector& shape_v) { + std::string header_dict = write_header_dict(descr, fortran_order, shape_v); + + size_t length = magic_string_length + 2 + 2 + header_dict.length() + 1; + + unsigned char version[2] = {1, 0}; + if (length >= 255 * 255) { + length = magic_string_length + 2 + 4 + header_dict.length() + 1; + version[0] = 2; + version[1] = 0; + } + size_t padding_len = 16 - length % 16; + std::string padding(padding_len, ' '); + + // write magic + write_magic(out, version[0], version[1]); + + // write header length + if (version[0] == 1 && version[1] == 0) { + char header_len_le16[2]; + uint16_t header_len = + static_cast(header_dict.length() + padding.length() + 1); + + header_len_le16[0] = (header_len >> 0) & 0xff; + header_len_le16[1] = (header_len >> 8) & 0xff; + out.write(reinterpret_cast(header_len_le16), 2); + } else { + char header_len_le32[4]; + uint32_t header_len = + static_cast(header_dict.length() + padding.length() + 1); + + header_len_le32[0] = (header_len >> 0) & 0xff; + header_len_le32[1] = (header_len >> 8) & 0xff; + header_len_le32[2] = (header_len >> 16) & 0xff; + header_len_le32[3] = (header_len >> 24) & 0xff; + out.write(reinterpret_cast(header_len_le32), 4); + } + + out << header_dict << padding << '\n'; +} + +inline std::string read_header(std::istream& istream) { + // check magic bytes an version number + unsigned char v_major, v_minor; + read_magic(istream, v_major, v_minor); + + uint32_t header_length = 0; + if (v_major == 1 && v_minor == 0) { + char header_len_le16[2]; + istream.read(header_len_le16, 2); + header_length = (header_len_le16[0] << 0) | (header_len_le16[1] << 8); + + if ((magic_string_length + 2 + 2 + header_length) % 16 != 0) { + // TODO: display warning + } + } else if (v_major == 2 && v_minor == 0) { + char header_len_le32[4]; + istream.read(header_len_le32, 4); + + header_length = (header_len_le32[0] << 0) | (header_len_le32[1] << 8) | + (header_len_le32[2] << 16) | (header_len_le32[3] << 24); + + if ((magic_string_length + 2 + 4 + header_length) % 16 != 0) { + // TODO: display warning + } + } else { + fprintf(stderr, "unsupported file format version"); + } + + auto buf_v = std::vector(); + buf_v.reserve(header_length); + istream.read(buf_v.data(), header_length); + std::string header(buf_v.data(), header_length); + + return header; +} + +inline ndarray_len_t comp_size(const std::vector& shape) { + ndarray_len_t size = 1; + for (ndarray_len_t i : shape) + size *= i; + + return size; +} + +template +inline void SaveArrayAsNumpy( + const std::string& filename, bool fortran_order, unsigned int n_dims, + const unsigned long shape[], const std::vector& data) { + Typestring typestring_o(data); + std::string typestring = typestring_o.str(); + + std::ofstream stream(filename, std::ofstream::binary); + if (!stream) { + fprintf(stderr, "io error: failed to open a file."); + } + + std::vector shape_v(shape, shape + n_dims); + write_header(stream, typestring, fortran_order, shape_v); + + auto size = static_cast(comp_size(shape_v)); + + stream.write(reinterpret_cast(data.data()), sizeof(Scalar) * size); +} + +template +inline void LoadArrayFromNumpy( + const std::string& filename, std::vector& shape, + std::vector& data) { + bool fortran_order; + LoadArrayFromNumpy(filename, shape, fortran_order, data); +} + +template +inline void LoadArrayFromNumpy( + const std::string& filename, std::vector& shape, + bool& fortran_order, std::vector& data) { + std::ifstream stream(filename, std::ifstream::binary); + if (!stream) { + fprintf(stderr, "io error: failed to open a file."); + } + + std::string header = read_header(stream); + + // parse header + std::string typestr; + + parse_header(header, typestr, fortran_order, shape); + + // check if the typestring matches the given one + Typestring typestring_o{data}; + std::string expect_typestr = typestring_o.str(); + if (typestr != expect_typestr) { + fprintf(stderr, "formatting error: typestrings not matching"); + } + + // compute the data size based on the shape + auto size = static_cast(comp_size(shape)); + data.resize(size); + + // read the data + stream.read(reinterpret_cast(data.data()), sizeof(Scalar) * size); +} + +inline void LoadArrayFromNumpy( + const std::string& filename, std::string& type_str, + std::vector& shape, std::vector& data) { + std::ifstream stream(filename, std::ifstream::binary); + if (!stream) { + fprintf(stderr, "io error: failed to open a file."); + } + + std::string header = read_header(stream); + bool fortran_order; + // parse header + parse_header(header, type_str, fortran_order, shape); + + // check if the typestring matches the given one + std::string size_str = type_str.substr(type_str.size() - 1); + size_t elem_size = atoi(size_str.c_str()); + + // compute the data size based on the shape + auto byte_size = elem_size * static_cast(comp_size(shape)); + data.resize(byte_size); + + // read the data + stream.read(reinterpret_cast(data.data()), byte_size); +} + +} // namespace npy + +#endif // NPY_H diff --git a/lite/load_and_run/src/helpers/outdumper.cpp b/lite/load_and_run/src/helpers/outdumper.cpp new file mode 100644 index 000000000..9a5d8315b --- /dev/null +++ b/lite/load_and_run/src/helpers/outdumper.cpp @@ -0,0 +1,48 @@ +/** + * \file lite/load_and_run/src/helpers/outdumper.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + */ + +#include "outdumper.h" +#include "megbrain/utils/debug.h" + +using namespace lar; + +void OutputDumper::set(mgb::SymbolVarArray& symb_var) { + for (auto&& i : symb_var) { + auto&& var = i.node(); + DumpInfo info; + info.var_info = mgb::cg::dump_var_info({var}); + info.owner_inputs_info = mgb::cg::dump_var_info(var->owner_opr()->input()); + info.id = var->id(); + m_infos.push_back(info); + } +} + +mgb::ComputingGraph::Callback OutputDumper::bind() { + auto& info = m_infos.at(m_bind_id++); + mgb::ComputingGraph::Callback cb = [&info](const mgb::DeviceTensorND& dv) { + info.hv.copy_from(dv); + }; + return cb; +} + +void OutputDumper::write_to_file() { + if (!dump_file.empty()) { + for (auto&& info : m_infos) { + auto value = mgb::debug::dump_tensor( + info.hv, + mgb::ssprintf( + "var=%s owner_opr_inputs= %s", info.var_info.c_str(), + info.owner_inputs_info.c_str())); + mgb::debug::write_to_file( + mgb::ssprintf( + "%s/run%zu-var %zd", dump_file.c_str(), m_run_id, info.id) + .c_str(), + value); + } + } + m_run_id++; +} diff --git a/lite/load_and_run/src/helpers/outdumper.h b/lite/load_and_run/src/helpers/outdumper.h new file mode 100644 index 000000000..d08496b4f --- /dev/null +++ b/lite/load_and_run/src/helpers/outdumper.h @@ -0,0 +1,42 @@ +/** + * \file lite/load_and_run/src/helpers/outdumper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include "megbrain/serialization/serializer.h" + +namespace lar { + +/*! + * \brief dumper for only output used for --bin-out-dump + */ +class OutputDumper { +public: + struct DumpInfo { + mgb::HostTensorND hv = {}; + std::string var_info; + std::string owner_inputs_info; + size_t id; + }; + //! init the dump_file path + OutputDumper(const char* file) { dump_file = file; } + + //! set the dump informations + void set(mgb::SymbolVarArray& symb_var); + + //! callback function for specify output when compile computing graph + mgb::ComputingGraph::Callback bind(); + + //! write dumped output into dump_file + void write_to_file(); + +private: + mgb::SmallVector m_infos; + size_t m_run_id = 0; + size_t m_bind_id = 0; + std::string dump_file; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/load_and_run/src/helpers/text_table.cpp b/lite/load_and_run/src/helpers/text_table.cpp new file mode 100644 index 000000000..86331abd1 --- /dev/null +++ b/lite/load_and_run/src/helpers/text_table.cpp @@ -0,0 +1,119 @@ +/** + * \file lite/load_and_run/src/helpers/text_table.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "text_table.h" + +using namespace mgb; + +namespace { +inline void mid(std::ostream& os, const std::string& str, size_t max_w) { + size_t l = (max_w - str.length()) / 2 + str.length(); + size_t r = max_w - l; + os << std::setw(l) << std::right << str; + if (r > 0) + os << std::setw(r) << ' '; +} +inline size_t char_length(char c) { + return c ? 1 : 0; +} +} // namespace + +void TextTable::adjuster_last_row() { + if (m_rows.empty()) + return; + auto& row = m_rows.back(); + if (row.params.horizontal == 0 or row.params.vertical == 0) { + row.params.corner = 0; + } + if (row.params.horizontal != 0 && row.params.vertical != 0 && + row.params.corner == 0) { + row.params.corner = row.params.horizontal; + } +} + +void TextTable::show(std::ostream& os) { + if (m_rows.empty()) + return; + auto& last_row = m_rows.front(); + bool first = true; + for (auto& row : m_rows) { + auto& lrow = + (last_row.values.size() * char_length(last_row.params.horizontal)) > + (row.values.size() * char_length(row.params.horizontal)) + ? last_row + : row; + // line before row + if (lrow.params.horizontal) { + if (not first) + os << std::endl; + os << m_prefix; + if (lrow.params.corner) + os << lrow.params.corner; + size_t skip_size = 0; + // table name + if (first) { + os << m_name; + skip_size = m_name.length(); + } + for (size_t i = 0; i < lrow.values.size(); ++i) { + auto max_w = m_cols_max_w.at(i) + m_padding * 2; + if (max_w + char_length(lrow.params.corner) <= skip_size) { + skip_size = skip_size - max_w - char_length(lrow.params.corner); + continue; + } + size_t rest = max_w + char_length(lrow.params.corner) - skip_size; + skip_size = 0; + if (rest > char_length(lrow.params.corner)) { + os << std::string( + rest - char_length(lrow.params.corner), + lrow.params.horizontal); + rest = char_length(lrow.params.corner); + } + if (rest > 0 && lrow.params.corner) + os << lrow.params.corner; + } + } else if (first) { + os << m_prefix << ' ' << m_name; + } + first = false; + os << std::endl << m_prefix; + if (row.params.vertical) + os << row.params.vertical; + // row + for (size_t i = 0; i < row.values.size(); ++i) { + auto& str = row.values.at(i); + auto max_w = m_cols_max_w.at(i) + 2 * m_padding; + if (row.params.align == Align::Mid) { + mid(os, str, max_w); + } else if (row.params.align == Align::Left) { + os << std::setw(max_w) << std::left << str; + } else { + os << std::setw(max_w) << std::right << str; + } + if (row.params.vertical) + os << row.params.vertical; + } + last_row = row; + } + if (last_row.params.horizontal) { + os << std::endl << m_prefix; + if (last_row.params.corner) + os << last_row.params.corner; + for (size_t i = 0; i < last_row.values.size(); ++i) { + auto max_w = m_cols_max_w.at(i); + std::string tmp(max_w + m_padding * 2, last_row.params.horizontal); + os << tmp; + if (last_row.params.corner) + os << last_row.params.corner; + } + } +} \ No newline at end of file diff --git a/lite/load_and_run/src/helpers/text_table.h b/lite/load_and_run/src/helpers/text_table.h new file mode 100644 index 000000000..25a3074e6 --- /dev/null +++ b/lite/load_and_run/src/helpers/text_table.h @@ -0,0 +1,133 @@ +/** + * \file lite/load_and_run/src/helpers/text_table.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "megbrain/common.h" + +namespace mgb { + +class TextTable { +public: + enum Level { Summary, Detail }; + enum class Align : int { Left, Right, Mid }; + explicit TextTable(const std::string& table_name) : m_name(table_name) {} + TextTable& horizontal(char c) { + m_row.params.horizontal = c; + return *this; + } + TextTable& vertical(char c) { + m_row.params.vertical = c; + return *this; + } + TextTable& corner(char c) { + m_row.params.corner = c; + return *this; + } + TextTable& align(Align v) { + m_row.params.align = v; + return *this; + } + TextTable& padding(size_t w) { + m_padding = w; + return *this; + } + TextTable& prefix(const std::string& str) { + m_prefix = str; + return *this; + } + + template + TextTable& add(const T& value) { + m_row.values.emplace_back(value); + if (m_cols_max_w.size() < m_row.values.size()) { + m_cols_max_w.emplace_back(m_row.values.back().length()); + } else { + mgb_assert(m_row.values.size() >= 1); + size_t i = m_row.values.size() - 1; + m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); + } + return *this; + } + + template < + typename T, + typename std::enable_if::value, bool>::type = 0> + TextTable& add(const T& value) { + std::stringstream ss; + ss << std::setiosflags(std::ios::fixed) << std::setprecision(2); + ss << value; + m_row.values.emplace_back(ss.str()); + if (m_cols_max_w.size() < m_row.values.size()) { + m_cols_max_w.emplace_back(m_row.values.back().length()); + } else { + mgb_assert(m_row.values.size() >= 1); + size_t i = m_row.values.size() - 1; + m_cols_max_w[i] = std::max(m_cols_max_w[i], m_row.values.back().length()); + } + return *this; + } + + template < + typename T, + typename std::enable_if::value, bool>::type = 0> + TextTable& add(const T& value) { + m_row.values.emplace_back(std::to_string(value)); + return *this; + } + + void eor() { + m_rows.emplace_back(m_row); + adjuster_last_row(); + m_row.values.clear(); + } + + void reset() { + m_row = {}; + m_cols_max_w.clear(); + m_padding = 0; + m_rows.clear(); + } + + void show(std::ostream& os); + +private: + void adjuster_last_row(); + std::string m_name; + std::vector m_cols_max_w; + size_t m_padding = 0; + std::string m_prefix = ""; + struct Row { + std::vector values; + struct Params { + Align align = Align::Left; + char horizontal = '-', vertical = '|', corner = '+'; + } params; + }; + std::vector m_rows; + Row m_row; +}; + +inline std::ostream& operator<<(std::ostream& stream, TextTable& table) { + table.show(stream); + return stream; +} + +} // namespace mgb \ No newline at end of file diff --git a/lite/load_and_run/src/main.cpp b/lite/load_and_run/src/main.cpp new file mode 100644 index 000000000..512429757 --- /dev/null +++ b/lite/load_and_run/src/main.cpp @@ -0,0 +1,31 @@ +/** + * \file lite/load_and_run/src/main.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include +#include +#include "strategys/strategy.h" + +int main(int argc, char** argv) { + std::string usage = "load_and_run [options...]"; + if (argc < 2) { + printf("usage: %s\n", usage.c_str()); + return -1; + } + gflags::SetUsageMessage(usage); + gflags::SetVersionString("1.0"); + gflags::ParseCommandLineFlags(&argc, &argv, true); + std::string model_path = argv[1]; + auto strategy = lar::StrategyBase::create_strategy(model_path); + strategy->run(); + gflags::ShutDownCommandLineFlags(); + + return 0; +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/load_and_run/src/models/model.cpp b/lite/load_and_run/src/models/model.cpp new file mode 100644 index 000000000..22d85d290 --- /dev/null +++ b/lite/load_and_run/src/models/model.cpp @@ -0,0 +1,60 @@ + +/** + * \file lite/load_and_run/src/models/model.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ +#include "model.h" +#include +#include +#include "model_lite.h" +#include "model_mdl.h" + +using namespace lar; + +ModelType ModelBase::get_model_type(std::string model_path) { + //! read magic number of dump file + FILE* fin = fopen(model_path.c_str(), "rb"); + mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); + char buf[16]; + mgb_assert(fread(buf, 1, 16, fin) == 16, "read model failed"); + fclose(fin); + + // get model type + // uint32_t MGB_MAGIC = 0x5342474D + std::string tag(buf); + ModelType type; + if (tag.substr(0, 7) == std::string("mgb0001") || + tag.substr(0, 8) == std::string("mgb0000a") || + tag.substr(0, 4) == std::string("MGBS") || + tag.substr(0, 8) == std::string("mgbtest0")) { + type = ModelType::MEGDL_MODEL; + + } else { + type = ModelType::LITE_MODEL; + } + + return type; +} + +std::shared_ptr ModelBase::create_model(std::string model_path) { + mgb_log_debug("model path %s\n", model_path.c_str()); + + auto model_type = get_model_type(model_path); + + if (ModelType::LITE_MODEL == model_type) { + return std::make_shared(model_path); + } else if (ModelType::MEGDL_MODEL == model_type) { + if (FLAGS_lite) + return std::make_shared(model_path); + else + return std::make_shared(model_path); + } else { + return nullptr; + } +} +DEFINE_bool(lite, false, "using lite model to run mdl model"); +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/load_and_run/src/models/model.h b/lite/load_and_run/src/models/model.h new file mode 100644 index 000000000..240574f2f --- /dev/null +++ b/lite/load_and_run/src/models/model.h @@ -0,0 +1,49 @@ +/** + * \file lite/load_and_run/src/models/model.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include +#include "helpers/common.h" + +DECLARE_bool(lite); + +namespace lar { +/*! + * \brief: base class of model + */ +class ModelBase { +public: + //! get model type by the magic number in dump file + static ModelType get_model_type(std::string model_path); + + //! create model by different model type + static std::shared_ptr create_model(std::string model_path); + + //! type of the model + virtual ModelType type() = 0; + + //! set model load state + + virtual void set_shared_mem(bool state) = 0; + + //! load model interface for load and run strategy + virtual void load_model() = 0; + + //! run model interface for load and run strategy + virtual void run_model() = 0; + + //! wait asynchronous function interface for load and run strategy + virtual void wait() = 0; + + virtual ~ModelBase() = default; +}; +} // namespace lar + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/load_and_run/src/models/model_lite.cpp b/lite/load_and_run/src/models/model_lite.cpp new file mode 100644 index 000000000..2cdf38341 --- /dev/null +++ b/lite/load_and_run/src/models/model_lite.cpp @@ -0,0 +1,50 @@ +/** + * \file lite/load_and_run/src/models/model_lite.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ +#include "model_lite.h" +#include +#include +#include "misc.h" + +DECLARE_bool(share_param_mem); + +using namespace lar; +ModelLite::ModelLite(const std::string& path) : model_path(path) { + LITE_WARN("creat lite model use CPU as default comp node"); +}; +void ModelLite::load_model() { + m_network = std::make_shared(config, IO); + if (share_model_mem) { + //! WARNNING:maybe not right to share param memmory for this + LITE_WARN("enable share model memory"); + + FILE* fin = fopen(model_path.c_str(), "rb"); + LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); + fseek(fin, 0, SEEK_END); + size_t size = ftell(fin); + fseek(fin, 0, SEEK_SET); + + void* ptr = malloc(size); + std::shared_ptr buf{ptr, free}; + auto nr = fread(buf.get(), 1, size, fin); + LITE_ASSERT(nr == size, "read model file failed"); + fclose(fin); + + m_network->load_model(buf.get(), size); + } else { + m_network->load_model(model_path); + } +} + +void ModelLite::run_model() { + m_network->forward(); +} + +void ModelLite::wait() { + m_network->wait(); +} diff --git a/lite/load_and_run/src/models/model_lite.h b/lite/load_and_run/src/models/model_lite.h new file mode 100644 index 000000000..66e7aa983 --- /dev/null +++ b/lite/load_and_run/src/models/model_lite.h @@ -0,0 +1,73 @@ +/** + * \file lite/load_and_run/src/models/model_lite.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include +#include "helpers/common.h" +#include "helpers/data_parser.h" +#include "lite/network.h" +#include "model.h" + +namespace lar { +/*! + * \brief: megengine lite model + */ +class ModelLite : public ModelBase { +public: + using Strategy = LiteAlgoSelectStrategy; + + ModelLite(const std::string& path); + //! model type + ModelType type() override { return ModelType::LITE_MODEL; } + + //! set to load from shared memory + void set_shared_mem(bool state) override { share_model_mem = state; } + + //! load model from dump file + void load_model() override; + + //! run model with given runtime parameter + void run_model() override; + + //! wait the end of asynchronous function execution + void wait() override; + + //! get the network of lite model + std::shared_ptr get_lite_network() { return m_network; } + + //! get the config of lite model + lite::Config& get_config() { return config; } + + //! get the networkIO of lite model + lite::NetworkIO& get_networkIO() { return IO; } + + //! get the data parser + DataParser& get_input_parser() { return parser; } + + //! set the strategy before load model + void set_lite_strategy(Strategy& u_strategy) { m_strategy = u_strategy; } + + //! get algo strategy + Strategy& get_lite_strategy() { return m_strategy; } + +private: + bool share_model_mem; + std::string model_path; + + DataParser parser; + lite::Config config; + lite::NetworkIO IO; + + std::shared_ptr m_network; + + Strategy m_strategy; +}; +} // namespace lar +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/load_and_run/src/models/model_mdl.cpp b/lite/load_and_run/src/models/model_mdl.cpp new file mode 100644 index 000000000..63fa6d732 --- /dev/null +++ b/lite/load_and_run/src/models/model_mdl.cpp @@ -0,0 +1,105 @@ +/** + * \file lite/load_and_run/src/models/model_mdl.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "model_mdl.h" +#include +#include + +DECLARE_bool(share_param_mem); + +using namespace lar; + +ModelMdl::ModelMdl(const std::string& path) : model_path(path) { + mgb_log_warn("creat mdl model use XPU as default comp node"); + m_load_config.comp_graph = mgb::ComputingGraph::make(); + m_load_config.comp_graph->options().graph_opt_level = 0; + testcase_num = 0; +} + +void ModelMdl::load_model() { + //! read dump file + if (share_model_mem) { + mgb_log_warn("enable share model memory"); + FILE* fin = fopen(model_path.c_str(), "rb"); + mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); + fseek(fin, 0, SEEK_END); + size_t size = ftell(fin); + fseek(fin, 0, SEEK_SET); + + void* ptr = malloc(size); + std::shared_ptr buf{ptr, free}; + auto nr = fread(buf.get(), 1, size, fin); + mgb_assert(nr == size, "read model file failed"); + fclose(fin); + + m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size); + } else { + m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str()); + } + + //! get dump_with_testcase model testcase number + char magic[8]; + m_model_file->read(magic, sizeof(magic)); + if (strncmp(magic, "mgbtest0", 8)) { + m_model_file->rewind(); + } else { + m_model_file->read(&testcase_num, sizeof(testcase_num)); + } + + auto format = + mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file); + mgb_assert( + format.valid(), + "invalid format, please make sure model is dumped by GraphDumper"); + + //! load computing graph of model + m_loader = mgb::serialization::GraphLoader::make( + std::move(m_model_file), format.val()); + m_load_result = m_loader->load(m_load_config, false); + m_load_config.comp_graph.reset(); + + // get testcase input generated by dump_with_testcase.py + if (testcase_num) { + for (auto&& i : m_load_result.tensor_map) { + test_input_tensors.emplace_back(i.first, i.second.get()); + } + std::sort(test_input_tensors.begin(), test_input_tensors.end()); + } + // initialize output callback + for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { + mgb::ComputingGraph::Callback cb; + m_callbacks.push_back(cb); + } +} + +void ModelMdl::make_output_spec() { + for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { + auto item = m_load_result.output_var_list[i]; + m_output_spec.emplace_back(item, std::move(m_callbacks[i])); + } + + m_asyc_exec = m_load_result.graph_compile(m_output_spec); +} + +std::shared_ptr& ModelMdl::reset_loader() { + m_loader = mgb::serialization::GraphLoader::make( + m_loader->reset_file(), m_loader->format()); + return m_loader; +} + +void ModelMdl::run_model() { + mgb_assert( + m_asyc_exec != nullptr, + "empty asychronous function to execute after graph compiled"); + m_asyc_exec->execute(); +} + +void ModelMdl::wait() { + m_asyc_exec->wait(); +} diff --git a/lite/load_and_run/src/models/model_mdl.h b/lite/load_and_run/src/models/model_mdl.h new file mode 100644 index 000000000..2d7923b4d --- /dev/null +++ b/lite/load_and_run/src/models/model_mdl.h @@ -0,0 +1,117 @@ +/** + * \file lite/load_and_run/src/models/model_mdl.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include "megbrain/opr/search_policy/algo_chooser_helper.h" +#include "megbrain/plugin/opr_io_dump.h" +#include "megbrain/serialization/extern_c_opr.h" +#include "megbrain/serialization/serializer.h" +#include "megbrain/utils/debug.h" + +#include "megbrain/plugin/num_range_checker.h" +#include "megbrain/plugin/profiler.h" + +#include "helpers/common.h" +#include "helpers/data_parser.h" +#include "model.h" + +namespace lar { + +class ModelMdl : public ModelBase { +public: + using Strategy = mgb::opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; + //! interface implement of ModelBase + ModelMdl(const std::string& path); + + ModelType type() override { return ModelType::MEGDL_MODEL; } + + void set_shared_mem(bool state) override { share_model_mem = state; } + + void load_model() override; + + void make_output_spec(); + + void run_model() override; + + void wait() override; + + //! get load result for megDL model + mgb::serialization::GraphLoader::LoadResult& get_mdl_load_result() { + return m_load_result; + } + + //! get load config for megDL model + mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; } + + //! reset the graph loader for dump_with_testcase model + std::shared_ptr& reset_loader(); + + //! algo strategy for runing model + void set_mdl_strategy(Strategy& u_strategy) { m_strategy = u_strategy; } + Strategy& get_mdl_strategy() { return m_strategy; } + + //! get data parser + DataParser& get_input_parser() { return parser; } + uint32_t get_testcase_num() { return testcase_num; } + std::vector>& get_test_input() { + return test_input_tensors; + } + + //! get output specified configuration + mgb::ComputingGraph::OutputSpec& get_output_spec() { return m_output_spec; } + std::unique_ptr& get_async_func() { return m_asyc_exec; } + + void set_output_callback(std::vector& cb) { + mgb_assert( + m_callbacks.size() == cb.size(), + "invalid output callback list to set!!"); + for (size_t i = 0; i < cb.size(); i++) { + m_callbacks[i] = cb[i]; + } + } +#if MGB_ENABLE_JSON + std::unique_ptr& get_profiler() { return m_profiler; } + void set_profiler() { + m_profiler = + std::make_unique(m_load_config.comp_graph.get()); + } +#endif + void set_num_range_checker(float range) { + m_num_range_checker = std::make_unique( + m_load_config.comp_graph.get(), range); + } + +private: + bool share_model_mem; + std::string model_path; + std::unique_ptr m_model_file; + mgb::serialization::GraphLoadConfig m_load_config; + + mgb::serialization::GraphLoader::LoadResult m_load_result; + std::shared_ptr m_loader; + std::unique_ptr m_asyc_exec; + + uint32_t testcase_num; + std::vector> test_input_tensors; + + DataParser parser; + Strategy m_strategy = Strategy::HEURISTIC; + std::vector m_callbacks; + mgb::ComputingGraph::OutputSpec m_output_spec; + + std::unique_ptr m_num_range_checker; +#if MGB_ENABLE_JSON + std::unique_ptr m_profiler; +#endif +}; + +} // namespace lar + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/load_and_run/src/options/device_options.cpp b/lite/load_and_run/src/options/device_options.cpp new file mode 100644 index 000000000..3365d8bc2 --- /dev/null +++ b/lite/load_and_run/src/options/device_options.cpp @@ -0,0 +1,200 @@ +/** + * \file lite/load_and_run/src/options/device_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include +#include +#include "lite/global.h" +#include "megbrain/comp_node_env.h" +#include "misc.h" +#include "device_options.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +DECLARE_bool(weight_preprocess); + +using namespace lar; + +/////////////////// XPUDeviceOption ////////////////////// +namespace lar { +template <> +void XPUDeviceOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) || + (enable_multithread_default)) { + LITE_WARN("using cpu device\n"); + model->get_config().device_type = LiteDeviceType::LITE_CPU; + } +#if MGE_WITH_CUDA + if (enable_cuda) { + model->get_config().device_type = LiteDeviceType::LITE_CUDA; + } +#endif + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + auto network = model->get_lite_network(); + if (enable_cpu_default) { + LITE_WARN("using cpu default device\n"); + lite::Runtime::set_cpu_inplace_mode(network); + } + if (enable_multithread) { + LITE_WARN("using multithread device\n"); + lite::Runtime::set_cpu_threads_number(network, thread_num); + } + if (enable_multithread_default) { + LITE_WARN("using multithread default device\n"); + lite::Runtime::set_cpu_inplace_mode(network); + lite::Runtime::set_cpu_threads_number(network, thread_num); + } + if (enable_set_core_ids) { + std::string core_str; + for (auto id : core_ids) { + core_str += std::to_string(id) + ","; + } + LITE_WARN("multi thread core ids: %s\n", core_str.c_str()); + lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) { + mgb::sys::set_cpu_affinity({core_ids[thread_id]}); + }; + lite::Runtime::set_runtime_thread_affinity(network, affinity_callback); + } + } +} + +template <> +void XPUDeviceOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (enable_cpu) { + mgb_log_warn("using cpu device\n"); + model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { + loc.type = mgb::CompNode::DeviceType::CPU; + }; + } +#if MGE_WITH_CUDA + if (enable_cuda) { + mgb_log_warn("using cuda device\n"); + model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { + loc.type = mgb::CompNode::DeviceType::CUDA; + }; + } +#endif + if (enable_cpu_default) { + mgb_log_warn("using cpu default device\n"); + model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) { + loc.type = mgb::CompNode::DeviceType::CPU; + loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT; + }; + } + if (enable_multithread) { + mgb_log_warn("using multithread device\n"); + model->get_mdl_config().comp_node_mapper = + [&](mgb::CompNode::Locator& loc) { + loc.type = mgb::CompNode::DeviceType::MULTITHREAD; + loc.device = 0; + loc.stream = thread_num; + }; + } + if (enable_multithread_default) { + mgb_log_warn("using multithread default device\n"); + model->get_mdl_config().comp_node_mapper = + [&](mgb::CompNode::Locator& loc) { + loc.type = mgb::CompNode::DeviceType::MULTITHREAD; + loc.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; + loc.stream = thread_num; + }; + } + if (enable_set_core_ids) { + std::string core_str; + for (auto id : core_ids) { + core_str += std::to_string(id) + ","; + } + mgb_log_warn("set multi thread core ids:%s\n", core_str.c_str()); + auto affinity_callback = [&](size_t thread_id) { + mgb::sys::set_cpu_affinity({core_ids[thread_id]}); + }; + mgb::CompNode::Locator loc; + model->get_mdl_config().comp_node_mapper(loc); + auto comp_node = mgb::CompNode::load(loc); + mgb::CompNodeEnv::from_comp_node(comp_node).cpu_env().set_affinity( + affinity_callback); + } + } +} +} // namespace lar + +XPUDeviceOption::XPUDeviceOption() { + m_option_name = "xpu_device"; + enable_cpu = FLAGS_cpu; +#if MGE_WITH_CUDA + enable_cuda = FLAGS_cuda; +#endif + enable_cpu_default = FLAGS_cpu_default; + + if (FLAGS_multithread >= 0) { + thread_num = FLAGS_multithread; + enable_multithread = true; + } + + if (FLAGS_multithread_default >= 0) { + thread_num = FLAGS_multithread_default; + enable_multithread_default = true; + } + + if (!FLAGS_multi_thread_core_ids.empty()) { + mgb_assert(enable_multithread, "core ids should be set after --multithread"); + std::stringstream id_stream(FLAGS_multi_thread_core_ids); + std::string id; + size_t thread_cnt = 0; + while (getline(id_stream, id, ',')) { + thread_cnt++; + core_ids.push_back(atoi(id.c_str())); + } + mgb_assert( + thread_cnt == thread_num, + "core ids number should be same with thread number set before"); + enable_set_core_ids = true; + } +} + +bool XPUDeviceOption::is_valid() { + bool ret = FLAGS_cpu || FLAGS_cpu_default; +#if MGE_WITH_CUDA + ret = ret || FLAGS_cuda; +#endif + ret = ret || FLAGS_multithread >= 0; + ret = ret || FLAGS_multithread_default >= 0; + ret = ret || !FLAGS_multi_thread_core_ids.empty(); + + return ret; +} + +std::shared_ptr XPUDeviceOption::create_option() { + static std::shared_ptr option(new XPUDeviceOption); + if (XPUDeviceOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void XPUDeviceOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +///////////////////////// xpu gflags //////////////////////////// +DEFINE_bool(cpu, false, "set CPU device as running device"); +#if MGE_WITH_CUDA +DEFINE_bool(cuda, false, "set CUDA device as running device "); +#endif +DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode"); +DEFINE_int32(multithread, -1, "set multithread device as running device"); +DEFINE_int32( + multithread_default, -1, + "set multithread device as running device with inplace mode"); +DEFINE_string(multi_thread_core_ids, "", "set multithread core id"); +REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option); \ No newline at end of file diff --git a/lite/load_and_run/src/options/device_options.h b/lite/load_and_run/src/options/device_options.h new file mode 100644 index 000000000..3386d2bac --- /dev/null +++ b/lite/load_and_run/src/options/device_options.h @@ -0,0 +1,49 @@ +/** + * \file lite/load_and_run/src/options/device_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include "models/model.h" +#include "option_base.h" + +DECLARE_bool(cpu); +#if MGE_WITH_CUDA +DECLARE_bool(cuda); +#endif +DECLARE_bool(cpu_default); +DECLARE_int32(multithread); +DECLARE_int32(multithread_default); +DECLARE_string(multi_thread_core_ids); +namespace lar { + +class XPUDeviceOption final : public OptionBase { +public: + static bool is_valid(); + static std::shared_ptr create_option(); + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + std::string option_name() const override { return m_option_name; }; + +private: + XPUDeviceOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + bool enable_cpu; +#if MGE_WITH_CUDA + bool enable_cuda; +#endif + bool enable_cpu_default; + bool enable_multithread; + bool enable_multithread_default; + bool enable_set_core_ids; + size_t thread_num; + std::vector core_ids; + std::string m_option_name; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/load_and_run/src/options/extern_c_opr_options.cpp b/lite/load_and_run/src/options/extern_c_opr_options.cpp new file mode 100644 index 000000000..d7131cdc9 --- /dev/null +++ b/lite/load_and_run/src/options/extern_c_opr_options.cpp @@ -0,0 +1,216 @@ +/** + * \file lite/load_and_run/src/options/extern_c_opr_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "extern_c_opr_options.h" +#include "megbrain/utils/debug.h" +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +namespace lar { +template <> +void COprLibOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + MGB_MARK_USED_VAR(model); + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (!lib_path.empty()) { + lite::set_loader_lib_path(lib_path); + } + if (c_opr_args.is_run_c_opr_with_param) { + LITE_THROW( + "lite model dont't support run with external c opr " + "parmeter"); + } + } +} +template <> +void COprLibOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (!lib_path.empty()) { + load_lib(); + } + if (c_opr_args.is_run_c_opr_with_param) { + mgb_assert( + c_opr_args.is_run_c_opr && + c_opr_args.copr_param_device_ptr_malloc && + c_opr_args.copr_param_device_ptr_free && + c_opr_args.copr_param_device_ptr_h2d, + "--c-opr-lib-with-param need config with --c-opr-lib, also " + "extern c opr loader need implemente " + "copr_param_device_ptr_malloc, copr_param_device_ptr_free " + "and copr_param_device_ptr_h2d symbols"); + } + } else if (runtime_param.stage == RunStage::MODEL_RUNNING) { + if (model->get_testcase_num() && c_opr_args.is_run_c_opr_with_param) { + init_extern_param(model); + set_Copr_IO(model); + } + } else if (runtime_param.stage == RunStage::AFTER_RUNNING_ITER) { + if (model->get_testcase_num() && c_opr_args.is_run_c_opr_with_param) { + c_opr_args.copr_param_device_ptr_free(c_opr_param.get()); + free(c_opr_param->input); + } + } +} +} // namespace lar + +using namespace lar; + +MGBDType COprLibOption::dtype_cpp2c(megdnn::DType dtype) { + switch (dtype.enumv()) { + case megdnn::DTypeEnum::Float32: + return MGB_DTYPE_FLOAT32; + case megdnn::DTypeEnum::Int32: + return MGB_DTYPE_INT32; + case megdnn::DTypeEnum::Int16: + return MGB_DTYPE_INT16; + case megdnn::DTypeEnum::Uint8: + return MGB_DTYPE_UINT8; +#if !MEGDNN_DISABLE_FLOAT16 + case megdnn::DTypeEnum::Float16: + return MGB_DTYPE_FLOAT16; +#endif + default: + mgb_throw( + mgb::InternalError, "unsupported dtype for extern C API: %s", + dtype.name()); + } +} + +void COprLibOption::tensor_shape_to_c( + const megdnn::TensorShape& shape, MGBTensorShape& mgb_shape) { + mgb_assert( + shape.ndim <= MGB_TENSOR_MAX_NDIM, "shape ndim too large: %zu", shape.ndim); + mgb_shape.ndim = shape.ndim; + for (size_t i = 0; i < shape.ndim; ++i) { + mgb_shape.shape[i] = shape[i]; + } +} + +void COprLibOption::init_extern_param(std::shared_ptr model_ptr) { + auto model = std::static_pointer_cast(model_ptr); + auto inp_tensors = model->get_test_input(); + + c_opr_param = std::make_shared(); + memset(c_opr_param.get(), 0, sizeof(ExternCOprParam)); + + //! we just test input on npu case, do not test output on + //! npu case, so we just init input shape and type + + c_opr_param->nr_input = inp_tensors.size(); + c_opr_param->input = (ExternDeviceTensor*)malloc( + sizeof(ExternDeviceTensor) * inp_tensors.size()); + memset(c_opr_param->input, 0, sizeof(ExternDeviceTensor) * inp_tensors.size()); + + //! init input ExternDeviceTensor shape and dtype + for (size_t input_idx = 0; input_idx < inp_tensors.size(); input_idx++) { + auto& mgb_tensor_layout = c_opr_param->input[input_idx].layout; + auto host_tensor_nd_p = inp_tensors[input_idx].second; + mgb_tensor_layout.dtype = dtype_cpp2c(host_tensor_nd_p->dtype()); + tensor_shape_to_c( + inp_tensors[input_idx].second->shape(), mgb_tensor_layout.shape); + } + c_opr_param->nr_output = 0; + + //! now call copr_param_device_ptr_malloc to malloc + //! device_ptr + c_opr_args.copr_param_device_ptr_malloc(c_opr_param.get()); +} + +void COprLibOption::load_lib() { + auto handle = dlopen(lib_path.c_str(), RTLD_LAZY); + mgb_assert(handle, "failed to open c opr lib %s: %s", lib_path.c_str(), dlerror()); + + const char* entry = MGB_C_OPR_INIT_FUNC_STR; + auto func = dlsym(handle, entry); + mgb_assert(func, "can not resolve %s: %s", entry, dlerror()); + typedef void (*entry_f_t)(void*); + reinterpret_cast(func)( + reinterpret_cast(&mgb_get_extern_c_opr_api_versioned)); + printf("loaded C opr library: %s\n", lib_path.c_str()); + entry = "copr_param_device_ptr_malloc"; + func = dlsym(handle, entry); + if (func) { + printf("get %s from: %s\n", entry, lib_path.c_str()); + c_opr_args.copr_param_device_ptr_malloc = + reinterpret_cast(func); + } + + entry = "copr_param_device_ptr_free"; + func = dlsym(handle, entry); + if (func) { + printf("get %s from: %s\n", entry, lib_path.c_str()); + c_opr_args.copr_param_device_ptr_free = + reinterpret_cast(func); + } + + entry = "copr_param_device_ptr_h2d"; + func = dlsym(handle, entry); + if (func) { + printf("get %s from: %s\n", entry, lib_path.c_str()); + c_opr_args.copr_param_device_ptr_h2d = + reinterpret_cast(func); + } +} + +void COprLibOption::set_Copr_IO(std::shared_ptr model_ptr) { + auto model = std::static_pointer_cast(model_ptr); + auto inp_tensors = model->get_test_input(); + auto loader = model->reset_loader(); + auto testcase = loader->load(model->get_mdl_config(), false); + mgb_assert(testcase.output_var_list.size() == inp_tensors.size()); + for (size_t i = 0; i < inp_tensors.size(); ++i) { + auto&& opr = testcase.output_var_list[i] + .node() + ->owner_opr() + ->cast_final_safe(); + c_opr_args.copr_param_device_ptr_h2d( + c_opr_param.get(), opr.dev_data()->raw_ptr(), i); + } + + //! now config c opr dynamic param + config_extern_c_opr_dynamic_param(model->get_async_func(), c_opr_param); +} + +COprLibOption::COprLibOption() { + m_option_name = "c_opr_lib"; + lib_path = FLAGS_c_opr_lib; + c_opr_args.is_run_c_opr = !lib_path.empty(); + c_opr_args.is_run_c_opr_with_param = FLAGS_c_opr_lib_with_param; +} + +bool COprLibOption::is_valid() { + return !FLAGS_c_opr_lib.empty() || FLAGS_c_opr_lib_with_param; +} + +std::shared_ptr COprLibOption::create_option() { + static std::shared_ptr option(new COprLibOption); + if (COprLibOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void COprLibOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +DEFINE_string( + c_opr_lib, "", + "Load external operator library. It must implement " + "MGB_C_OPR_INIT_FUNC_STR as the entry point"); +DEFINE_bool( + c_opr_lib_with_param, false, + "Run c opr lib with param, use to benchmark speed and check result, " + "need c opr loader implemente `copr_param_device_ptr_malloc, " + "copr_param_device_ptr_free and copr_param_device_ptr_h2d' symbols"); + +REGIST_OPTION_CREATOR(c_opr_lib, lar::COprLibOption::create_option); diff --git a/lite/load_and_run/src/options/extern_c_opr_options.h b/lite/load_and_run/src/options/extern_c_opr_options.h new file mode 100644 index 000000000..f55d91d60 --- /dev/null +++ b/lite/load_and_run/src/options/extern_c_opr_options.h @@ -0,0 +1,64 @@ +/** + * \file lite/load_and_run/src/options/extern_c_opr_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include "megbrain/graph/extern_copr_api.h" +#include "models/model.h" +#include "option_base.h" + +DECLARE_bool(c_opr_lib_with_param); +DECLARE_string(c_opr_lib); + +namespace lar { + +struct COprArgs { + //! for run c opr + bool is_run_c_opr = false; + bool is_run_c_opr_with_param = false; + typedef void (*COPR_PARAM_DEVICE_PTR_MEM_T)(ExternCOprParam* param); + typedef void (*COPR_PARAM_DEVICE_PTR_H2D_T)( + ExternCOprParam* param, void* host_ptr, size_t extern_device_tensor_id); + COPR_PARAM_DEVICE_PTR_MEM_T copr_param_device_ptr_malloc = nullptr; + COPR_PARAM_DEVICE_PTR_MEM_T copr_param_device_ptr_free = nullptr; + COPR_PARAM_DEVICE_PTR_H2D_T copr_param_device_ptr_h2d = nullptr; +}; + +class COprLibOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + COprLibOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + void load_lib(); + + MGBDType dtype_cpp2c(megdnn::DType dtype); + + void tensor_shape_to_c(const megdnn::TensorShape& shape, MGBTensorShape& mgb_shape); + + void init_extern_param(std::shared_ptr model); + + void set_Copr_IO(std::shared_ptr model); + + std::string m_option_name; + COprArgs c_opr_args; + std::string lib_path; + std::shared_ptr c_opr_param; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/load_and_run/src/options/fastrun_options.cpp b/lite/load_and_run/src/options/fastrun_options.cpp new file mode 100644 index 000000000..764bfeb90 --- /dev/null +++ b/lite/load_and_run/src/options/fastrun_options.cpp @@ -0,0 +1,231 @@ +/** + * \file lite/load_and_run/src/options/fastrun_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include + +#if defined(_WIN32) +#include +#define F_OK 0 +#define access(a, b) _access(a, b) +#elif __linux__ || __unix__ || __APPLE__ +#include +#endif +#include "fastrun_options.h" +#include "megbrain/gopt/inference.h" +#include "megbrain/utils/infile_persistent_cache.h" +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +namespace lar { + +template <> +void FastRunOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + //! set the algo policy before model load + using Strategy = ModelLite::Strategy; + uint32_t strategy = 0; +#if MGB_ENABLE_FASTRUN + if (enable_full_run) { + LITE_WARN("enable full-run strategy for algo profile"); + strategy = static_cast(Strategy::LITE_ALGO_PROFILE) | strategy; + } else if (enable_fast_run) { + LITE_WARN("enable fast-run strategy for algo profile"); + strategy = static_cast(Strategy::LITE_ALGO_PROFILE) | + static_cast(Strategy::LITE_ALGO_OPTIMIZED) | strategy; + } else { + strategy = static_cast(Strategy::LITE_ALGO_HEURISTIC) | strategy; + } +#else + strategy = static_cast(Strategy::LITE_ALGO_HEURISTIC) | strategy; +#endif + if (batch_binary_equal || enable_reproducible) { + LITE_WARN("enable reproducible strategy for algo profile"); + if (batch_binary_equal) + strategy = static_cast(Strategy::LITE_ALGO_REPRODUCIBLE) | + strategy; + } + auto lite_strategy = static_cast(strategy); + model->set_lite_strategy(lite_strategy); + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + auto lite_network = model->get_lite_network(); + auto lite_strategy = model->get_lite_strategy(); + //! set algo policy for model + lite::Runtime::set_network_algo_policy( + lite_network, lite_strategy, share_batch_size, batch_binary_equal); + if (!m_fast_run_cache.empty()) { + if (!access(m_fast_run_cache.c_str(), F_OK)) { + lite::set_persistent_cache(m_fast_run_cache); + } else { + lite::set_persistent_cache(m_fast_run_cache, true); + } + //! TODO:this is from mdl model settings but not matched settings in + //! lite model + // if (!enable_full_run && !enable_fast_run) + // mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { +#if MGB_ENABLE_FASTRUN + //! dump algo cache + if (!m_fast_run_cache.empty()) { + lite::dump_persistent_cache(m_fast_run_cache); + } +#endif + } +} + +template <> +void FastRunOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + //! set the algo policy before model load + using Strategy = ModelMdl::Strategy; + auto strategy = static_cast(0); +#if MGB_ENABLE_FASTRUN + if (enable_full_run) { + mgb_log_warn("enable full-run strategy for algo profile"); + strategy = Strategy::PROFILE | strategy; + } else if (enable_fast_run) { + mgb_log_warn("enable fast-run strategy for algo profile"); + strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy; + } else { + strategy = Strategy::HEURISTIC | strategy; + } +#else + strategy = Strategy::HEURISTIC | strategy; +#endif + if (batch_binary_equal || enable_reproducible) { + mgb_log_warn("enable reproducible strategy for algo profile"); + strategy = Strategy::REPRODUCIBLE | strategy; + } + model->set_mdl_strategy(strategy); + + //! set binary_equal_between_batch and shared_batch_size + if (batch_binary_equal) { + mgb_log_warn("enable batch binary equal"); + model->get_mdl_config() + .comp_graph->options() + .fast_run_config.binary_equal_between_batch = true; + } + if (share_batch_size > 0) { + mgb_log_warn("set shared shared batch"); + model->get_mdl_config() + .comp_graph->options() + .fast_run_config.shared_batch_size = share_batch_size; + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + auto vars = model->get_mdl_load_result().output_var_list; + auto strategy = model->get_mdl_strategy(); + mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); + // set algo cache path + if (!m_fast_run_cache.empty()) { + if (!access(m_fast_run_cache.c_str(), F_OK)) { + mgb::PersistentCache::set_impl( + std::make_shared( + m_fast_run_cache.c_str())); + } else { + mgb::PersistentCache::set_impl( + std::make_shared()); + } +#if MGB_ENABLE_FASTRUN + if (!enable_full_run && !enable_fast_run) +#endif + mgb::gopt::enable_opr_use_profiling_cache_inplace(vars); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { +#if MGB_ENABLE_FASTRUN + //! dump algo cache + if (!m_fast_run_cache.empty()) { + static_cast(mgb::PersistentCache::inst()) + .dump_cache(m_fast_run_cache.c_str()); + } +#endif + } +} + +} // namespace lar + +using namespace lar; + +FastRunOption::FastRunOption() { + m_option_name = "fastrun"; +#if MGB_ENABLE_FASTRUN + enable_fast_run = FLAGS_fast_run; + enable_full_run = FLAGS_full_run; +#endif + batch_binary_equal = FLAGS_binary_equal_between_batch; + enable_reproducible = FLAGS_reproducible; + m_fast_run_cache = FLAGS_fast_run_algo_policy; + share_batch_size = FLAGS_fast_run_shared_batch_size; +#if MGB_ENABLE_FASTRUN + //! while fastrun cache file path is not empty and can't be accessed + if (!m_fast_run_cache.empty() && access(m_fast_run_cache.c_str(), F_OK)) { + mgb_assert( + enable_full_run || enable_fast_run, + "--fast-run or --full-run should be enabled"); + } + if (share_batch_size) { + mgb_assert( + enable_full_run || enable_fast_run || !m_fast_run_cache.empty(), + "--fast-run-shared-batch-size should be used with " + "--fast-run|--full-run|--fast-run-algo-policy"); + } +#endif +} + +bool FastRunOption::is_valid() { + bool ret = false; +#if MGB_ENABLE_FASTRUN + ret = ret || FLAGS_fast_run; + ret = ret || FLAGS_full_run; +#endif + ret = ret || FLAGS_binary_equal_between_batch; + ret = ret || FLAGS_fast_run_shared_batch_size > 0; + ret = ret || FLAGS_reproducible; + ret = ret || FLAGS_fast_run_algo_policy.size() > 0; + + return ret; +} + +std::shared_ptr FastRunOption::create_option() { + static std::shared_ptr option(new FastRunOption); + if (FastRunOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void FastRunOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +#if MGB_ENABLE_FASTRUN +DEFINE_bool(fast_run, false, "whether to use fast-run in model run"); +DEFINE_bool(full_run, false, "whether to use full-run in model run"); +#endif + +DEFINE_bool( + binary_equal_between_batch, false, + "Each batch of output is promised binary equal if each batch of " + "input is binary equal\n Note that if this option is turned on, " + "`--reproducible` will also be turned on."); +DEFINE_bool( + reproducible, false, + "Enable choose algo which is reproducible. It mainly used for " + "cudnn algos.See " + "https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/" + "index.html#reproducibility" + "for more details."); +DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun"); +DEFINE_string(fast_run_algo_policy, "", "fast-run cache path."); + +REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option); \ No newline at end of file diff --git a/lite/load_and_run/src/options/fastrun_options.h b/lite/load_and_run/src/options/fastrun_options.h new file mode 100644 index 000000000..62b897764 --- /dev/null +++ b/lite/load_and_run/src/options/fastrun_options.h @@ -0,0 +1,57 @@ +/** + * \file lite/load_and_run/src/options/fastrun_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include +#include "models/model.h" +#include "option_base.h" + +#if MGB_ENABLE_FASTRUN +DECLARE_bool(fast_run); +DECLARE_bool(full_run); +#endif +DECLARE_bool(reproducible); +DECLARE_bool(binary_equal_between_batch); +DECLARE_uint32(fast_run_shared_batch_size); +DECLARE_string(fast_run_algo_policy); + +namespace lar { +class FastRunOption final : public OptionBase { +public: + //! get condition for construct FastRunOption + static bool is_valid(); + + //! creat option using condition from cmdline args + static std::shared_ptr create_option(); + + //! configure model for different runtime_param + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + //! get options name for quickly search + std::string option_name() const override { return m_option_name; } + +private: + FastRunOption(); + //! config template for different model + template + void config_model_internel(RuntimeParam&, std::shared_ptr) {} + +#if MGB_ENABLE_FASTRUN + bool enable_fast_run; //! fast run strategy flag + bool enable_full_run; //! full run strategy flag +#endif + bool batch_binary_equal; //! fast run stratgey setting + bool enable_reproducible; //! enable reproducible strategy + size_t share_batch_size; //! fast run strategy share batch size setting + std::string m_fast_run_cache; //! fast run cache file path + std::string m_option_name; //! option name +}; +} // namespace lar diff --git a/lite/load_and_run/src/options/io_options.cpp b/lite/load_and_run/src/options/io_options.cpp new file mode 100644 index 000000000..961ca99cc --- /dev/null +++ b/lite/load_and_run/src/options/io_options.cpp @@ -0,0 +1,295 @@ +/** + * \file lite/load_and_run/src/options/io_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include + +#include "helpers/data_parser.h" +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +#include "io_options.h" +namespace lar { +template <> +void InputOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto parser = model->get_input_parser(); + auto io = model->get_networkIO(); + for (size_t idx = 0; idx < data_path.size(); ++idx) { + parser.feed(data_path[idx].c_str()); + } + + auto inputs = parser.inputs; + bool is_host = true; + for (auto& i : inputs) { + io.inputs.push_back({i.first, is_host}); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + auto config = model->get_config(); + auto parser = model->get_input_parser(); + auto network = model->get_lite_network(); + + //! datd type map from mgb data type to lite data type + std::map type_map = { + {megdnn::DTypeEnum::Float32, LiteDataType::LITE_FLOAT}, + {megdnn::DTypeEnum::Int32, LiteDataType::LITE_INT}, + {megdnn::DTypeEnum::Int8, LiteDataType::LITE_INT8}, + {megdnn::DTypeEnum::Uint8, LiteDataType::LITE_UINT8}}; + + for (auto& i : parser.inputs) { + //! get tensor information from data parser + auto tensor = i.second; + auto data_type = tensor.dtype(); + auto tensor_shape = tensor.shape(); + mgb::dt_byte* src = tensor.raw_ptr(); + + //! set lite layout + lite::Layout layout; + layout.ndim = tensor_shape.ndim; + for (size_t idx = 0; idx < tensor_shape.ndim; idx++) { + layout.shapes[idx] = tensor_shape[idx]; + } + layout.data_type = type_map[data_type.enumv()]; + + //! set network input tensor + std::shared_ptr input_tensor = + network->get_io_tensor(i.first); + input_tensor->reset(src, layout); + } + } +} + +template <> +void InputOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto parser = model->get_input_parser(); + for (size_t idx = 0; idx < data_path.size(); ++idx) { + parser.feed(data_path[idx].c_str()); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + auto parser = model->get_input_parser(); + auto network = model->get_mdl_load_result(); + auto tensormap = network.tensor_map; + for (auto& i : parser.inputs) { + mgb_assert( + tensormap.find(i.first) != tensormap.end(), + "can't find tesnor named %s", i.first.c_str()); + auto& in = tensormap.find(i.first)->second; + in->copy_from(i.second); + } + } +} + +template <> +void IOdumpOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + if (enable_io_dump) { + LITE_WARN("enable text io dump"); + lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path); + } + if (enable_bin_io_dump) { + LITE_WARN("enable binary io dump"); + lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path); + } + //! FIX:when add API in lite complate this + if (enable_io_dump_stdout || enable_io_dump_stderr) { + LITE_THROW("lite model don't support the stdout or stderr io dump"); + } + if (enable_bin_out_dump) { + LITE_THROW("lite model don't support the binary output dump"); + } + if (enable_copy_to_host) { + LITE_WARN("lite model set copy to host defaultly"); + } + } +} + +template <> +void IOdumpOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (enable_io_dump) { + mgb_log_warn("enable text io dump"); + auto iodump = std::make_unique( + model->get_mdl_config().comp_graph.get(), dump_path.c_str()); + iodump->print_addr(false); + io_dumper = std::move(iodump); + } + + if (enable_io_dump_stdout) { + mgb_log_warn("enable text io dump to stdout"); + std::shared_ptr std_out(stdout, [](FILE*) {}); + auto iodump = std::make_unique( + model->get_mdl_config().comp_graph.get(), std_out); + iodump->print_addr(false); + io_dumper = std::move(iodump); + } + + if (enable_io_dump_stderr) { + mgb_log_warn("enable text io dump to stderr"); + std::shared_ptr std_err(stderr, [](FILE*) {}); + auto iodump = std::make_unique( + model->get_mdl_config().comp_graph.get(), std_err); + iodump->print_addr(false); + io_dumper = std::move(iodump); + } + + if (enable_bin_io_dump) { + mgb_log_warn("enable binary io dump"); + auto iodump = std::make_unique( + model->get_mdl_config().comp_graph.get(), dump_path); + io_dumper = std::move(iodump); + } + + if (enable_bin_out_dump) { + mgb_log_warn("enable binary output dump"); + out_dumper = std::make_unique(dump_path.c_str()); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + if (enable_bin_out_dump) { + auto load_result = model->get_mdl_load_result(); + out_dumper->set(load_result.output_var_list); + + std::vector cb; + for (size_t i = 0; i < load_result.output_var_list.size(); i++) { + cb.push_back(out_dumper->bind()); + } + model->set_output_callback(cb); + } + if (enable_copy_to_host) { + auto load_result = model->get_mdl_load_result(); + + std::vector cb; + for (size_t i = 0; i < load_result.output_var_list.size(); i++) { + mgb::HostTensorND val; + auto callback = [val](const mgb::DeviceTensorND& dv) mutable { + val.copy_from(dv); + }; + cb.push_back(callback); + } + model->set_output_callback(cb); + } + } else if (runtime_param.stage == RunStage::AFTER_RUNNING_WAIT) { + if (enable_bin_out_dump) { + out_dumper->write_to_file(); + } + } +} + +} // namespace lar + +////////////////////// Input options //////////////////////// +using namespace lar; + +InputOption::InputOption() { + m_option_name = "input"; + size_t start = 0; + auto end = FLAGS_input.find(";", start); + while (end != std::string::npos) { + std::string path = FLAGS_input.substr(start, end - start); + data_path.emplace_back(path); + start = end + 1; + end = FLAGS_input.find(";", start); + } + data_path.emplace_back(FLAGS_input.substr(start)); +} + +std::shared_ptr lar::InputOption::create_option() { + static std::shared_ptr m_option(new InputOption); + if (InputOption::is_valid()) { + return std::static_pointer_cast(m_option); + } else { + return nullptr; + } +} + +void InputOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +////////////////////// OprIOdump options //////////////////////// + +IOdumpOption::IOdumpOption() { + m_option_name = "iodump"; + size_t valid_flag = 0; + if (!FLAGS_io_dump.empty()) { + dump_path = FLAGS_io_dump; + enable_io_dump = true; + valid_flag = valid_flag | (1 << 0); + } + if (!FLAGS_bin_io_dump.empty()) { + dump_path = FLAGS_bin_io_dump; + enable_bin_io_dump = true; + valid_flag = valid_flag | (1 << 1); + } + if (!FLAGS_bin_out_dump.empty()) { + dump_path = FLAGS_bin_out_dump; + enable_bin_out_dump = true; + valid_flag = valid_flag | (1 << 2); + } + if (FLAGS_io_dump_stdout) { + enable_io_dump_stdout = FLAGS_io_dump_stdout; + valid_flag = valid_flag | (1 << 3); + } + if (FLAGS_io_dump_stderr) { + enable_io_dump_stderr = FLAGS_io_dump_stderr; + valid_flag = valid_flag | (1 << 4); + } + // not only one dump set valid + if (valid_flag && (valid_flag & (valid_flag - 1))) { + mgb_log_warn( + "ONLY the last io dump option is validate and others is " + "skipped!!!"); + } + + enable_copy_to_host = FLAGS_copy_to_host; +} + +bool IOdumpOption::is_valid() { + bool ret = !FLAGS_io_dump.empty(); + ret = ret || FLAGS_io_dump_stdout; + ret = ret || FLAGS_io_dump_stderr; + ret = ret || !FLAGS_bin_io_dump.empty(); + ret = ret || !FLAGS_bin_out_dump.empty(); + ret = ret || FLAGS_copy_to_host; + return ret; +} + +std::shared_ptr IOdumpOption::create_option() { + static std::shared_ptr option(new IOdumpOption); + if (IOdumpOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void IOdumpOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +////////////////////// Input gflags //////////////////////// +DEFINE_string( + input, "", "Set up inputs data for model --input [ file_path | data_string]"); + +////////////////////// OprIOdump gflags //////////////////////// + +DEFINE_string(io_dump, "", "set the io dump file path in text format"); +DEFINE_bool(io_dump_stdout, false, "dump io opr to stdout in text format"); +DEFINE_bool(io_dump_stderr, false, "dump io opr to stderr in text format"); +DEFINE_string(bin_io_dump, "", "set the io dump file path in binary format"); +DEFINE_string(bin_out_dump, "", "set the out dump file path in binary format"); +DEFINE_bool(copy_to_host, false, "copy device data to host"); + +REGIST_OPTION_CREATOR(input, lar::InputOption::create_option); +REGIST_OPTION_CREATOR(iodump, lar::IOdumpOption::create_option); diff --git a/lite/load_and_run/src/options/io_options.h b/lite/load_and_run/src/options/io_options.h new file mode 100644 index 000000000..1f4c9c8d0 --- /dev/null +++ b/lite/load_and_run/src/options/io_options.h @@ -0,0 +1,78 @@ +/** + * \file lite/load_and_run/src/options/io_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include "helpers/outdumper.h" +#include "megbrain/plugin/opr_io_dump.h" +#include "models/model.h" +#include "option_base.h" + +DECLARE_string(input); + +DECLARE_string(io_dump); +DECLARE_bool(io_dump_stdout); +DECLARE_bool(io_dump_stderr); +DECLARE_string(bin_io_dump); +DECLARE_string(bin_out_dump); +DECLARE_bool(copy_to_host); + +namespace lar { + +/*! + * \brief: input option for --input set + */ +class InputOption final : public OptionBase { +public: + //! static function for registe options + static bool is_valid() { return !FLAGS_input.empty(); }; + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + //! interface implement from OptionBase + std::string option_name() const override { return m_option_name; }; + +private: + InputOption(); + + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + std::vector data_path; // data string or data file path +}; + +class IOdumpOption : public OptionBase { +public: + static bool is_valid(); + static std::shared_ptr create_option(); + //! config the model, if different has different configure code, then + //! dispatch + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + std::string option_name() const override { return m_option_name; }; + +private: + IOdumpOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + bool enable_io_dump; + bool enable_io_dump_stdout; + bool enable_io_dump_stderr; + bool enable_bin_io_dump; + bool enable_bin_out_dump; + bool enable_copy_to_host; + std::string m_option_name; + std::string dump_path; + std::unique_ptr io_dumper; + std::unique_ptr out_dumper; +}; +} // namespace lar diff --git a/lite/load_and_run/src/options/layout_options.cpp b/lite/load_and_run/src/options/layout_options.cpp new file mode 100644 index 000000000..519771398 --- /dev/null +++ b/lite/load_and_run/src/options/layout_options.cpp @@ -0,0 +1,171 @@ +/** + * \file lite/load_and_run/src/options/layout_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include + +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +#include "layout_options.h" +namespace lar { +template <> +void LayoutOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { +#define ENABLE_LAYOUT(layout) \ + LITE_WARN("enable " #layout " optimization"); \ + model->get_config().options.enable_##layout = true; \ + break; + + switch (option_flag) { + case OptLayoutType::NCHW4: + ENABLE_LAYOUT(nchw4) + + case OptLayoutType::CHWN4: + LITE_THROW("lite model unsupport chwn4 layout"); + break; + case OptLayoutType::NCHW44: + ENABLE_LAYOUT(nchw44) + + case OptLayoutType::NCHW88: + ENABLE_LAYOUT(nchw88) + + case OptLayoutType::NCHW32: + ENABLE_LAYOUT(nchw32) + + case OptLayoutType::NCHW64: + ENABLE_LAYOUT(nchw64) + + case OptLayoutType::NHWCD4: + ENABLE_LAYOUT(nhwcd4) + + case OptLayoutType::NCHW44_DOT: + ENABLE_LAYOUT(nchw44_dot) + default: + break; + } +#undef ENABLE_LAYOUT + } +} + +template <> +void lar::LayoutOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + mgb_log_debug("mdl layout config start"); +#define ENABLE_LAYOUT(layout) \ + mgb_log_warn("enable " #layout " optimization"); \ + model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \ + break; + + switch (option_flag) { + case OptLayoutType::NCHW4: + ENABLE_LAYOUT(nchw4) + + case OptLayoutType::CHWN4: + ENABLE_LAYOUT(chwn4) + + case OptLayoutType::NCHW44: + ENABLE_LAYOUT(nchw44) + + case OptLayoutType::NCHW88: + ENABLE_LAYOUT(nchw88) + + case OptLayoutType::NCHW32: + ENABLE_LAYOUT(nchw32) + + case OptLayoutType::NCHW64: + ENABLE_LAYOUT(nchw64) + + case OptLayoutType::NHWCD4: + ENABLE_LAYOUT(nhwcd4) + + case OptLayoutType::NCHW44_DOT: + ENABLE_LAYOUT(nchw44_dot) + + default: + break; + } + mgb_log_debug("mdl layout config end"); + +#undef ENABLE_LAYOUT + } +} +} // namespace lar + +using namespace lar; + +OptLayoutType LayoutOption::option_flag; + +LayoutOption::LayoutOption() { + m_option_name = "layout"; +} + +bool LayoutOption::is_valid() { + size_t valid_flag = 0; + if (FLAGS_enable_nchw4) { + valid_flag = valid_flag | (1 << 0); + } + if (FLAGS_enable_chwn4) { + valid_flag = valid_flag | (1 << 1); + } + if (FLAGS_enable_nchw44) { + valid_flag = valid_flag | (1 << 2); + } + if (FLAGS_enable_nchw88) { + valid_flag = valid_flag | (1 << 3); + } + if (FLAGS_enable_nchw32) { + valid_flag = valid_flag | (1 << 4); + } + if (FLAGS_enable_nchw64) { + valid_flag = valid_flag | (1 << 5); + } + if (FLAGS_enable_nhwcd4) { + valid_flag = valid_flag | (1 << 6); + } + if (FLAGS_enable_nchw44_dot) { + valid_flag = valid_flag | (1 << 7); + } + + bool ret = valid_flag && !(valid_flag & (valid_flag - 1)); + if (ret) { + option_flag = static_cast(valid_flag); + } else { + option_flag = static_cast(0); + } + + return ret; +}; + +std::shared_ptr LayoutOption::create_option() { + static std::shared_ptr option(new LayoutOption); + if (LayoutOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void LayoutOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +DEFINE_bool(enable_nchw4, false, "enable nchw4 layout optimization!!"); +DEFINE_bool(enable_chwn4, false, "enable chwn4 layout optimization!!"); +DEFINE_bool(enable_nchw44, false, "enable nchw44 layout optimization!!"); +DEFINE_bool(enable_nchw88, false, "enable nchw88 layout optimization!!"); +DEFINE_bool(enable_nchw32, false, "enable nchw32 layout optimization!!"); +DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!"); +DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!"); +DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!"); + +REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option); \ No newline at end of file diff --git a/lite/load_and_run/src/options/layout_options.h b/lite/load_and_run/src/options/layout_options.h new file mode 100644 index 000000000..fcf6a1772 --- /dev/null +++ b/lite/load_and_run/src/options/layout_options.h @@ -0,0 +1,56 @@ +/** + * \file lite/load_and_run/src/options/layout_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include +#include "helpers/common.h" +#include "models/model.h" +#include "option_base.h" + +DECLARE_bool(enable_nchw4); +DECLARE_bool(enable_chwn4); +DECLARE_bool(enable_nchw44); +DECLARE_bool(enable_nchw88); +DECLARE_bool(enable_nchw32); +DECLARE_bool(enable_nchw64); +DECLARE_bool(enable_nhwcd4); +DECLARE_bool(enable_nchw44_dot); + +namespace lar { +/*! + * \brief: layout option for optimization + */ +class LayoutOption final : public OptionBase { +public: + //! check the validation of option flag + static bool is_valid(); + + //! creat options when option is used + static std::shared_ptr create_option(); + + //! config the model, dispatch configuration for different model implement + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + //! get option name + std::string option_name() const override { return m_option_name; }; + +private: + //! Constructor + LayoutOption(); + + //! configuration for different model implement + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + static OptLayoutType option_flag; + std::string m_option_name; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/load_and_run/src/options/optimize_options.cpp b/lite/load_and_run/src/options/optimize_options.cpp new file mode 100644 index 000000000..c684a3afa --- /dev/null +++ b/lite/load_and_run/src/options/optimize_options.cpp @@ -0,0 +1,600 @@ +/** + * \file lite/load_and_run/src/options/optimize_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "megbrain/gopt/inference.h" +#if MGB_ENABLE_TENSOR_RT +#include "megbrain/tensorrt/tensorrt_engine_cache.h" +#endif +#include "lite/global.h" +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" +#include "optimize_options.h" + +///////////////////////// fuse and preprocess optimize options /////////////// +namespace lar { +template <> +void FusePreprocessOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (enable_fuse_preprocess) { + LITE_WARN("enable fuse-preprocess optimization"); + model->get_config().options.fuse_preprocess = true; + } + } +} + +template <> +void FusePreprocessOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (enable_fuse_preprocess) { + mgb_log_warn("enable fuse-preprocess optimization"); + graph_option.graph_opt.enable_fuse_preprocess(); + } + } +} +} // namespace lar +using namespace lar; + +FusePreprocessOption::FusePreprocessOption() { + m_option_name = "fuse_preprocess"; + enable_fuse_preprocess = FLAGS_enable_fuse_preprocess; +} + +bool FusePreprocessOption::is_valid() { + bool ret = FLAGS_enable_fuse_preprocess; + return ret; +} + +std::shared_ptr FusePreprocessOption::create_option() { + static std::shared_ptr option(new FusePreprocessOption); + if (FusePreprocessOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void FusePreprocessOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +///////////////////////// weight preprocess optimize options /////////////// +namespace lar { +template <> +void WeightPreprocessOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (weight_preprocess) { + LITE_WARN("enable weight-preprocess optimization"); + model->get_config().options.weight_preprocess = true; + //! FIXME: algo searcher enable weight preprocess for opencl( + //! implement below has some problem); + // #if MGB_OPENCL + // megdnn::opencl::algo_searcher::AlgoSearcherBase:: + // enable_weight_preprocess(); + // #endif + } + } +} + +template <> +void WeightPreprocessOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (weight_preprocess) { + mgb_log_warn("enable weight-preprocess optimization"); + graph_option.graph_opt.enable_weight_preprocess(); + //! FIXME: this implemment is not right + // #if MGB_OPENCL + // megdnn::opencl::algo_searcher::AlgoSearcherBase:: + // enable_weight_preprocess(); + // #endif + } + } +} +} // namespace lar + +WeightPreprocessOption::WeightPreprocessOption() { + m_option_name = "weight_preprocess"; + weight_preprocess = FLAGS_weight_preprocess; +} + +bool WeightPreprocessOption::is_valid() { + bool ret = FLAGS_weight_preprocess; + return ret; +} + +std::shared_ptr WeightPreprocessOption::create_option() { + static std::shared_ptr option(new WeightPreprocessOption); + if (WeightPreprocessOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void WeightPreprocessOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +///// fuse conv bias and nonlinear activation opr optimize options //////// +namespace lar { +template <> +void FuseConvBiasNonlinearOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + LITE_MARK_USED_VAR(model); + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (enable_fuse_conv_bias_nonlinearity) { + LITE_THROW("fuse conv+bias+nonlinearity not supported in lite model"); + } + } +} + +template <> +void FuseConvBiasNonlinearOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (enable_fuse_conv_bias_nonlinearity) { + mgb_log_warn("enable fuse conv+bias+nonlinearity optimization"); + graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity(); + } + } +} +} // namespace lar + +FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() { + m_option_name = "fuse_conv_bias_nonlinear"; + enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity; +} + +bool FuseConvBiasNonlinearOption::is_valid() { + bool ret = FLAGS_enable_fuse_conv_bias_nonlinearity; + return ret; +} + +std::shared_ptr FuseConvBiasNonlinearOption::create_option() { + static std::shared_ptr option( + new FuseConvBiasNonlinearOption); + if (FuseConvBiasNonlinearOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void FuseConvBiasNonlinearOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +///////////////////////// fuse and preprocess optimize options /////////////// +namespace lar { +template <> +void FuseConvBiasElemwiseAddOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + LITE_MARK_USED_VAR(model); + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (enable_fuse_conv_bias_with_z) { + LITE_THROW( + "fuse conv+bias+z optimization not supported in lite " + "model"); + } + } +} + +template <> +void FuseConvBiasElemwiseAddOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (enable_fuse_conv_bias_with_z) { + mgb_log_warn("enable fuse conv+bias+z optimization"); + graph_option.graph_opt.enable_fuse_conv_bias_with_z(); + } + } +} +} // namespace lar + +FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() { + m_option_name = "fuse_conv_bias_z"; + enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z; +} + +bool FuseConvBiasElemwiseAddOption::is_valid() { + bool ret = FLAGS_enable_fuse_conv_bias_with_z; + return ret; +} + +std::shared_ptr FuseConvBiasElemwiseAddOption::create_option() { + static std::shared_ptr option( + new FuseConvBiasElemwiseAddOption); + if (FuseConvBiasElemwiseAddOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void FuseConvBiasElemwiseAddOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +///////////////////////// graph retrict options ///////////////////////// +namespace lar { +template <> +void GraphRecordOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& config_option = model->get_config().options; + if (const_shape) { + LITE_WARN("enable const var shape"); + config_option.const_shape = true; + } + if (fake_first) { + LITE_WARN("enable fake-first optimization"); + config_option.fake_next_exec = true; + } + if (no_sanity_check) { + LITE_WARN("disable var sanity check optimization"); + config_option.var_sanity_check_first_run = false; + } + if (m_record_comp_seq == 1) { + LITE_WARN("set record_comp_seq_level to 1"); + } + if (m_record_comp_seq == 2) { + mgb_assert( + no_sanity_check, + "--no-sanity-check should be set before " + "--record-comp-seq2"); + LITE_WARN("set record_comp_seq_level to 2"); + } + config_option.comp_node_seq_record_level = m_record_comp_seq; + } +} + +template <> +void GraphRecordOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (const_shape) { + mgb_log_warn("enable const var shape"); + model->get_mdl_config().const_var_shape = true; + } + if (fake_first) { + mgb_log_warn("enable fake-first optimization"); + graph_option.fake_next_exec = true; + } + if (no_sanity_check) { + mgb_log_warn("disable var sanity check optimization"); + graph_option.var_sanity_check_first_run = false; + } + if (m_record_comp_seq == 1) { + mgb_log_warn("set record_comp_seq_level to 1"); + } + if (m_record_comp_seq == 2) { + mgb_assert( + no_sanity_check && !fake_first, + "--no-sanity-check should be set before " + "--record-comp-seq2 and --fake-first should not be set"); + mgb_log_warn("set record_comp_seq_level to 2"); + } + graph_option.comp_node_seq_record_level = m_record_comp_seq; + } +} +} // namespace lar + +GraphRecordOption::GraphRecordOption() { + m_option_name = "graph_record"; + m_record_comp_seq = 0; + const_shape = FLAGS_const_shape; + fake_first = FLAGS_fake_first; + no_sanity_check = FLAGS_no_sanity_check; + if (FLAGS_record_comp_seq) { + m_record_comp_seq = 1; + } + if (FLAGS_record_comp_seq2) { + m_record_comp_seq = 2; + } +} + +bool GraphRecordOption::is_valid() { + bool ret = FLAGS_const_shape; + ret = ret || FLAGS_fake_first; + ret = ret || FLAGS_no_sanity_check; + ret = ret || FLAGS_record_comp_seq; + ret = ret || FLAGS_record_comp_seq2; + return ret; +} + +std::shared_ptr GraphRecordOption::create_option() { + static std::shared_ptr option(new GraphRecordOption); + if (GraphRecordOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void GraphRecordOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +///////////////////////// graph retrict options ///////////////////////// +namespace lar { +template <> +void MemoryOptimizeOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + LITE_MARK_USED_VAR(model); + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (disable_mem_opt) { + LITE_THROW("lite model don't support disable memory optimization"); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + if (workspace_limit != SIZE_MAX) { + LITE_WARN("set workspace limit to %ld", workspace_limit); + lite::Runtime::set_network_algo_workspace_limit( + model->get_lite_network(), workspace_limit); + } + } +} + +template <> +void MemoryOptimizeOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (disable_mem_opt) { + mgb_log_warn("disable memory optimization"); + graph_option.seq_opt.enable_mem_plan_opt = false; + graph_option.seq_opt.enable_mem_reuse_alloc = false; + } + if (workspace_limit < SIZE_MAX) { + mgb_log_warn("set workspace limit to %ld", workspace_limit); + auto output_spec = model->get_output_spec(); + mgb::SymbolVarArray vars; + for (auto i : output_spec) { + vars.push_back(i.first); + } + mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, workspace_limit); + } + } +} +} // namespace lar + +MemoryOptimizeOption::MemoryOptimizeOption() { + m_option_name = "memory_optimize"; + disable_mem_opt = FLAGS_disable_mem_opt; + workspace_limit = FLAGS_workspace_limit; +} + +bool MemoryOptimizeOption::is_valid() { + bool ret = FLAGS_disable_mem_opt; + ret = ret || FLAGS_workspace_limit < SIZE_MAX; + return ret; +} + +std::shared_ptr MemoryOptimizeOption::create_option() { + static std::shared_ptr option(new MemoryOptimizeOption); + if (MemoryOptimizeOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void MemoryOptimizeOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +///////////////////////// other options for optimization ///////////////// +namespace lar { +template <> +void JITOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& config_option = model->get_config().options; + if (enable_jit) { + LITE_WARN("enable JIT (level 1)"); + config_option.jit_level = 1; + } + } +} + +template <> +void JITOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (enable_jit) { + mgb_log_warn("enable JIT (level 1)"); + graph_option.graph_opt.jit = 1; + } + } +} +} // namespace lar +JITOption::JITOption() { + m_option_name = "JIT"; + enable_jit = FLAGS_enable_jit; +} + +bool JITOption::is_valid() { + bool ret = FLAGS_enable_jit; + return ret; +} + +std::shared_ptr JITOption::create_option() { + static std::shared_ptr option(new JITOption); + if (JITOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void JITOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +///////////////////////// other options for optimization ///////////////// +#if MGB_ENABLE_TENSOR_RT +namespace lar { +template <> +void TensorRTOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + if (!tensorrt_cache.empty()) { + LITE_WARN("set tensorrt cache as %s", tensorrt_cache.c_str()); + lite::set_tensor_rt_cache(tensorrt_cache); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + if (enable_tensorrt) { + LITE_WARN("enable TensorRT"); + lite::Runtime::use_tensorrt(model->get_lite_network()); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { + if (!tensorrt_cache.empty()) { + lite::dump_tensor_rt_cache(); + } + } +} + +template <> +void TensorRTOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto&& graph_option = model->get_mdl_config().comp_graph->options(); + if (enable_tensorrt) { + mgb_log_warn("using tensorRT"); + graph_option.graph_opt.tensorrt = true; + } + if (!tensorrt_cache.empty()) { + mgb_log_warn("use tensorrt cache: %s", tensorrt_cache.c_str()); + mgb::TensorRTEngineCache::enable_engine_cache(true); + mgb::TensorRTEngineCache::set_impl( + std::make_shared( + tensorrt_cache.c_str())); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { + if (!tensorrt_cache.empty()) { + if (mgb::TensorRTEngineCache::enable_engine_cache()) { + mgb::TensorRTEngineCache::inst().dump_cache(); + } + } + } +} +} // namespace lar + +TensorRTOption::TensorRTOption() { + m_option_name = "tensorRT"; + enable_tensorrt = FLAGS_tensorrt; + tensorrt_cache = FLAGS_tensorrt_cache; +} + +bool TensorRTOption::is_valid() { + bool ret = FLAGS_tensorrt; + ret = ret || !FLAGS_tensorrt_cache.empty(); + return ret; +} + +std::shared_ptr TensorRTOption::create_option() { + static std::shared_ptr option(new TensorRTOption); + if (TensorRTOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void TensorRTOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +#endif +///////////////////////// fuse and preprocess optimize options /////////////// +DEFINE_bool( + enable_fuse_preprocess, false, + "Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr"); +DEFINE_bool( + weight_preprocess, false, + "Execute operators with weight preprocess, which can optimize the " + "operator execution time with algo of winograd, im2col ,etc., but " + "it may consume more memory."); +DEFINE_bool( + enable_fuse_conv_bias_nonlinearity, false, + "whether to fuse conv+bias+nonlinearity"); +DEFINE_bool( + enable_fuse_conv_bias_with_z, false, + "fuse conv,bias (elemwise add),z(elemwise add) into one opr " + "(only support on GPU)"); + +///////////////////////// graph retrict options ///////////////////////// +DEFINE_bool( + const_shape, false, + "set const_var_shape to reduce memory usage, since some static " + "inference data structures can be omitted"); +DEFINE_bool( + fake_first, false, + "Enable fake exec for the first run. In fake exec mode, some " + "initialization job would be done, but no actual computing is " + "performed."); +DEFINE_bool(no_sanity_check, false, "Disable var sanity check on the first run"); +DEFINE_bool( + record_comp_seq, false, + "Record the computing sequence, in level 1 . It reduces overhead of API" + "calls of some asynchronous computing devices"); +DEFINE_bool( + record_comp_seq2, false, + "Record the computing sequence, in level 2, the computing graph can be" + "destructed to reduce memory usage"); +DEFINE_bool(disable_mem_opt, false, "disable memory optimization!!"); +DEFINE_uint64(workspace_limit, SIZE_MAX, "set workspace upbound limit"); + +///////////////////////// other options for optimization ///////////////// +DEFINE_bool( + enable_jit, false, + " Execute supported operators with JIT(now only support NVRTC). " + "Can only be used on Nvidia GPUs"); +#if MGB_ENABLE_ANDROID_NN +DEFINE_bool( + android_nn, false, + "Execute supported operators with Android NN. Can only be used " + "with --cpu."); +#endif +#if MGB_ENABLE_TENSOR_RT +DEFINE_bool( + tensorrt, false, + " Execute supported operators with TensorRT. Can only be used on " + "Nvidia GPUs,i.e. comp node is xpu or gpu."); +DEFINE_string( + tensorrt_cache, "", + "Set the TensorRT engine cache path for serialized prebuilt " + "ICudaEngine"); +#endif +REGIST_OPTION_CREATOR(fuse_preprocess, lar::FusePreprocessOption::create_option); +REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); +REGIST_OPTION_CREATOR( + fuse_conv_bias_nonlinear, lar::FuseConvBiasNonlinearOption::create_option); +REGIST_OPTION_CREATOR( + fuse_conv_bias_z, lar::FuseConvBiasElemwiseAddOption::create_option); +REGIST_OPTION_CREATOR(graph_record, lar::GraphRecordOption::create_option); +REGIST_OPTION_CREATOR(memory_optimize, lar::MemoryOptimizeOption::create_option); +REGIST_OPTION_CREATOR(JIT, lar::JITOption::create_option); +#if MGB_ENABLE_TENSOR_RT +REGIST_OPTION_CREATOR(tensorRT, lar::TensorRTOption::create_option); +#endif \ No newline at end of file diff --git a/lite/load_and_run/src/options/optimize_options.h b/lite/load_and_run/src/options/optimize_options.h new file mode 100644 index 000000000..f35574c52 --- /dev/null +++ b/lite/load_and_run/src/options/optimize_options.h @@ -0,0 +1,207 @@ +/** + * \file lite/load_and_run/src/options/optimize_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include "helpers/common.h" +#include "models/model.h" +#include "option_base.h" + +DECLARE_bool(enable_fuse_preprocess); +DECLARE_bool(weight_preprocess); +DECLARE_bool(enable_fuse_conv_bias_nonlinearity); +DECLARE_bool(enable_fuse_conv_bias_with_z); + +DECLARE_bool(const_shape); +DECLARE_bool(fake_first); +DECLARE_bool(no_sanity_check); +DECLARE_bool(record_comp_seq); +DECLARE_bool(record_comp_seq2); +DECLARE_bool(disable_mem_opt); +DECLARE_uint64(workspace_limit); + +DECLARE_bool(enable_jit); +#if MGB_ENABLE_TENSOR_RT +DECLARE_bool(tensorrt); +DECLARE_string(tensorrt_cache); +#endif +namespace lar { +///////////////////////// fuse_preprocess optimize options ////////////// +class FusePreprocessOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + FusePreprocessOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + bool enable_fuse_preprocess; +}; + +///////////////////////// weight preprocess optimize options ////////////// +class WeightPreprocessOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + WeightPreprocessOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + bool weight_preprocess; +}; + +/////////////// fuse_conv_bias_nonlinearity optimize options /////////////// +class FuseConvBiasNonlinearOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + FuseConvBiasNonlinearOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + bool enable_fuse_conv_bias_nonlinearity; +}; + +///////////////////////// fuse_conv_bias_with_z optimize options ////////////// +class FuseConvBiasElemwiseAddOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + FuseConvBiasElemwiseAddOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + std::string m_option_name; + bool enable_fuse_conv_bias_with_z; +}; + +///////////////////////// graph record options /////////////////////////// +class GraphRecordOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + GraphRecordOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + size_t m_record_comp_seq; + bool const_shape; + bool fake_first; + bool no_sanity_check; +}; + +///////////////////////// memory optimize options ///////////////////////// +class MemoryOptimizeOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + MemoryOptimizeOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + bool disable_mem_opt; + uint64_t workspace_limit; +}; + +///////////////////////// other options for optimization ///////////////// +class JITOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + JITOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + bool enable_jit; +}; +///////////////////////// TensorRT options for optimization ///////////////// +#if MGB_ENABLE_TENSOR_RT +class TensorRTOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + TensorRTOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + + std::string m_option_name; + bool enable_tensorrt; + std::string tensorrt_cache; +}; +#endif +} // namespace lar \ No newline at end of file diff --git a/lite/load_and_run/src/options/option_base.h b/lite/load_and_run/src/options/option_base.h new file mode 100644 index 000000000..ccc7363a8 --- /dev/null +++ b/lite/load_and_run/src/options/option_base.h @@ -0,0 +1,87 @@ +/** + * \file lite/load_and_run/src/options/option_base.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include "megbrain/common.h" + +#include "helpers/common.h" +#include "models/model.h" + +namespace lar { +/*! + * \brief: base class of options + */ +class OptionBase { +public: + //! configure model in different runtime state + virtual void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) = 0; + //! get depend options + virtual std::vector depend_option() const { return {}; }; + + //! get option name + virtual std::string option_name() const = 0; + + virtual ~OptionBase() = default; +}; + +/*! + * \brief: Singleton option factory for register options before main function + */ +class OptionFactory { +public: + using OptionCreator = std::function()>; + using OptionMap = std::unordered_map; + + //! get Singleton option factory + static OptionFactory& get_Instance() { + static OptionFactory instance; + return instance; + } + + //! registe option creator into option map + void registe_options(std::string name, OptionCreator creator) { + if (option_creator_map.count(name) == 0) { + option_creator_map[name] = creator; + } + } + + //! get creator map + OptionMap* get_option_creator_map() { return &option_creator_map; } + +private: + OptionFactory(){}; + OptionMap option_creator_map; +}; + +} // namespace lar + +#define REGIST_OPTION_CREATOR(name_, creator_) \ + struct OptionRegister_##name_ { \ + OptionRegister_##name_() { \ + lar::OptionFactory::get_Instance().registe_options(#name_, creator_); \ + } \ + }; \ + OptionRegister_##name_ name_; + +#define CONFIG_MODEL_FUN \ + if (model->type() == ModelType::LITE_MODEL) { \ + config_model_internel( \ + runtime_param, std::static_pointer_cast(model)); \ + } else if (model->type() == ModelType::MEGDL_MODEL) { \ + config_model_internel( \ + runtime_param, std::static_pointer_cast(model)); \ + } +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/lite/load_and_run/src/options/plugin_options.cpp b/lite/load_and_run/src/options/plugin_options.cpp new file mode 100644 index 000000000..c05eafb96 --- /dev/null +++ b/lite/load_and_run/src/options/plugin_options.cpp @@ -0,0 +1,401 @@ +/** + * \file lite/load_and_run/src/options/plugin_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "plugin_options.h" +#include "misc.h" +#include "models/model_lite.h" +#include "models/model_mdl.h" + +///////////////////// Plugin options/////////////////////////// +namespace lar { + +template <> +void PluginOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + LITE_ASSERT(range == 0, "lite model don't support NumRangeChecker plugin"); + LITE_ASSERT( + !enable_check_dispatch, + "lite model don't support CPUDispatchChecker plugin"); + LITE_ASSERT( + var_value_check_str.empty(), + "lite model don't support VarValueChecker plugin"); + } +#if MGB_ENABLE_JSON + else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + if (!profile_path.empty()) { + if (!enable_profile_host) { + LITE_WARN("enable profiling"); + model->get_lite_network()->enable_profile_performance(profile_path); + } else { + LITE_WARN("enable profiling for host"); + model->get_lite_network()->enable_profile_performance(profile_path); + } + } + } +#endif +} + +template <> +void PluginOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto config = model->get_mdl_config(); + if (range > 0) { + mgb_log_warn("enable number range check"); + model->set_num_range_checker(float(range)); + } + + if (enable_check_dispatch) { + mgb_log_warn("enable cpu dispatch check"); + cpu_dispatch_checker = + std::make_unique(config.comp_graph.get()); + } + + if (!var_value_check_str.empty()) { + mgb_log_warn("enable variable value check"); + size_t init_idx = 0, switch_interval; + auto sep = var_value_check_str.find(':'); + if (sep != std::string::npos) { + switch_interval = std::stoul(var_value_check_str.substr(0, sep)); + init_idx = std::stoul(var_value_check_str.substr(sep + 1)); + } else { + switch_interval = std::stoul(var_value_check_str); + } + var_value_checker = std::make_unique( + config.comp_graph.get(), switch_interval, init_idx); + } + +#if MGB_ENABLE_JSON + + if (!profile_path.empty()) { + if (!enable_profile_host) { + mgb_log_warn("enable profiling"); + } else { + mgb_log_warn("enable profiling for host"); + } + model->set_profiler(); + } +#endif + } + + else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { +#if MGB_ENABLE_JSON + if (!profile_path.empty()) { + mgb_log_warn("filename %s", profile_path.c_str()); + if (model->get_profiler()) { + model->get_profiler() + ->to_json_full(model->get_async_func().get()) + ->writeto_fpath(profile_path); + mgb_log_warn("profiling result written to %s", profile_path.c_str()); + } + } +#endif + } +} + +} // namespace lar + +using namespace lar; +PluginOption::PluginOption() { + m_option_name = "plugin"; + range = FLAGS_range; + enable_check_dispatch = FLAGS_check_dispatch; + var_value_check_str = FLAGS_check_var_value; +#if MGB_ENABLE_JSON + enable_profile_host = false; + if (!FLAGS_profile.empty()) { + profile_path = FLAGS_profile; + } + if (!FLAGS_profile_host.empty()) { + enable_profile_host = !FLAGS_profile_host.empty(); + profile_path = FLAGS_profile_host; + } +#endif +} + +bool PluginOption::is_valid() { + bool ret = FLAGS_check_dispatch; + ret = ret || FLAGS_range > 0; + ret = ret || !FLAGS_check_var_value.empty(); +#if MGB_ENABLE_JSON + ret = ret || !FLAGS_profile.empty(); + ret = ret || !FLAGS_profile_host.empty(); +#endif + return ret; +} + +std::shared_ptr PluginOption::create_option() { + static std::shared_ptr option(new PluginOption); + if (PluginOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void PluginOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} + +///////////////////// Debug options/////////////////////////// +namespace lar { +template <> +void DebugOption::format_and_print( + const std::string& tablename, std::shared_ptr model) { + auto table = mgb::TextTable(tablename); + auto network = model->get_lite_network(); + table.padding(1); + table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); + + auto to_string = [&](lite::Layout& layout) { + std::string shape("{"); + for (size_t i = 0; i < layout.ndim; i++) { + if (i) + shape.append(","); + shape.append(std::to_string(layout.shapes[i])); + } + shape.append("}"); + return shape; + }; + + auto input_name = network->get_all_input_name(); + for (auto& i : input_name) { + auto layout = network->get_io_tensor(i)->get_layout(); + table.align(mgb::TextTable::Align::Mid) + .add("INPUT") + .add(i) + .add(to_string(layout)) + .eor(); + } + + auto output_name = network->get_all_output_name(); + for (auto& i : output_name) { + auto layout = network->get_io_tensor(i)->get_layout(); + table.align(mgb::TextTable::Align::Mid) + .add("OUTPUT") + .add(i) + .add(to_string(layout)) + .eor(); + } + + std::stringstream ss; + ss << table; + printf("%s\n\n", ss.str().c_str()); +} + +template <> +void DebugOption::format_and_print( + const std::string& tablename, std::shared_ptr model) { + auto table = mgb::TextTable(tablename); + table.padding(1); + table.align(mgb::TextTable::Align::Mid).add("type").add("name").add("shape").eor(); + + for (auto&& i : model->get_mdl_load_result().tensor_map) { + table.align(mgb::TextTable::Align::Mid) + .add("INPUT") + .add(i.first) + .add(i.second->shape().to_string()) + .eor(); + } + + for (auto&& i : model->get_mdl_load_result().output_var_list) { + table.align(mgb::TextTable::Align::Mid) + .add("OUTPUT") + .add(i.node()->name()) + .add(i.shape().to_string()) + .eor(); + } + + std::stringstream ss; + ss << table; + printf("%s\n\n", ss.str().c_str()); +} + +template <> +void DebugOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + LITE_ASSERT( + !disable_assert_throw, "lite model don't support disable assert throw"); +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON + LITE_ASSERT( + static_mem_log_dir_path.empty(), + "lite model don't support static memory information export"); +#endif +#endif + if (enable_verbose) { + LITE_WARN("enable verbose"); + lite::set_log_level(LiteLogLevel::DEBUG); + } + +#if __linux__ || __unix__ + if (enable_wait_gdb) { + printf("wait for gdb attach (pid=%d): ", getpid()); + getchar(); + } +#endif + } else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) { + if (enable_display_model_info) { + LITE_WARN("enable display model information"); + format_and_print("Runtime Model Info", model); + } + } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { + if (enable_display_model_info) { + format_and_print("Runtime Model Info", model); + } + } +} + +template <> +void DebugOption::config_model_internel( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + auto config = model->get_mdl_config(); + if (enable_verbose) { + mgb_log_warn("enable verbose"); + mgb::set_log_level(mgb::LogLevel::DEBUG); + } + +#if __linux__ || __unix__ + if (enable_wait_gdb) { + printf("wait for gdb attach (pid=%d): ", getpid()); + getchar(); + } +#endif + } else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { + if (enable_display_model_info) { + mgb_log_warn("enable display model information"); + format_and_print("Runtime Model Info", model); + } + + if (disable_assert_throw) { + mgb_log_warn("disable assert throw"); + auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { + if (opr->same_type()) { + opr->cast_final().disable_throw_on_error(); + } + }; + mgb::cg::DepOprIter iter{on_opr}; + for (auto&& i : model->get_output_spec()) { + iter.add(i.first.node()->owner_opr()); + } + } + } else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) { + //! FIX:it don't work for cpu build (nothing dumped) + //! megbrain/sdk origin code will assert(m_recorded) in + //! EventImplHelper::finished(); + +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON + if (!static_mem_log_dir_path.empty()) { + mgb_log_warn("enable get static memeory information"); + model->get_async_func()->get_static_memory_alloc_info( + static_mem_log_dir_path); + } +#endif +#endif + } else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { + if (enable_display_model_info) { + format_and_print("Runtime Model Info", model); + } + } +} + +} // namespace lar + +DebugOption::DebugOption() { + m_option_name = "debug"; + enable_display_model_info = FLAGS_model_info; + enable_verbose = FLAGS_verbose; + disable_assert_throw = FLAGS_disable_assert_throw; +#if __linux__ || __unix__ + enable_wait_gdb = FLAGS_wait_gdb; +#endif +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON + static_mem_log_dir_path = FLAGS_get_static_mem_info; +#endif +#endif +} + +bool DebugOption::is_valid() { + bool ret = FLAGS_model_info; + ret = ret || FLAGS_verbose; + ret = ret || FLAGS_disable_assert_throw; + +#if __linux__ || __unix__ + ret = ret || FLAGS_wait_gdb; +#endif +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON + ret = ret || !FLAGS_get_static_mem_info.empty(); +#endif +#endif + return ret; +} + +std::shared_ptr DebugOption::create_option() { + static std::shared_ptr option(new DebugOption); + if (DebugOption::is_valid()) { + return std::static_pointer_cast(option); + } else { + return nullptr; + } +} + +void DebugOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + CONFIG_MODEL_FUN; +} +///////////////////// Plugin gflags/////////////////////////// +DEFINE_double( + range, 0, + "check whether absolute value of all numbers in computing graph " + "is in the given range"); + +DEFINE_bool( + check_dispatch, false, + "check whether an operator call dispatch on cpu comp nodes"); + +DEFINE_string( + check_var_value, "", + "--check-var-value [interval]|[interval:init_idx], Enable " + "VarValueChecker plugin. Refer to its doc for more details"); +#if MGB_ENABLE_JSON +DEFINE_string( + profile, "", + "Write profiling result to given file. The output file is in " + "JSON format"); +DEFINE_string(profile_host, "", "focus on host time profiling For some backends"); +#endif + +///////////////////// Debug gflags/////////////////////////// +DEFINE_bool( + model_info, false, + " Format and display model input/output tensor inforamtion"); + +DEFINE_bool(verbose, false, "get more inforamtion for debug"); + +DEFINE_bool(disable_assert_throw, false, "disable assert throw on error check"); +#if __linux__ || __unix__ +DEFINE_bool(wait_gdb, false, "print current process PID and wait for gdb attach"); +#endif +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON +DEFINE_string( + get_static_mem_info, "", + "Record the static computing graph's static memory information"); +#endif +#endif +REGIST_OPTION_CREATOR(plugin, lar::PluginOption::create_option); + +REGIST_OPTION_CREATOR(debug, lar::DebugOption::create_option); \ No newline at end of file diff --git a/lite/load_and_run/src/options/plugin_options.h b/lite/load_and_run/src/options/plugin_options.h new file mode 100644 index 000000000..b8822d02f --- /dev/null +++ b/lite/load_and_run/src/options/plugin_options.h @@ -0,0 +1,105 @@ +/** + * \file lite/load_and_run/src/options/plugin_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#if __linux__ || __unix__ +#include +#endif +#include "megbrain/plugin/cpu_dispatch_checker.h" +#include "megbrain/plugin/var_value_checker.h" + +#include "helpers/common.h" +#include "helpers/text_table.h" +#include "models/model.h" + +#include "option_base.h" + +DECLARE_bool(check_dispatch); +DECLARE_double(range); +DECLARE_string(check_var_value); +#if MGB_ENABLE_JSON +DECLARE_string(profile); +DECLARE_string(profile_host); +#endif + +DECLARE_bool(model_info); +DECLARE_bool(verbose); +DECLARE_bool(disable_assert_throw); +#if __linux__ || __unix__ +DECLARE_bool(wait_gdb); +#endif +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON +DECLARE_string(get_static_mem_info); +#endif +#endif + +namespace lar { +class PluginOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + PluginOption(); + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + double range; + bool enable_check_dispatch; +#if MGB_ENABLE_JSON + bool enable_profile_host; + std::string profile_path; +#endif + + std::string var_value_check_str; + + std::string m_option_name; + + std::unique_ptr var_value_checker; + std::unique_ptr cpu_dispatch_checker; +}; + +class DebugOption final : public OptionBase { +public: + static bool is_valid(); + + static std::shared_ptr create_option(); + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + std::string option_name() const override { return m_option_name; }; + +private: + DebugOption(); + template + void format_and_print(const std::string&, std::shared_ptr){}; + template + void config_model_internel(RuntimeParam&, std::shared_ptr){}; + bool enable_display_model_info; + bool enable_verbose; + bool disable_assert_throw; +#if __linux__ || __unix__ + bool enable_wait_gdb; +#endif +#ifndef __IN_TEE_ENV__ +#if MGB_ENABLE_JSON + std::string static_mem_log_dir_path; +#endif +#endif + std::string m_option_name; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/load_and_run/src/options/strategy_options.cpp b/lite/load_and_run/src/options/strategy_options.cpp new file mode 100644 index 000000000..0e08e8851 --- /dev/null +++ b/lite/load_and_run/src/options/strategy_options.cpp @@ -0,0 +1,96 @@ +/** + * \file lite/load_and_run/src/options/strategy_options.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "strategy_options.h" +#include "models/model_mdl.h" + +using namespace lar; + +DECLARE_bool(c_opr_lib_with_param); + +StrategyOption::StrategyOption() { + m_option_name = "run_strategy"; + warmup_iter = FLAGS_warmup_iter; + run_iter = FLAGS_iter; + threads = FLAGS_thread; +} + +std::shared_ptr StrategyOption::create_option() { + static std::shared_ptr option(new StrategyOption); + return std::static_pointer_cast(option); +} + +void StrategyOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { + model->set_shared_mem(FLAGS_share_param_mem); + runtime_param.warmup_iter = warmup_iter; + runtime_param.run_iter = run_iter; + runtime_param.threads = threads; + runtime_param.testcase_num = 1; + } else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) { + if (model->type() == ModelType::MEGDL_MODEL) { + auto model_ptr = std::static_pointer_cast(model); + auto num = model_ptr->get_testcase_num(); + if (num != 0) + runtime_param.testcase_num = num; + + model_ptr->make_output_spec(); + } + } +} + +TestcaseOption::TestcaseOption() { + m_option_name = "run_testcase"; +} + +std::shared_ptr TestcaseOption::create_option() { + static std::shared_ptr option(new TestcaseOption); + return std::static_pointer_cast(option); +} + +void TestcaseOption::config_model( + RuntimeParam& runtime_param, std::shared_ptr model) { + if (model->type() == ModelType::MEGDL_MODEL) { + auto model_ptr = std::static_pointer_cast(model); + if (model_ptr->get_testcase_num() && !FLAGS_c_opr_lib_with_param) { + if (runtime_param.stage == RunStage::MODEL_RUNNING) { + auto load_result = model_ptr->get_mdl_load_result(); + auto input_tensor = model_ptr->get_test_input(); + auto loader = model_ptr->reset_loader(); + auto testcase = loader->load(model_ptr->get_mdl_config(), false); + mgb_assert(testcase.output_var_list.size() == input_tensor.size()); + for (size_t i = 0; i < input_tensor.size(); ++i) { + auto&& opr = + testcase.output_var_list[i] + .node() + ->owner_opr() + ->cast_final_safe(); + input_tensor[i].second->copy_from( + mgb::HostTensorND::make_proxy(*opr.dev_data())); + } + } + } + } +} + +DEFINE_int32(iter, 10, "iteration number for run model"); + +DEFINE_int32(warmup_iter, 1, "iteration number for warm up model before run"); + +DEFINE_int32( + thread, 1, + "thread number for run model while is supported( NOTE: " + "this is not a mapper device setting just for load and run)"); + +DEFINE_bool(share_param_mem, false, "load model from shared memeory"); + +REGIST_OPTION_CREATOR(run_strategy, lar::StrategyOption::create_option); + +REGIST_OPTION_CREATOR(run_testcase, lar::TestcaseOption::create_option); \ No newline at end of file diff --git a/lite/load_and_run/src/options/strategy_options.h b/lite/load_and_run/src/options/strategy_options.h new file mode 100644 index 000000000..338bcb8d8 --- /dev/null +++ b/lite/load_and_run/src/options/strategy_options.h @@ -0,0 +1,68 @@ +/** + * \file lite/load_and_run/src/options/strategy_options.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include +#include "models/model.h" +#include "option_base.h" +DECLARE_int32(iter); +DECLARE_int32(warmup_iter); +DECLARE_int32(thread); +DECLARE_bool(share_param_mem); + +namespace lar { +/*! + * \brief: strategy option for running model + */ +class StrategyOption final : public OptionBase { +public: + //! creat options when option is used + static std::shared_ptr create_option(); + + //! config the model, dispatch configuration for different model implement + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + //! get option name + std::string option_name() const override { return m_option_name; }; + +private: + //! Constructor + StrategyOption(); + + //! configuration for different model implement + std::string m_option_name; + + size_t warmup_iter; //! warm up number before running model + size_t run_iter; //! iteration number for running model + size_t threads; //! thread number for running model (NOTE:it's different + //! from multithread device ) +}; + +class TestcaseOption final : public OptionBase { +public: + //! creat options when option is used + static std::shared_ptr create_option(); + + //! config the model, dispatch configuration for different model implement + + void config_model( + RuntimeParam& runtime_param, std::shared_ptr model) override; + + //! get option name + std::string option_name() const override { return m_option_name; }; + +private: + //! Constructor + TestcaseOption(); + + //! configuration for different model implement + std::string m_option_name; +}; +} // namespace lar \ No newline at end of file diff --git a/lite/load_and_run/src/strategys/strategy.cpp b/lite/load_and_run/src/strategys/strategy.cpp new file mode 100644 index 000000000..fadbe8d96 --- /dev/null +++ b/lite/load_and_run/src/strategys/strategy.cpp @@ -0,0 +1,24 @@ + +/** + * \file lite/load_and_run/src/strategys/strategy.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "strategy.h" +#include + +using namespace lar; + +std::shared_ptr StrategyBase::create_strategy(std::string model_path) { + if (FLAGS_fitting) { + return std::make_shared(model_path); + } else { + return std::make_shared(model_path); + } +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} \ No newline at end of file diff --git a/lite/load_and_run/src/strategys/strategy.h b/lite/load_and_run/src/strategys/strategy.h new file mode 100644 index 000000000..321a7cd7b --- /dev/null +++ b/lite/load_and_run/src/strategys/strategy.h @@ -0,0 +1,63 @@ +/** + * \file lite/load_and_run/src/strategys/strategy.h + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#pragma once +#include +#include +#include +#include "helpers/common.h" +#include "models/model.h" +#include "options/option_base.h" + +DECLARE_bool(fitting); + +namespace lar { +/*! + * \brief: load and run strategy base class + */ +class StrategyBase { +public: + static std::shared_ptr create_strategy(std::string model_path); + + virtual void run() = 0; + + virtual ~StrategyBase() = default; + + RuntimeParam m_runtime_param; + std::unordered_map> m_options; +}; + +/*! + * \brief: normal strategy for running + */ +class NormalStrategy : public StrategyBase { +public: + NormalStrategy(std::string model_path); + + //! run model with runtime parameter + void run() override; + +private: + //! run model subline for multiple thread + void run_subline(); + + std::string m_model_path; +}; + +/*! + * \brief: Fitting strategy for running + */ +class FittingStrategy : public StrategyBase { +public: + FittingStrategy(std::string model_path); + void run() override; +}; +} // namespace lar + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/lite/load_and_run/src/strategys/strategy_fitting.cpp b/lite/load_and_run/src/strategys/strategy_fitting.cpp new file mode 100644 index 000000000..ffa39884a --- /dev/null +++ b/lite/load_and_run/src/strategys/strategy_fitting.cpp @@ -0,0 +1,24 @@ +/** + * \file lite/load_and_run/src/strategys/strategy_fitting.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ + +#include "strategy.h" +using namespace lar; + +FittingStrategy::FittingStrategy(std::string) { + mgb_assert("this version don't support Fitting Strategy"); +}; + +void FittingStrategy::run() { + mgb_assert("this version don't support Fitting Strategy"); +}; + +DEFINE_bool( + fitting, false, + "whether to use the fitting model, which will auto profile and get " + "the best option set!"); \ No newline at end of file diff --git a/lite/load_and_run/src/strategys/strategy_normal.cpp b/lite/load_and_run/src/strategys/strategy_normal.cpp new file mode 100644 index 000000000..923cae7ca --- /dev/null +++ b/lite/load_and_run/src/strategys/strategy_normal.cpp @@ -0,0 +1,167 @@ +/** + * \file lite/load_and_run/src/strategys/strategy_normal.cpp + * + * This file is part of MegEngine, a deep learning framework developed by + * Megvii. + * + * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. + */ +#include +#include +#include "megbrain/common.h" +#include "megbrain/utils/timer.h" +#include "megbrain/version.h" +#include "megdnn/version.h" +#include "misc.h" +#include "strategy.h" + +using namespace lar; + +NormalStrategy::NormalStrategy(std::string model_path) { + mgb::set_log_level(mgb::LogLevel::WARN); + lite::set_log_level(LiteLogLevel::WARN); + m_model_path = model_path; + auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map(); + mgb_log_debug("option map size: %lu", option_creator_map->size()); + auto construct_option = [&](std::string name) -> void { + auto& creator = (*option_creator_map)[name]; + auto option = creator(); + if (option) { + m_options.insert({name, option}); + } + }; + + for (auto& creator : *option_creator_map) { + auto name = creator.first; + if (m_options.count(name) == 0) { + construct_option(name); + } + } +} + +void NormalStrategy::run_subline() { + auto model = ModelBase::create_model(m_model_path); + mgb_assert(model != nullptr, "create model failed!!"); + + auto stage_config_model = [&]() { + for (auto& option : m_options) { + option.second->config_model(m_runtime_param, model); + } + }; + //! execute before load config + m_runtime_param.stage = RunStage::BEFORE_MODEL_LOAD; + stage_config_model(); + + mgb::RealTimer timer; + model->load_model(); + printf("load model: %.3fms\n", timer.get_msecs_reset()); + + //! after load configure + m_runtime_param.stage = RunStage::AFTER_MODEL_LOAD; + stage_config_model(); + + m_runtime_param.stage = RunStage::BEFORE_OUTSPEC_SET; + stage_config_model(); + + // for get static memmory information options + m_runtime_param.stage = RunStage::AFTER_OUTSPEC_SET; + stage_config_model(); + + auto warm_up = [&]() { + auto warmup_num = m_runtime_param.warmup_iter; + for (size_t i = 0; i < warmup_num; i++) { + printf("=== prepare: %.3fms; going to warmup\n\n", timer.get_msecs_reset()); + model->run_model(); + model->wait(); + printf("warm up %lu %.3fms\n", i, timer.get_msecs_reset()); + m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; + stage_config_model(); + } + }; + + auto run_iter = [&](int idx) { + double time_sqrsum = 0, time_sum = 0, + min_time = std::numeric_limits::max(), max_time = 0; + auto run_num = m_runtime_param.run_iter; + for (size_t i = 0; i < run_num; i++) { + timer.reset(); + model->run_model(); + auto exec_time = timer.get_msecs(); + model->wait(); + m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT; + stage_config_model(); + auto cur = timer.get_msecs(); + printf("iter %lu/%lu: %.3fms (exec=%.3fms)\n", i, run_num, cur, exec_time); + time_sum += cur; + time_sqrsum += cur * cur; + fflush(stdout); + min_time = std::min(min_time, cur); + max_time = std::max(max_time, cur); + } + printf("\n=== finished test #%u: time=%.3fms avg_time=%.3fms " + "sexec=%.3fms min=%.3fms max=%.3fms\n\n", + idx, time_sum, time_sum / run_num, + std::sqrt( + (time_sqrsum * run_num - time_sum * time_sum) / + (run_num * (run_num - 1))), + min_time, max_time); + return time_sum; + }; + + //! model with testcase + size_t iter_num = m_runtime_param.testcase_num; + + double tot_time = 0; + for (size_t idx = 0; idx < iter_num; idx++) { + //! config when running model + mgb_log_warn("run testcase: %zu ", idx); + m_runtime_param.stage = RunStage::MODEL_RUNNING; + stage_config_model(); + + if (!idx) { + warm_up(); + } + tot_time += run_iter(idx); + + m_runtime_param.stage = RunStage::AFTER_RUNNING_ITER; + stage_config_model(); + } + + printf("=== total time: %.3fms\n", tot_time); + //! execute after run + m_runtime_param.stage = RunStage::AFTER_MODEL_RUNNING; + stage_config_model(); +}; + +void NormalStrategy::run() { + auto v0 = mgb::get_version(); + auto v1 = megdnn::get_version(); + printf("megbrain/lite/load_and_run:\nusing MegBrain " + "%d.%d.%d(%d) and MegDNN %d.%d.%d\n", + v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch); + + size_t thread_num = m_runtime_param.threads; + auto run_sub = [&]() { run_subline(); }; + if (thread_num == 1) { + run_sub(); + } else if (thread_num > 1) { +#if MGB_HAVE_THREAD + std::vector threads; + + for (size_t i = 0; i < thread_num; ++i) { + threads.emplace_back(run_sub); + } + for (auto&& i : threads) { + i.join(); + } +#else + mgb_log_error( + "%d threads requested, but load_and_run was compiled " + "without support.", + thread_num); +#endif + } else { + mgb_assert(false, "--thread must input a positive number!!"); + } + //! execute before run +} \ No newline at end of file -- GitLab