提交 5b6ebeb5 编写于 作者: M Megvii Engine Team

fix(mgb): append json file for dump and ready for midout open source

GitOrigin-RevId: 71ae7f1f4aa7300fdf941b391926a5b06f513979
上级 a81abc1d
...@@ -31,3 +31,6 @@ ...@@ -31,3 +31,6 @@
[submodule "third_party/pybind11"] [submodule "third_party/pybind11"]
path = third_party/pybind11 path = third_party/pybind11
url = https://github.com/pybind/pybind11.git url = https://github.com/pybind/pybind11.git
[submodule "third_party/midout"]
path = third_party/midout
url = https://github.com/MegEngine/midout.git
...@@ -30,6 +30,8 @@ set (MGE_EXPORT_TARGETS MegEngine-targets) ...@@ -30,6 +30,8 @@ set (MGE_EXPORT_TARGETS MegEngine-targets)
option(MGE_WITH_JIT "Build MegEngine with JIT." ON) option(MGE_WITH_JIT "Build MegEngine with JIT." ON)
option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON) option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON)
option(MGE_WITH_MIDOUT_PROFILE "Build MegEngine with Midout profile." OFF)
option(MGE_WITH_MINIMUM_SIZE "Swith off MGE_ENABLE_RTTI、MGE_ENABLE_EXCEPTIONS、MGE_ENABLE_LOGGING and switch on MGE_INFERENCE_ONLY so that compile minimum load_and_run. Take effect only when MGE_BIN_REDUCE was set" OFF)
option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF) option(MGE_ARMV8_2_FEATURE_FP16 "Enable armv8.2-a+fp16 support" OFF)
option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF) option(MGE_ARMV8_2_FEATURE_DOTPROD "enable armv8.2-a+dotprod support" OFF)
option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF) option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF)
...@@ -53,6 +55,26 @@ option(MGE_INFERENCE_ONLY "Build inference only library." OFF) ...@@ -53,6 +55,26 @@ option(MGE_INFERENCE_ONLY "Build inference only library." OFF)
option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON)
option(MGE_WITH_ROCM "Enable ROCM support" OFF) option(MGE_WITH_ROCM "Enable ROCM support" OFF)
if(NOT ${MGE_BIN_REDUCE} STREQUAL "")
message("build with BIN REDUCE")
if(MGE_WITH_MINIMUM_SIZE)
set(MGE_ENABLE_RTTI OFF)
set(MGE_ENABLE_LOGGING OFF)
set(MGE_ENABLE_EXCEPTIONS OFF)
set(MGE_INFERENCE_ONLY ON)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -include ${MGE_BIN_REDUCE}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -include ${MGE_BIN_REDUCE}")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -flto=full")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -flto=full")
endif()
if(MGE_WITH_MIDOUT_PROFILE)
message("build with MIDOUT PROFILE")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMIDOUT_PROFILING")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DMIDOUT_PROFILING")
endif()
if (APPLE) if (APPLE)
set (BUILD_SHARED_LIBS OFF) set (BUILD_SHARED_LIBS OFF)
message("build static for xcode framework require") message("build static for xcode framework require")
...@@ -235,7 +257,7 @@ if(NOT MGE_ENABLE_RTTI) ...@@ -235,7 +257,7 @@ if(NOT MGE_ENABLE_RTTI)
endif() endif()
if(NOT MGE_ENABLE_EXCEPTIONS) if(NOT MGE_ENABLE_EXCEPTIONS)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exception") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exceptions")
endif() endif()
if(MGE_WITH_TEST) if(MGE_WITH_TEST)
...@@ -297,7 +319,7 @@ if(MGE_WITH_CUDA) ...@@ -297,7 +319,7 @@ if(MGE_WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-rtti") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-rtti")
endif() endif()
if(NOT MGE_ENABLE_EXCEPTIONS) if(NOT MGE_ENABLE_EXCEPTIONS)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-exception") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-exceptions")
endif() endif()
if(NOT MGE_CUDA_GENCODE) if(NOT MGE_CUDA_GENCODE)
......
...@@ -36,6 +36,9 @@ if(NOT ${MGE_ARCH} STREQUAL "naive") ...@@ -36,6 +36,9 @@ if(NOT ${MGE_ARCH} STREQUAL "naive")
endif() endif()
endif() endif()
if(MGE_WITH_MIDOUT_PROFILE)
list(APPEND SOURCES ${PROJECT_SOURCE_DIR}/third_party/midout/src/midout.cpp)
endif()
############################################################################### ###############################################################################
# HIP_COMPILE # HIP_COMPILE
......
...@@ -25,6 +25,9 @@ if(MGE_WITH_CUDA) ...@@ -25,6 +25,9 @@ if(MGE_WITH_CUDA)
list(APPEND SOURCES ${CUSOURCES}) list(APPEND SOURCES ${CUSOURCES})
endif() endif()
if(MGE_WITH_MIDOUT_PROFILE)
list(APPEND SOURCES ${PROJECT_SOURCE_DIR}/third_party/midout/src/midout.cpp)
endif()
if(MGE_WITH_CAMBRICON) if(MGE_WITH_CAMBRICON)
file(GLOB_RECURSE SOURCES_ cambricon/*.cpp) file(GLOB_RECURSE SOURCES_ cambricon/*.cpp)
......
...@@ -119,6 +119,7 @@ function cmake_build() { ...@@ -119,6 +119,7 @@ function cmake_build() {
mkdir -p $BUILD_DIR mkdir -p $BUILD_DIR
mkdir -p $INSTALL_DIR mkdir -p $INSTALL_DIR
cd $BUILD_DIR cd $BUILD_DIR
unset IFS
cmake -G "$MAKEFILE_TYPE Makefiles" \ cmake -G "$MAKEFILE_TYPE Makefiles" \
-DCMAKE_TOOLCHAIN_FILE="$NDK_ROOT/build/cmake/android.toolchain.cmake" \ -DCMAKE_TOOLCHAIN_FILE="$NDK_ROOT/build/cmake/android.toolchain.cmake" \
-DANDROID_NDK="$NDK_ROOT" \ -DANDROID_NDK="$NDK_ROOT" \
......
...@@ -471,8 +471,11 @@ def main(): ...@@ -471,8 +471,11 @@ def main():
assert not testcase, 'extra inputs provided in testcase: {}'.format( assert not testcase, 'extra inputs provided in testcase: {}'.format(
testcase.keys() testcase.keys()
) )
mgb.serialize_comp_graph_to_file(args.output, output_mgbvars, append=True) mgb.serialize_comp_graph_to_file(
args.output,
output_mgbvars,
append=True,
output_strip_info=args.output_strip_info)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Subproject commit 3b8ae875a9e5c95031aca5edcc4233051d774eb5
...@@ -15,6 +15,7 @@ git submodule foreach --recursive git reset --hard ...@@ -15,6 +15,7 @@ git submodule foreach --recursive git reset --hard
git submodule foreach --recursive git clean -fd git submodule foreach --recursive git clean -fd
git submodule update --init midout
git submodule update --init intel-mkl-dnn git submodule update --init intel-mkl-dnn
git submodule update --init Halide git submodule update --init Halide
git submodule update --init protobuf git submodule update --init protobuf
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import sys
import re
if sys.version_info[0] != 3 or sys.version_info[1] < 5:
print('This script requires Python version 3.5')
sys.exit(1)
import argparse
import json
import os
import subprocess
import tempfile
from pathlib import Path
MIDOUT_TRACE_MAGIC = 'midout_trace v1\n'
class HeaderGen:
_dtypes = None
_oprs = None
_fout = None
_elemwise_modes = None
_has_netinfo = False
_midout_files = None
_file_without_hash = False
def __init__(self):
self._dtypes = set()
self._oprs = set()
self._elemwise_modes = set()
self._graph_hashes = set()
self._midout_files = []
_megvii3_root_cache = None
@classmethod
def get_megvii3_root(cls):
if cls._megvii3_root_cache is not None:
return cls._megvii3_root_cache
wd = Path(__file__).resolve().parent
while wd.parent != wd:
workspace_file = wd / 'WORKSPACE'
if workspace_file.is_file():
cls._megvii3_root_cache = str(wd)
return cls._megvii3_root_cache
wd = wd.parent
raise RuntimeError('This script is supposed to run in megvii3.')
def extend_netinfo(self, data):
self._has_netinfo = True
if 'hash' not in data:
self._file_without_hash = True
else:
self._graph_hashes.add(str(data['hash']))
for i in data['dtypes']:
self._dtypes.add(i)
for i in data['opr_types']:
self._oprs.add(i)
for i in data['elemwise_modes']:
self._elemwise_modes.add(i)
def extend_midout(self, fname):
self._midout_files.append(fname)
def generate(self, fout):
self._fout = fout
self._write_def('MGB_BINREDUCE_VERSION', '20190219')
if self._has_netinfo:
self._write_dtype()
self._write_elemwise_modes()
self._write_oprs()
self._write_hash()
self._write_midout()
del self._fout
def strip_opr_name_with_version(self, name):
pos = len(name)
t = re.search(r'V\d+$', name)
if t:
pos = t.start()
return name[:pos]
def _write_oprs(self):
defs = ['}', 'namespace opr {']
already_declare = set()
already_instance = set()
for i in self._oprs:
i = self.strip_opr_name_with_version(i)
if i in already_declare:
continue
else:
already_declare.add(i)
defs.append('class {};'.format(i))
defs.append('}')
defs.append('namespace serialization {')
defs.append("""
template<class Opr, class Callee>
struct OprRegistryCaller {
}; """)
for i in sorted(self._oprs):
i = self.strip_opr_name_with_version(i)
if i in already_instance:
continue
else:
already_instance.add(i)
defs.append("""
template<class Callee>
struct OprRegistryCaller<opr::{}, Callee>: public
OprRegistryCallerDefaultImpl<Callee> {{
}}; """.format(i))
self._write_def('MGB_OPR_REGISTRY_CALLER_SPECIALIZE', defs)
def _write_elemwise_modes(self):
with tempfile.NamedTemporaryFile() as ftmp:
fpath = os.path.realpath(ftmp.name)
subprocess.check_call(
['./brain/megbrain/dnn/scripts/gen_param_defs.py',
'--write-enum-items', 'Elemwise:Mode',
'./brain/megbrain/dnn/scripts/opr_param_defs.py',
fpath],
cwd=self.get_megvii3_root()
)
with open(fpath) as fin:
mode_list = [i.strip() for i in fin]
for i in mode_list:
if i in self._elemwise_modes:
content = '_cb({})'.format(i)
else:
content = ''
self._write_def(
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_{}(_cb)'.format(i), content)
self._write_def('MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb)',
'_MEGDNN_ELEMWISE_MODE_ENABLE_IMPL_##_mode(_cb)')
def _write_dtype(self):
if 'Float16' not in self._dtypes:
# MegBrain/MegDNN used MEGDNN_DISABLE_FLOT16 to turn off float16
# support in the past; however `FLOT16' is really a typo. We plan to
# change MEGDNN_DISABLE_FLOT16 to MEGDNN_DISABLE_FLOAT16 soon.
# To prevent issues in the transition, we decide to define both
# macros (`FLOT16' and `FLOAT16') here.
#
# In the future when the situation is settled and no one would ever
# use legacy MegBrain/MegDNN, the `FLOT16' macro definition can be
# safely deleted.
self._write_def('MEGDNN_DISABLE_FLOT16', 1)
self._write_def('MEGDNN_DISABLE_FLOAT16', 1)
def _write_hash(self):
if self._file_without_hash:
print('WARNING: network info has no graph hash. Using json file '
'generated by MegBrain >= 7.28.0 is recommended')
else:
defs = 'ULL,'.join(self._graph_hashes) + 'ULL'
self._write_def('MGB_BINREDUCE_GRAPH_HASHES', defs)
def _write_def(self, name, val):
if isinstance(val, list):
val = '\n'.join(val)
val = str(val).strip().replace('\n', ' \\\n')
self._fout.write('#define {} {}\n'.format(name, val))
def _write_midout(self):
if not self._midout_files:
return
gen = os.path.join(self.get_megvii3_root(), 'brain', 'midout',
'gen_header.py')
cvt = subprocess.run(
[gen] + self._midout_files,
stdout=subprocess.PIPE, check=True,
).stdout.decode('utf-8')
self._fout.write('// midout \n')
self._fout.write(cvt)
def main():
parser = argparse.ArgumentParser(
description='generate header file for reducing binary size by '
'stripping unused oprs in a particular network; output file would '
'be written to bin_reduce.h',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'inputs', nargs='+',
help='input files that describe specific traits of the network; '
'can be one of the following:'
' 1. json files generated by '
'megbrain.serialize_comp_graph_to_file() in python; '
' 2. trace files generated by midout library')
parser.add_argument('-o', '--output', help='output file',
default=os.path.join(HeaderGen.get_megvii3_root(),
'utils', 'bin_reduce.h'))
args = parser.parse_args()
gen = HeaderGen()
for i in args.inputs:
print('==== processing {}'.format(i))
with open(i) as fin:
if fin.read(len(MIDOUT_TRACE_MAGIC)) == MIDOUT_TRACE_MAGIC:
gen.extend_midout(i)
else:
fin.seek(0)
gen.extend_netinfo(json.loads(fin.read()))
with open(args.output, 'w') as fout:
gen.generate(fout)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册