From 69e3e3224042da0b4abe9100921613e6c06dad1f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 18 Dec 2020 10:12:14 +0800 Subject: [PATCH] feat(imperative): auto generated opdef header and python binding GitOrigin-RevId: d2f22ad5fe0b15f45afa1ea31af8874e8b18fef9 --- CMakeLists.txt | 14 +- dnn/scripts/gen_tablegen.py | 150 +++ imperative/CMakeLists.txt | 33 +- .../core/autodiff/builtin_op_utils.py | 33 +- .../megengine/core/ops/_internal/__init__.py | 8 - .../megengine/core/ops/_internal/all_ops.py | 10 - .../megengine/core/ops/_internal/enum36.py | 939 ------------------ .../megengine/core/ops/_internal/helper.py | 94 -- .../megengine/core/ops/_internal/misc_ops.py | 194 ---- .../megengine/core/ops/builtin/__init__.py | 15 +- .../megengine/core/tensor/tensor_wrapper.py | 7 +- .../python/megengine/core/tensor/utils.py | 2 +- .../megengine/distributed/functional.py | 46 +- .../python/megengine/functional/elemwise.py | 20 - .../python/megengine/functional/math.py | 16 +- imperative/python/megengine/functional/nn.py | 29 +- .../python/megengine/functional/quantized.py | 12 +- .../python/megengine/functional/tensor.py | 11 +- imperative/python/megengine/jit/tracing.py | 12 +- imperative/python/megengine/module/conv.py | 23 +- .../python/megengine/module/elemwise.py | 1 - .../python/megengine/module/qat/conv.py | 4 +- .../megengine/module/quantized/elemwise.py | 5 +- imperative/python/megengine/utils/profiler.py | 4 +- imperative/python/src/graph_rt.cpp | 4 +- imperative/python/src/imperative_rt.cpp | 3 - imperative/python/src/ops.cpp | 192 +--- .../python/test/unit/core/test_dtype_quant.py | 4 +- .../python/test/unit/core/test_indexing_op.py | 35 +- imperative/python/tools/gen_ops.py | 320 ------ imperative/python/tools/ops.tpl.py | 40 - imperative/src/impl/op_def.cpp | 10 +- imperative/src/impl/op_trait.cpp | 2 +- imperative/src/impl/op_trait.h | 43 +- imperative/src/impl/ops/autogen.cpp | 46 + imperative/src/impl/ops/batch_norm.cpp | 17 +- imperative/src/impl/ops/broadcast.cpp | 6 +- imperative/src/impl/ops/collective_comm.cpp | 35 +- imperative/src/impl/ops/cond_take.cpp | 5 +- imperative/src/impl/ops/elemwise.cpp | 8 +- imperative/src/impl/ops/io_remote.cpp | 42 +- imperative/src/impl/ops/nms.cpp | 4 +- imperative/src/impl/ops/specializations.cpp | 630 ++++++++++++ imperative/src/impl/ops/tensor_manip.cpp | 6 +- imperative/src/impl/profiler.cpp | 14 +- imperative/src/impl/proxy_graph.cpp | 6 +- .../src/include/megbrain/imperative/op_def.h | 23 +- .../imperative/ops/{cond_take.h => autogen.h} | 23 +- .../megbrain/imperative/ops/batch_norm.h | 70 -- .../megbrain/imperative/ops/broadcast.h | 35 - .../megbrain/imperative/ops/collective_comm.h | 69 -- .../megbrain/imperative/ops/elemwise.h | 42 - .../megbrain/imperative/ops/io_remote.h | 77 -- .../src/include/megbrain/imperative/ops/nms.h | 41 - .../megbrain/imperative/ops/tensor_manip.h | 99 -- imperative/src/test/backward_graph.cpp | 16 +- imperative/src/test/collective_comm.cpp | 11 +- imperative/src/test/cond_take.cpp | 2 +- imperative/src/test/helper.cpp | 2 +- imperative/src/test/io_remote.cpp | 23 +- imperative/tablegen/CMakeLists.txt | 14 + imperative/tablegen/autogen.cpp | 383 +++++++ imperative/tablegen/helper.h | 228 +++++ imperative/test/CMakeLists.txt | 2 +- src/core/include/megbrain/ir/base.td | 257 +++++ src/core/include/megbrain/ir/ops.td | 240 +++++ third_party/prepare.sh | 1 + 67 files changed, 2191 insertions(+), 2621 deletions(-) create mode 100755 dnn/scripts/gen_tablegen.py delete mode 100644 imperative/python/megengine/core/ops/_internal/__init__.py delete mode 100644 imperative/python/megengine/core/ops/_internal/all_ops.py delete mode 100644 imperative/python/megengine/core/ops/_internal/enum36.py delete mode 100644 imperative/python/megengine/core/ops/_internal/helper.py delete mode 100644 imperative/python/megengine/core/ops/_internal/misc_ops.py delete mode 100755 imperative/python/tools/gen_ops.py delete mode 100644 imperative/python/tools/ops.tpl.py create mode 100644 imperative/src/impl/ops/autogen.cpp create mode 100644 imperative/src/impl/ops/specializations.cpp rename imperative/src/include/megbrain/imperative/ops/{cond_take.h => autogen.h} (54%) delete mode 100644 imperative/src/include/megbrain/imperative/ops/batch_norm.h delete mode 100644 imperative/src/include/megbrain/imperative/ops/broadcast.h delete mode 100644 imperative/src/include/megbrain/imperative/ops/collective_comm.h delete mode 100644 imperative/src/include/megbrain/imperative/ops/elemwise.h delete mode 100644 imperative/src/include/megbrain/imperative/ops/io_remote.h delete mode 100644 imperative/src/include/megbrain/imperative/ops/nms.h delete mode 100644 imperative/src/include/megbrain/imperative/ops/tensor_manip.h create mode 100644 imperative/tablegen/CMakeLists.txt create mode 100644 imperative/tablegen/autogen.cpp create mode 100644 imperative/tablegen/helper.h create mode 100644 src/core/include/megbrain/ir/base.td create mode 100644 src/core/include/megbrain/ir/ops.td diff --git a/CMakeLists.txt b/CMakeLists.txt index 7bf64ba35..b9e7a9490 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -230,6 +230,10 @@ endif() # FIXME At present, there are some conflicts between the LLVM that halide # depends on and the LLVM that MLIR depends on. Should be fixed in subsequent # versions. +if(MGE_BUILD_IMPERATIVE_RT) + set(MGE_WITH_HALIDE OFF) + message(WARNING "cannot use HALIDE when building IMPERATIVE_RT") +endif() if(MGE_WITH_JIT_MLIR) if(MGE_WITH_HALIDE) message(FATAL_ERROR "please set MGE_WITH_HALIDE to OFF with MGE_WITH_JIT_MLIR enabled") @@ -310,7 +314,7 @@ if(MGE_INFERENCE_ONLY) set(MGE_BUILD_IMPERATIVE_RT OFF) endif() -if(MGE_WITH_JIT_MLIR) +if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) include(cmake/llvm-project.cmake) endif() @@ -750,7 +754,7 @@ target_include_directories(mgb_opr_param_defs add_dependencies(mgb_opr_param_defs _mgb_opr_param_defs) install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) -if(MGE_WITH_JIT_MLIR) +if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) # generate param_defs.td set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir) @@ -800,12 +804,6 @@ if(TARGET _imperative_rt) COMMAND ${CMAKE_COMMAND} -E create_symlink ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/$ ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/$ - COMMAND ${CMAKE_COMMAND} -E create_symlink - ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py - ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py - COMMAND ${CMAKE_COMMAND} -E create_symlink - ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py - ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py DEPENDS _imperative_rt VERBATIM ) diff --git a/dnn/scripts/gen_tablegen.py b/dnn/scripts/gen_tablegen.py new file mode 100755 index 000000000..751e75789 --- /dev/null +++ b/dnn/scripts/gen_tablegen.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import collections +import textwrap +import os +import hashlib +import struct +import io + +from gen_param_defs import member_defs, ParamDef, IndentWriterBase + + +class ConverterWriter(IndentWriterBase): + _skip_current_param = False + _last_param = None + _current_tparams = None + _packed = None + _const = None + + def __call__(self, fout, defs): + super().__call__(fout) + self._write("// %s", self._get_header()) + self._write("#ifndef MGB_PARAM") + self._write("#define MGB_PARAM") + self._process(defs) + self._write("#endif // MGB_PARAM") + + def _ctype2attr(self, ctype, value): + if ctype == 'uint32_t': + return 'MgbUI32Attr', value + if ctype == 'uint64_t': + return 'MgbUI64Attr', value + if ctype == 'int32_t': + return 'MgbI32Attr', value + if ctype == 'float': + return 'MgbF32Attr', value + if ctype == 'double': + return 'MgbF64Attr', value + if ctype == 'bool': + return 'MgbBoolAttr', value + if ctype == 'DTypeEnum': + self._packed = False + return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value) + raise RuntimeError("unknown ctype") + + def _on_param_begin(self, p): + self._last_param = p + if p.is_legacy: + self._skip_current_param = True + return + self._packed = True + self._current_tparams = [] + self._const = set() + + def _on_param_end(self, p): + if self._skip_current_param: + self._skip_current_param = False + return + if self._packed: + self._write("class {0}ParamBase : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1) + else: + self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1) + self._write("let fields = (ins", indent=1) + self._write(",\n{}".format(self._cur_indent).join(self._current_tparams)) + self._write(");", indent=-1) + self._write("}\n", indent=-1) + if self._packed: + self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name)) + self._current_tparams = None + self._packed = None + self._const = None + + def _wrapped_with_default_value(self, attr, default): + return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default) + + def _on_member_enum(self, e): + p = self._last_param + + # Note: always generate llvm Record def for enum attribute even it was not + # directly used by any operator, or other enum couldn't alias to this enum + td_class = "{}{}".format(p.name, e.name) + fullname = "::megdnn::param::{}".format(p.name) + enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name) + def format(v): + return '\"{}\"'.format(str(v)) + enum_def += ','.join(format(i) for i in e.members) + enum_def += "]>" + self._write("def {} : {};".format(td_class, enum_def)) + + if self._skip_current_param: + return + + # wrapped with default value + default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default) + wrapped = self._wrapped_with_default_value(td_class, default_val) + + self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) + + def _on_member_enum_alias(self, e): + p = self._last_param + if self._skip_current_param: + return + + # write enum attr def + td_class = "{}{}".format(p.name, e.name) + fullname = "::megdnn::param::{}".format(p.name) + base_td_class = "{}{}".format(e.src_class, e.src_name) + enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class) + self._write("def {} : {};".format(td_class, enum_def)) + + # wrapped with default value + default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default()) + wrapped = self._wrapped_with_default_value(td_class, default_val) + + self._current_tparams.append("{}:${}".format(wrapped, e.name_field)) + + + def _on_member_field(self, f): + if self._skip_current_param: + return + attr, value = self._ctype2attr(f.dtype.cname, str(f.default)) + if str(value) in self._const: + value = '::megdnn::param::{}::{}'.format(self._last_param.name, value) + wrapped = self._wrapped_with_default_value(attr, value) + self._current_tparams.append("{}:${}".format(wrapped, f.name)) + + def _on_const_field(self, f): + self._const.add(str(f.name)) + +def main(): + parser = argparse.ArgumentParser('generate op param tablegen file') + parser.add_argument('input') + parser.add_argument('output') + args = parser.parse_args() + + with open(args.input) as fin: + inputs = fin.read() + exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) + input_hash = hashlib.sha256() + input_hash.update(inputs.encode(encoding='UTF-8')) + input_hash = input_hash.hexdigest() + + writer = ConverterWriter() + with open(args.output, 'w') as fout: + writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) + +if __name__ == "__main__": + main() diff --git a/imperative/CMakeLists.txt b/imperative/CMakeLists.txt index 0e8859ceb..fac4c060d 100644 --- a/imperative/CMakeLists.txt +++ b/imperative/CMakeLists.txt @@ -8,9 +8,7 @@ file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/sr set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") -file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py) -list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py) file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h ${PROJECT_SOURCE_DIR}/src/core/include/* ${PROJECT_SOURCE_DIR}/src/opr/include/* @@ -19,33 +17,8 @@ file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h ${PROJECT_SOURCE_DIR}/dnn/include/*) set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/) -set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal) -file(MAKE_DIRECTORY ${GEN_OPS_DIR}) -set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py) -set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py) -set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py) -##################### generate python opr_param_defs.py ############## - -file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) -file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) -file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS}) - -add_custom_command( - OUTPUT ${GEN_OPS_FILE} - COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} - COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME} - COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE} - COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE} - COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test - COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE} - DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE} - VERBATIM -) - -add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE}) - -##################### end of opdef generation ######################### +add_subdirectory(tablegen) add_custom_target(_version_ld SOURCES ${MGE_VERSION_SCRIPT}) @@ -73,7 +46,7 @@ else() endif() endif() -target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) +target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) if(CXX_SUPPORT_WCLASS_MEMACCESS) @@ -87,7 +60,7 @@ if (APPLE OR MSVC OR WIN32) message(VERBOSE "overwriting SUFFIX at macos and windows before config by set_target_properties") pybind11_extension(${MODULE_NAME}) endif() -add_dependencies(${MODULE_NAME} gen_opr_py _version_ld) +add_dependencies(${MODULE_NAME} mgb_opdef _version_ld) if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) add_subdirectory(test) diff --git a/imperative/python/megengine/core/autodiff/builtin_op_utils.py b/imperative/python/megengine/core/autodiff/builtin_op_utils.py index 07177e26e..5071e2b34 100644 --- a/imperative/python/megengine/core/autodiff/builtin_op_utils.py +++ b/imperative/python/megengine/core/autodiff/builtin_op_utils.py @@ -19,7 +19,6 @@ from ..ops.builtin import ( IndexingMultiAxisVec, IndexingSetMultiAxisVec, OpDef, - OprAttr, Reduce, Reshape, SetSubtensor, @@ -31,8 +30,6 @@ from ..tensor.function import Function from ..tensor.tensor import Tensor from ..tensor.tensor_wrapper import TensorWrapper -_reduce_sum_param = Reduce(mode="SUM").to_c().param[0] - @functools.singledispatch def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): @@ -41,17 +38,18 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): @builtin_op_get_backward_fn.register(OpDef) def _(op: OpDef, inputs, outputs, input_requires_grad): - if isinstance(op, OprAttr): - grad_fn = _oprAttr_grad_fn.get(op.type, None) - if grad_fn is None: - if op.type == Reduce.name and op.param[0] == _reduce_sum_param: - grad_fn = reduce_sum_grad_fn - else: - grad_fn = default_grad_fn + if isinstance(op, Reshape): + grad_fn = reshape_grad_fn + elif isinstance(op, Subtensor): + grad_fn = subtensor_grad_fn + elif isinstance(op, IndexingMultiAxisVec): + grad_fn = indexingMultiAxisVec_grad_fn elif isinstance(op, Broadcast) or ( isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD ): grad_fn = elemwise_add_grad_fn + elif isinstance(op, Reduce) and op.mode.name == "SUM": + grad_fn = reduce_sum_grad_fn else: grad_fn = default_grad_fn return grad_fn(op, inputs, outputs, input_requires_grad) @@ -152,9 +150,7 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad): # override for Subtensor def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): - grad_op = OprAttr() - grad_op.type = SetSubtensor.name - grad_op.param = op.param + grad_op = SetSubtensor(op.items) input_shape = get_shape(inputs[0]) params = inputs[1:] @@ -175,9 +171,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): # override for IndexingMultiAxisVec def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): - grad_op = OprAttr() - grad_op.type = IndexingSetMultiAxisVec.name - grad_op.param = op.param + grad_op = IndexingSetMultiAxisVec(op.items) input_shape = get_shape(inputs[0]) params = inputs[1:] @@ -209,10 +203,3 @@ def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad): return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,) return backward, [True] - - -_oprAttr_grad_fn = { - Reshape.name: reshape_grad_fn, - Subtensor.name: subtensor_grad_fn, - IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn, -} diff --git a/imperative/python/megengine/core/ops/_internal/__init__.py b/imperative/python/megengine/core/ops/_internal/__init__.py deleted file mode 100644 index 1207b5d98..000000000 --- a/imperative/python/megengine/core/ops/_internal/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/imperative/python/megengine/core/ops/_internal/all_ops.py b/imperative/python/megengine/core/ops/_internal/all_ops.py deleted file mode 100644 index f1627ee97..000000000 --- a/imperative/python/megengine/core/ops/_internal/all_ops.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .generated_ops import * -from .misc_ops import * diff --git a/imperative/python/megengine/core/ops/_internal/enum36.py b/imperative/python/megengine/core/ops/_internal/enum36.py deleted file mode 100644 index bab7a0f83..000000000 --- a/imperative/python/megengine/core/ops/_internal/enum36.py +++ /dev/null @@ -1,939 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import sys -from functools import reduce -from operator import or_ as _or_ -from types import DynamicClassAttribute, MappingProxyType - -# try _collections first to reduce startup cost -try: - from _collections import OrderedDict -except ImportError: - from collections import OrderedDict - - -__all__ = [ - "EnumMeta", - "Enum", - "IntEnum", - "Flag", - "IntFlag", - "auto", - "unique", -] - - -def _is_descriptor(obj): - """Returns True if obj is a descriptor, False otherwise.""" - return ( - hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") - ) - - -def _is_dunder(name): - """Returns True if a __dunder__ name, False otherwise.""" - return ( - name[:2] == name[-2:] == "__" - and name[2:3] != "_" - and name[-3:-2] != "_" - and len(name) > 4 - ) - - -def _is_sunder(name): - """Returns True if a _sunder_ name, False otherwise.""" - return ( - name[0] == name[-1] == "_" - and name[1:2] != "_" - and name[-2:-1] != "_" - and len(name) > 2 - ) - - -def _make_class_unpicklable(cls): - """Make the given class un-picklable.""" - - def _break_on_call_reduce(self, proto): - raise TypeError("%r cannot be pickled" % self) - - cls.__reduce_ex__ = _break_on_call_reduce - cls.__module__ = "" - - -_auto_null = object() - - -class auto: - """ - Instances are replaced with an appropriate value in Enum class suites. - """ - - value = _auto_null - - -class _EnumDict(dict): - """ - Track enum member order and ensure member names are not reused. - - EnumMeta will use the names found in self._member_names as the - enumeration member names. - - """ - - def __init__(self): - super().__init__() - self._member_names = [] - self._last_values = [] - - def __setitem__(self, key, value): - """ - Changes anything not dundered or not a descriptor. - - If an enum member name is used twice, an error is raised; duplicate - values are not checked for. - - Single underscore (sunder) names are reserved. - - """ - if _is_sunder(key): - if key not in ( - "_order_", - "_create_pseudo_member_", - "_generate_next_value_", - "_missing_", - ): - raise ValueError("_names_ are reserved for future Enum use") - if key == "_generate_next_value_": - setattr(self, "_generate_next_value", value) - elif _is_dunder(key): - if key == "__order__": - key = "_order_" - elif key in self._member_names: - # descriptor overwriting an enum? - raise TypeError("Attempted to reuse key: %r" % key) - elif not _is_descriptor(value): - if key in self: - # enum overwriting a descriptor? - raise TypeError("%r already defined as: %r" % (key, self[key])) - if isinstance(value, auto): - if value.value == _auto_null: - value.value = self._generate_next_value( - key, 1, len(self._member_names), self._last_values[:] - ) - value = value.value - self._member_names.append(key) - self._last_values.append(value) - super().__setitem__(key, value) - - -# Dummy value for Enum as EnumMeta explicitly checks for it, but of course -# until EnumMeta finishes running the first time the Enum class doesn't exist. -# This is also why there are checks in EnumMeta like `if Enum is not None` -Enum = None - - -class EnumMeta(type): - """Metaclass for Enum""" - - @classmethod - def __prepare__(metacls, cls, bases): - # create the namespace dict - enum_dict = _EnumDict() - # inherit previous flags and _generate_next_value_ function - member_type, first_enum = metacls._get_mixins_(bases) - if first_enum is not None: - enum_dict["_generate_next_value_"] = getattr( - first_enum, "_generate_next_value_", None - ) - return enum_dict - - def __new__(metacls, cls, bases, classdict): - # an Enum class is final once enumeration items have been defined; it - # cannot be mixed with other types (int, float, etc.) if it has an - # inherited __new__ unless a new __new__ is defined (or the resulting - # class will fail). - member_type, first_enum = metacls._get_mixins_(bases) - __new__, save_new, use_args = metacls._find_new_( - classdict, member_type, first_enum - ) - - # save enum items into separate mapping so they don't get baked into - # the new class - enum_members = {k: classdict[k] for k in classdict._member_names} - for name in classdict._member_names: - del classdict[name] - - # adjust the sunders - _order_ = classdict.pop("_order_", None) - - # check for illegal enum names (any others?) - invalid_names = set(enum_members) & { - "mro", - } - if invalid_names: - raise ValueError( - "Invalid enum member name: {0}".format(",".join(invalid_names)) - ) - - # create a default docstring if one has not been provided - if "__doc__" not in classdict: - classdict["__doc__"] = "An enumeration." - - # create our new Enum type - enum_class = super().__new__(metacls, cls, bases, classdict) - enum_class._member_names_ = [] # names in definition order - enum_class._member_map_ = OrderedDict() # name->value map - enum_class._member_type_ = member_type - - # save attributes from super classes so we know if we can take - # the shortcut of storing members in the class dict - base_attributes = {a for b in enum_class.mro() for a in b.__dict__} - - # Reverse value->name map for hashable values. - enum_class._value2member_map_ = {} - - # If a custom type is mixed into the Enum, and it does not know how - # to pickle itself, pickle.dumps will succeed but pickle.loads will - # fail. Rather than have the error show up later and possibly far - # from the source, sabotage the pickle protocol for this class so - # that pickle.dumps also fails. - # - # However, if the new class implements its own __reduce_ex__, do not - # sabotage -- it's on them to make sure it works correctly. We use - # __reduce_ex__ instead of any of the others as it is preferred by - # pickle over __reduce__, and it handles all pickle protocols. - if "__reduce_ex__" not in classdict: - if member_type is not object: - methods = ( - "__getnewargs_ex__", - "__getnewargs__", - "__reduce_ex__", - "__reduce__", - ) - if not any(m in member_type.__dict__ for m in methods): - _make_class_unpicklable(enum_class) - - # instantiate them, checking for duplicates as we go - # we instantiate first instead of checking for duplicates first in case - # a custom __new__ is doing something funky with the values -- such as - # auto-numbering ;) - for member_name in classdict._member_names: - value = enum_members[member_name] - if not isinstance(value, tuple): - args = (value,) - else: - args = value - if member_type is tuple: # special case for tuple enums - args = (args,) # wrap it one more time - if not use_args: - enum_member = __new__(enum_class) - if not hasattr(enum_member, "_value_"): - enum_member._value_ = value - else: - enum_member = __new__(enum_class, *args) - if not hasattr(enum_member, "_value_"): - if member_type is object: - enum_member._value_ = value - else: - enum_member._value_ = member_type(*args) - value = enum_member._value_ - enum_member._name_ = member_name - enum_member.__objclass__ = enum_class - enum_member.__init__(*args) - # If another member with the same value was already defined, the - # new member becomes an alias to the existing one. - for name, canonical_member in enum_class._member_map_.items(): - if canonical_member._value_ == enum_member._value_: - enum_member = canonical_member - break - else: - # Aliases don't appear in member names (only in __members__). - enum_class._member_names_.append(member_name) - # performance boost for any member that would not shadow - # a DynamicClassAttribute - if member_name not in base_attributes: - setattr(enum_class, member_name, enum_member) - # now add to _member_map_ - enum_class._member_map_[member_name] = enum_member - try: - # This may fail if value is not hashable. We can't add the value - # to the map, and by-value lookups for this value will be - # linear. - enum_class._value2member_map_[value] = enum_member - except TypeError: - pass - - # double check that repr and friends are not the mixin's or various - # things break (such as pickle) - for name in ("__repr__", "__str__", "__format__", "__reduce_ex__"): - class_method = getattr(enum_class, name) - obj_method = getattr(member_type, name, None) - enum_method = getattr(first_enum, name, None) - if obj_method is not None and obj_method is class_method: - setattr(enum_class, name, enum_method) - - # replace any other __new__ with our own (as long as Enum is not None, - # anyway) -- again, this is to support pickle - if Enum is not None: - # if the user defined their own __new__, save it before it gets - # clobbered in case they subclass later - if save_new: - enum_class.__new_member__ = __new__ - enum_class.__new__ = Enum.__new__ - - # py3 support for definition order (helps keep py2/py3 code in sync) - if _order_ is not None: - if isinstance(_order_, str): - _order_ = _order_.replace(",", " ").split() - if _order_ != enum_class._member_names_: - raise TypeError("member order does not match _order_") - - return enum_class - - def __bool__(self): - """ - classes/types should always be True. - """ - return True - - def __call__( - cls, value, names=None, *, module=None, qualname=None, type=None, start=1 - ): - """ - Either returns an existing member, or creates a new enum class. - - This method is used both when an enum class is given a value to match - to an enumeration member (i.e. Color(3)) and for the functional API - (i.e. Color = Enum('Color', names='RED GREEN BLUE')). - - When used for the functional API: - - `value` will be the name of the new class. - - `names` should be either a string of white-space/comma delimited names - (values will start at `start`), or an iterator/mapping of name, value pairs. - - `module` should be set to the module this class is being created in; - if it is not set, an attempt to find that module will be made, but if - it fails the class will not be picklable. - - `qualname` should be set to the actual location this class can be found - at in its module; by default it is set to the global scope. If this is - not correct, unpickling will fail in some circumstances. - - `type`, if set, will be mixed in as the first base class. - - """ - if names is None: # simple value lookup - return cls.__new__(cls, value) - # otherwise, functional API: we're creating a new Enum type - return cls._create_( - value, names, module=module, qualname=qualname, type=type, start=start - ) - - def __contains__(cls, member): - return isinstance(member, cls) and member._name_ in cls._member_map_ - - def __delattr__(cls, attr): - # nicer error message when someone tries to delete an attribute - # (see issue19025). - if attr in cls._member_map_: - raise AttributeError("%s: cannot delete Enum member." % cls.__name__) - super().__delattr__(attr) - - def __dir__(self): - return [ - "__class__", - "__doc__", - "__members__", - "__module__", - ] + self._member_names_ - - def __getattr__(cls, name): - """ - Return the enum member matching `name` - - We use __getattr__ instead of descriptors or inserting into the enum - class' __dict__ in order to support `name` and `value` being both - properties for enum members (which live in the class' __dict__) and - enum members themselves. - - """ - if _is_dunder(name): - raise AttributeError(name) - try: - return cls._member_map_[name] - except KeyError: - raise AttributeError(name) from None - - def __getitem__(cls, name): - return cls._member_map_[name] - - def __iter__(cls): - return (cls._member_map_[name] for name in cls._member_names_) - - def __len__(cls): - return len(cls._member_names_) - - @property - def __members__(cls): - """ - Returns a mapping of member name->value. - - This mapping lists all enum members, including aliases. Note that this - is a read-only view of the internal mapping. - - """ - return MappingProxyType(cls._member_map_) - - def __repr__(cls): - return "" % cls.__name__ - - def __reversed__(cls): - return (cls._member_map_[name] for name in reversed(cls._member_names_)) - - def __setattr__(cls, name, value): - """ - Block attempts to reassign Enum members. - - A simple assignment to the class namespace only changes one of the - several possible ways to get an Enum member from the Enum class, - resulting in an inconsistent Enumeration. - - """ - member_map = cls.__dict__.get("_member_map_", {}) - if name in member_map: - raise AttributeError("Cannot reassign members.") - super().__setattr__(name, value) - - def _create_( - cls, class_name, names=None, *, module=None, qualname=None, type=None, start=1 - ): - """ - Convenience method to create a new Enum class. - - `names` can be: - - * A string containing member names, separated either with spaces or - commas. Values are incremented by 1 from `start`. - * An iterable of member names. Values are incremented by 1 from `start`. - * An iterable of (member name, value) pairs. - * A mapping of member name -> value pairs. - - """ - metacls = cls.__class__ - bases = (cls,) if type is None else (type, cls) - _, first_enum = cls._get_mixins_(bases) - classdict = metacls.__prepare__(class_name, bases) - - # special processing needed for names? - if isinstance(names, str): - names = names.replace(",", " ").split() - if isinstance(names, (tuple, list)) and names and isinstance(names[0], str): - original_names, names = names, [] - last_values = [] - for count, name in enumerate(original_names): - value = first_enum._generate_next_value_( - name, start, count, last_values[:] - ) - last_values.append(value) - names.append((name, value)) - - # Here, names is either an iterable of (name, value) or a mapping. - for item in names: - if isinstance(item, str): - member_name, member_value = item, names[item] - else: - member_name, member_value = item - classdict[member_name] = member_value - enum_class = metacls.__new__(metacls, class_name, bases, classdict) - - # TODO: replace the frame hack if a blessed way to know the calling - # module is ever developed - if module is None: - try: - module = sys._getframe(2).f_globals["__name__"] - except (AttributeError, ValueError) as exc: - pass - if module is None: - _make_class_unpicklable(enum_class) - else: - enum_class.__module__ = module - if qualname is not None: - enum_class.__qualname__ = qualname - - return enum_class - - @staticmethod - def _get_mixins_(bases): - """ - Returns the type for creating enum members, and the first inherited - enum class. - - bases: the tuple of bases that was given to __new__ - - """ - if not bases: - return object, Enum - - # double check that we are not subclassing a class with existing - # enumeration members; while we're at it, see if any other data - # type has been mixed in so we can use the correct __new__ - member_type = first_enum = None - for base in bases: - if base is not Enum and issubclass(base, Enum) and base._member_names_: - raise TypeError("Cannot extend enumerations") - # base is now the last base in bases - if not issubclass(base, Enum): - raise TypeError( - "new enumerations must be created as " - "`ClassName([mixin_type,] enum_type)`" - ) - - # get correct mix-in type (either mix-in type of Enum subclass, or - # first base if last base is Enum) - if not issubclass(bases[0], Enum): - member_type = bases[0] # first data type - first_enum = bases[-1] # enum type - else: - for base in bases[0].__mro__: - # most common: (IntEnum, int, Enum, object) - # possible: (, , - # , , - # ) - if issubclass(base, Enum): - if first_enum is None: - first_enum = base - else: - if member_type is None: - member_type = base - - return member_type, first_enum - - @staticmethod - def _find_new_(classdict, member_type, first_enum): - """ - Returns the __new__ to be used for creating the enum members. - - classdict: the class dictionary given to __new__ - member_type: the data type whose __new__ will be used by default - first_enum: enumeration to check for an overriding __new__ - - """ - # now find the correct __new__, checking to see of one was defined - # by the user; also check earlier enum classes in case a __new__ was - # saved as __new_member__ - __new__ = classdict.get("__new__", None) - - # should __new__ be saved as __new_member__ later? - save_new = __new__ is not None - - if __new__ is None: - # check all possibles for __new_member__ before falling back to - # __new__ - for method in ("__new_member__", "__new__"): - for possible in (member_type, first_enum): - target = getattr(possible, method, None) - if target not in { - None, - None.__new__, - object.__new__, - Enum.__new__, - }: - __new__ = target - break - if __new__ is not None: - break - else: - __new__ = object.__new__ - - # if a non-object.__new__ is used then whatever value/tuple was - # assigned to the enum member name will be passed to __new__ and to the - # new enum member's __init__ - if __new__ is object.__new__: - use_args = False - else: - use_args = True - - return __new__, save_new, use_args - - -class Enum(metaclass=EnumMeta): - """ - Generic enumeration. - - Derive from this class to define new enumerations. - - """ - - def __new__(cls, value): - # all enum instances are actually created during class construction - # without calling this method; this method is called by the metaclass' - # __call__ (i.e. Color(3) ), and by pickle - if type(value) is cls: - # For lookups like Color(Color.RED) - return value - # by-value search for a matching enum member - # see if it's in the reverse mapping (for hashable values) - try: - if value in cls._value2member_map_: - return cls._value2member_map_[value] - except TypeError: - # not there, now do long search -- O(n) behavior - for member in cls._member_map_.values(): - if member._value_ == value: - return member - # still not found -- try _missing_ hook - return cls._missing_(value) - - def _generate_next_value_(name, start, count, last_values): - for last_value in reversed(last_values): - try: - return last_value + 1 - except TypeError: - pass - else: - return start - - @classmethod - def _missing_(cls, value): - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) - - def __repr__(self): - return "<%s.%s: %r>" % (self.__class__.__name__, self._name_, self._value_) - - def __str__(self): - return "%s.%s" % (self.__class__.__name__, self._name_) - - def __dir__(self): - added_behavior = [ - m - for cls in self.__class__.mro() - for m in cls.__dict__ - if m[0] != "_" and m not in self._member_map_ - ] - return ["__class__", "__doc__", "__module__"] + added_behavior - - def __format__(self, format_spec): - # mixed-in Enums should use the mixed-in type's __format__, otherwise - # we can get strange results with the Enum name showing up instead of - # the value - - # pure Enum branch - if self._member_type_ is object: - cls = str - val = str(self) - # mix-in branch - else: - cls = self._member_type_ - val = self._value_ - return cls.__format__(val, format_spec) - - def __hash__(self): - return hash(self._name_) - - def __reduce_ex__(self, proto): - return self.__class__, (self._value_,) - - # DynamicClassAttribute is used to provide access to the `name` and - # `value` properties of enum members while keeping some measure of - # protection from modification, while still allowing for an enumeration - # to have members named `name` and `value`. This works because enumeration - # members are not set directly on the enum class -- __getattr__ is - # used to look them up. - - @DynamicClassAttribute - def name(self): - """The name of the Enum member.""" - return self._name_ - - @DynamicClassAttribute - def value(self): - """The value of the Enum member.""" - return self._value_ - - @classmethod - def _convert(cls, name, module, filter, source=None): - """ - Create a new Enum subclass that replaces a collection of global constants - """ - # convert all constants from source (or module) that pass filter() to - # a new Enum called name, and export the enum and its members back to - # module; - # also, replace the __reduce_ex__ method so unpickling works in - # previous Python versions - module_globals = vars(sys.modules[module]) - if source: - source = vars(source) - else: - source = module_globals - # We use an OrderedDict of sorted source keys so that the - # _value2member_map is populated in the same order every time - # for a consistent reverse mapping of number to name when there - # are multiple names for the same number rather than varying - # between runs due to hash randomization of the module dictionary. - members = [(name, source[name]) for name in source.keys() if filter(name)] - try: - # sort by value - members.sort(key=lambda t: (t[1], t[0])) - except TypeError: - # unless some values aren't comparable, in which case sort by name - members.sort(key=lambda t: t[0]) - cls = cls(name, members, module=module) - cls.__reduce_ex__ = _reduce_ex_by_name - module_globals.update(cls.__members__) - module_globals[name] = cls - return cls - - -class IntEnum(int, Enum): - """Enum where members are also (and must be) ints""" - - -def _reduce_ex_by_name(self, proto): - return self.name - - -class Flag(Enum): - """Support for flags""" - - def _generate_next_value_(name, start, count, last_values): - """ - Generate the next value when not given. - - name: the name of the member - start: the initital start value or None - count: the number of existing members - last_value: the last value assigned or None - """ - if not count: - return start if start is not None else 1 - for last_value in reversed(last_values): - try: - high_bit = _high_bit(last_value) - break - except Exception: - raise TypeError("Invalid Flag value: %r" % last_value) from None - return 2 ** (high_bit + 1) - - @classmethod - def _missing_(cls, value): - original_value = value - if value < 0: - value = ~value - possible_member = cls._create_pseudo_member_(value) - if original_value < 0: - possible_member = ~possible_member - return possible_member - - @classmethod - def _create_pseudo_member_(cls, value): - """ - Create a composite member iff value contains only members. - """ - pseudo_member = cls._value2member_map_.get(value, None) - if pseudo_member is None: - # verify all bits are accounted for - _, extra_flags = _decompose(cls, value) - if extra_flags: - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) - # construct a singleton enum pseudo-member - pseudo_member = object.__new__(cls) - pseudo_member._name_ = None - pseudo_member._value_ = value - # use setdefault in case another thread already created a composite - # with this value - pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) - return pseudo_member - - def __contains__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return other._value_ & self._value_ == other._value_ - - def __repr__(self): - cls = self.__class__ - if self._name_ is not None: - return "<%s.%s: %r>" % (cls.__name__, self._name_, self._value_) - members, uncovered = _decompose(cls, self._value_) - return "<%s.%s: %r>" % ( - cls.__name__, - "|".join([str(m._name_ or m._value_) for m in members]), - self._value_, - ) - - def __str__(self): - cls = self.__class__ - if self._name_ is not None: - return "%s.%s" % (cls.__name__, self._name_) - members, uncovered = _decompose(cls, self._value_) - if len(members) == 1 and members[0]._name_ is None: - return "%s.%r" % (cls.__name__, members[0]._value_) - else: - return "%s.%s" % ( - cls.__name__, - "|".join([str(m._name_ or m._value_) for m in members]), - ) - - def __bool__(self): - return bool(self._value_) - - def __or__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ | other._value_) - - def __and__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ & other._value_) - - def __xor__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ ^ other._value_) - - def __invert__(self): - members, uncovered = _decompose(self.__class__, self._value_) - inverted_members = [ - m - for m in self.__class__ - if m not in members and not m._value_ & self._value_ - ] - inverted = reduce(_or_, inverted_members, self.__class__(0)) - return self.__class__(inverted) - - -class IntFlag(int, Flag): - """Support for integer-based Flags""" - - @classmethod - def _missing_(cls, value): - if not isinstance(value, int): - raise ValueError("%r is not a valid %s" % (value, cls.__name__)) - new_member = cls._create_pseudo_member_(value) - return new_member - - @classmethod - def _create_pseudo_member_(cls, value): - pseudo_member = cls._value2member_map_.get(value, None) - if pseudo_member is None: - need_to_create = [value] - # get unaccounted for bits - _, extra_flags = _decompose(cls, value) - # timer = 10 - while extra_flags: - # timer -= 1 - bit = _high_bit(extra_flags) - flag_value = 2 ** bit - if ( - flag_value not in cls._value2member_map_ - and flag_value not in need_to_create - ): - need_to_create.append(flag_value) - if extra_flags == -flag_value: - extra_flags = 0 - else: - extra_flags ^= flag_value - for value in reversed(need_to_create): - # construct singleton pseudo-members - pseudo_member = int.__new__(cls, value) - pseudo_member._name_ = None - pseudo_member._value_ = value - # use setdefault in case another thread already created a composite - # with this value - pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) - return pseudo_member - - def __or__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - result = self.__class__(self._value_ | self.__class__(other)._value_) - return result - - def __and__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - return self.__class__(self._value_ & self.__class__(other)._value_) - - def __xor__(self, other): - if not isinstance(other, (self.__class__, int)): - return NotImplemented - return self.__class__(self._value_ ^ self.__class__(other)._value_) - - __ror__ = __or__ - __rand__ = __and__ - __rxor__ = __xor__ - - def __invert__(self): - result = self.__class__(~self._value_) - return result - - -def _high_bit(value): - """returns index of highest bit, or -1 if value is zero or negative""" - return value.bit_length() - 1 - - -def unique(enumeration): - """Class decorator for enumerations ensuring unique member values.""" - duplicates = [] - for name, member in enumeration.__members__.items(): - if name != member.name: - duplicates.append((name, member.name)) - if duplicates: - alias_details = ", ".join( - ["%s -> %s" % (alias, name) for (alias, name) in duplicates] - ) - raise ValueError( - "duplicate values found in %r: %s" % (enumeration, alias_details) - ) - return enumeration - - -def _decompose(flag, value): - """Extract all members from the value.""" - # _decompose is only called if the value is not named - not_covered = value - negative = value < 0 - # issue29167: wrap accesses to _value2member_map_ in a list to avoid race - # conditions between iterating over it and having more psuedo- - # members added to it - if negative: - # only check for named flags - flags_to_check = [ - (m, v) - for v, m in list(flag._value2member_map_.items()) - if m.name is not None - ] - else: - # check for named flags and powers-of-two flags - flags_to_check = [ - (m, v) - for v, m in list(flag._value2member_map_.items()) - if m.name is not None or _power_of_two(v) - ] - members = [] - for member, member_value in flags_to_check: - if member_value and member_value & value == member_value: - members.append(member) - not_covered &= ~member_value - if not members and value in flag._value2member_map_: - members.append(flag._value2member_map_[value]) - members.sort(key=lambda m: m._value_, reverse=True) - if len(members) > 1 and members[0].value == value: - # we have the breakdown, don't need the value member itself - members.pop(0) - return members, not_covered - - -def _power_of_two(value): - if value < 1: - return False - return value == 2 ** _high_bit(value) diff --git a/imperative/python/megengine/core/ops/_internal/helper.py b/imperative/python/megengine/core/ops/_internal/helper.py deleted file mode 100644 index 52af3aa0a..000000000 --- a/imperative/python/megengine/core/ops/_internal/helper.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import warnings - -from ..._imperative_rt.ops import OprAttr -from . import param_defs - - -def make_param(param, ptype, kwargs): - if param is not None: - if isinstance(param, ptype): - return param - - param = [param] - assert len(param) == len( - ptype.__slots__ - ), "{} needs {} params, but {} are provided".format( - ptype, len(ptype.__slots__), len(param) - ) - return ptype(*param) - - ckw = {} - for i in ptype.__slots__: - val = kwargs.pop(i, ckw) - if val is not ckw: - ckw[i] = val - return ptype(**ckw) - - -class PodOpVisitor: - __name2subclass = {} - __c = None - - name = None - param_names = [] - config = None - - def __init__(self, config, **params): - self.config = config - assert set(params) == set(self.param_names) - self.__dict__.update(params) - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) # python 3.5 does not have this - name = cls.name - if name in cls.__name2subclass: - if not issubclass(cls, cls.__name2subclass[name]): - warnings.warn("Multiple subclasses for bultin op: %s" % name) - cls.__name2subclass[name] = cls - - def to_c(self): - if self.__c: - return self.__c - op = OprAttr() - op.type = self.name - if self.config is not None: - op.config = self.config - # first 4 bytes is TAG, has to remove them currently - op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names) - self.__c = op - return op - - def __eq__(self, rhs): - return self.to_c() == rhs.to_c() - - def __repr__(self): - name = self.__class__.__name__ - - if self.__c: - return "{}()".format(name) - - kwargs = {} - for i in self.param_names: - p = self.__dict__[i] - if isinstance(p, param_defs._ParamDefBase): - for k in p.__slots__: - v = getattr(p, k) - if isinstance(v, param_defs._EnumBase): - v = v.name - kwargs[k] = repr(v) - else: - kwargs[i] = repr(p) - if self.config: - if len(self.config.comp_node_arr) == 1: - kwargs["device"] = "'%s'" % self.config.comp_node - return "{}({})".format( - name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items()) - ) diff --git a/imperative/python/megengine/core/ops/_internal/misc_ops.py b/imperative/python/megengine/core/ops/_internal/misc_ops.py deleted file mode 100644 index a257efc76..000000000 --- a/imperative/python/megengine/core/ops/_internal/misc_ops.py +++ /dev/null @@ -1,194 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import collections -import ctypes - -from ..._imperative_rt import OperatorNodeConfig as Config -from . import param_defs -from .helper import PodOpVisitor, make_param - -__all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"] - - -class TensorShape: - MAX_NDIM = 7 - - -class ConvolutionBackwardData(PodOpVisitor): - param_names = ( - "param", - "execution_polity", - ) - name = "ConvolutionBackwardDataV1" - - def __init__( - self, - *, - param=None, - execution_polity=None, - name=None, - comp_node=None, - config=None, - dtype=None, - **kwargs - ): - config = config or Config() - if name: - config.name = name - if comp_node: - config.comp_node = comp_node - if dtype: - config.dtype = dtype - self.config = config - - self.param = make_param(param, param_defs.Convolution, kwargs) - self.execution_polity = make_param( - execution_polity, param_defs.ExecutionPolicy, kwargs - ) - assert not kwargs, "extra kwargs: {}".format(kwargs) - - -class Dimshuffle(PodOpVisitor): - name = "Dimshuffle" - param_names = ("pattern",) - - class Pattern(ctypes.Structure): - Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM - _fields_ = [ - ("length", ctypes.c_uint32), - ("pattern", Pattern_Array), - ("ndim", ctypes.c_uint32), - ] - - def serialize(self): - return bytes(ctypes.c_uint32(0)) + bytes(self) - - def __init__(self, pattern, ndim=0): - assert isinstance(pattern, collections.abc.Iterable) - assert len(pattern) <= TensorShape.MAX_NDIM - pattern_array = Dimshuffle.Pattern.Pattern_Array() - for idx, v in enumerate(pattern): - pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v)) - self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim) - - -class Reshape(PodOpVisitor): - name = "ReshapeV1" - param_names = ("unspec_axis",) - - def __init__(self, unspec_axis=None): - if unspec_axis is None: - self.unspec_axis = param_defs.OptionalAxisV1() - else: - self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis) - - -class AxisNum(ctypes.Structure): - _fields_ = [ - ("m_num", ctypes.c_int), - ] - - -class AxisDesc(ctypes.Structure): - class Method(ctypes.c_int): - ADD_1 = 0 - REMOVE = 1 - - _fields_ = [ - ("method", Method), - ("axis", AxisNum), - ] - - @classmethod - def make_add(cls, axis): - return cls(cls.Method.ADD_1, AxisNum(axis)) - - @classmethod - def make_remove(cls, axis): - return cls(cls.Method.REMOVE, AxisNum(axis)) - - -class AxisAddRemove(PodOpVisitor): - name = "AxisAddRemove" - param_names = ("param",) - - AxisDesc = AxisDesc - - class Param(ctypes.Structure): - MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2 - - _fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)] - - def __init__(self, *args): - super().__init__() - self.nr_desc = len(args) - for i, a in enumerate(args): - self.desc[i] = a - - def serialize(self): - return bytes(ctypes.c_uint32(0)) + bytes(self) - - def __init__(self, param): - assert isinstance(param, self.Param) - self.param = param - - -del AxisDesc - - -class IndexingOpBase(PodOpVisitor): - param_names = ("index_desc",) - - class IndexDescMaskDump(ctypes.Structure): - class Item(ctypes.Structure): - _fields_ = [ - ("axis", ctypes.c_int8), - ("begin", ctypes.c_bool), - ("end", ctypes.c_bool), - ("step", ctypes.c_bool), - ("idx", ctypes.c_bool), - ] - - Item_Array = Item * TensorShape.MAX_NDIM - - _fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)] - - def serialize(self): - return bytes(ctypes.c_uint32(0)) + bytes(self) - - def __init__(self, items): - nr_item = len(items) - assert nr_item <= TensorShape.MAX_NDIM - item_array = IndexingOpBase.IndexDescMaskDump.Item_Array() - for idx, item in enumerate(items): - assert isinstance(item, (tuple, list)) and len(item) == 5 - item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item) - self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array) - - -def _gen_indexing_defs(*names): - for name in names: - globals()[name] = type(name, (IndexingOpBase,), dict(name=name)) - __all__.append(name) - - -_gen_indexing_defs( - "Subtensor", - "SetSubtensor", - "IncrSubtensor", - "IndexingMultiAxisVec", - "IndexingSetMultiAxisVec", - "IndexingIncrMultiAxisVec", - "MeshIndexing", - "IncrMeshIndexing", - "SetMeshIndexing", - "BatchedMeshIndexing", - "BatchedIncrMeshIndexing", - "BatchedSetMeshIndexing", -) diff --git a/imperative/python/megengine/core/ops/builtin/__init__.py b/imperative/python/megengine/core/ops/builtin/__init__.py index 997c9d2f1..3d67846d2 100644 --- a/imperative/python/megengine/core/ops/builtin/__init__.py +++ b/imperative/python/megengine/core/ops/builtin/__init__.py @@ -11,25 +11,12 @@ from typing import Union from ..._imperative_rt import OpDef, ops from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply -from .._internal import all_ops -from .._internal.helper import PodOpVisitor # register OpDef as a "virtual subclass" of OpBase, so any of registered # apply(OpBase, ...) rules could work well on OpDef OpBase.register(OpDef) -# forward to apply(OpDef, ...) -@apply.register() -def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]): - return apply(op.to_c(), *args) - - -__all__ = ["OpDef", "PodOpVisitor"] - -for k, v in all_ops.__dict__.items(): - if isinstance(v, type) and issubclass(v, PodOpVisitor): - globals()[k] = v - __all__.append(k) +__all__ = ["OpDef"] for k, v in ops.__dict__.items(): if isinstance(v, type) and issubclass(v, OpDef): diff --git a/imperative/python/megengine/core/tensor/tensor_wrapper.py b/imperative/python/megengine/core/tensor/tensor_wrapper.py index 760253d9d..357999a0a 100644 --- a/imperative/python/megengine/core/tensor/tensor_wrapper.py +++ b/imperative/python/megengine/core/tensor/tensor_wrapper.py @@ -90,7 +90,7 @@ def _reshape(x, shape): if unspec_axis is None: op = builtin.Reshape() else: - op = builtin.Reshape(unspec_axis=unspec_axis) + op = builtin.Reshape(axis=unspec_axis) (x,) = apply(op, x, shape) return x @@ -144,8 +144,6 @@ def _logical_binary_elwise(mode, rev=False): def _remove_axis(inp: Tensor, axis) -> Tensor: - Param = builtin.AxisAddRemove.Param - def get_axes(): if axis is None: return [i for i, s in enumerate(inp.shape) if s == 1] @@ -159,8 +157,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: axis = sorted(i + inp.ndim if i < 0 else i for i in axis) axis = [a - i for i, a in enumerate(axis)] - param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) - op = builtin.AxisAddRemove(param=param) + op = builtin.RemoveAxis(axis=axis) (result,) = apply(op, inp) if len(axis) == inp.ndim: setscalar(result) diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 49055d5f8..287ad1e2d 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -134,7 +134,7 @@ def astype(x, dtype): dtype = np.dtype(dtype) if not is_equal(x.dtype, dtype): isscalar = x.__wrapped__._data._isscalar - (x,) = apply(builtin.TypeCvt(param=dtype), x) + (x,) = apply(builtin.TypeCvt(dtype=dtype), x) x.__wrapped__._data._isscalar = isscalar return x diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 0c0b8d2b6..1816251bf 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -8,7 +8,6 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from typing import Optional, Tuple -from ..core._imperative_rt.ops import CollectiveCommMode from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn from ..core.autodiff.grad import ( Tracer, @@ -110,17 +109,20 @@ def collective_comm(inp, mode, group, device): assert isinstance(group, Group) if group is None: return inp - op = CollectiveComm() - op.key = group.key - op.nr_devices = group.size - op.rank = group.rank - op.is_root = op.rank == 0 - op.local_grad = False - op.addr, op.port = get_mm_server_addr() - op.mode = mode - op.dtype = inp.dtype - op.backend = get_backend() - op.comp_node = device + addr, port = get_mm_server_addr() + op = CollectiveComm( + key=group.key, + nr_devices=group.size, + rank=group.rank, + is_root=(group.rank == 0), + local_grad=False, + addr=addr, + port=port, + mode=mode, + dtype=inp.dtype, + backend=get_backend(), + comp_node=device, + ) return apply(op, inp)[0] @@ -134,7 +136,7 @@ def reduce_sum( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.REDUCE_SUM + mode = CollectiveComm.Mode.REDUCE_SUM return collective_comm(inp, mode, group, device) @@ -148,7 +150,7 @@ def broadcast( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.BROADCAST + mode = CollectiveComm.Mode.BROADCAST return collective_comm(inp, mode, group, device) @@ -162,7 +164,7 @@ def all_gather( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.ALL_GATHER + mode = CollectiveComm.Mode.ALL_GATHER return collective_comm(inp, mode, group, device) @@ -176,7 +178,7 @@ def reduce_scatter_sum( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.REDUCE_SCATTER_SUM + mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM return collective_comm(inp, mode, group, device) @@ -190,7 +192,7 @@ def all_reduce_sum( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.ALL_REDUCE_SUM + mode = CollectiveComm.Mode.ALL_REDUCE_SUM return collective_comm(inp, mode, group, device) @@ -204,7 +206,7 @@ def all_reduce_max( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.ALL_REDUCE_MAX + mode = CollectiveComm.Mode.ALL_REDUCE_MAX return collective_comm(inp, mode, group, device) @@ -218,7 +220,7 @@ def all_reduce_min( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.ALL_REDUCE_MIN + mode = CollectiveComm.Mode.ALL_REDUCE_MIN return collective_comm(inp, mode, group, device) @@ -232,7 +234,7 @@ def gather( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.GATHER + mode = CollectiveComm.Mode.GATHER return collective_comm(inp, mode, group, device) @@ -246,7 +248,7 @@ def scatter( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.SCATTER + mode = CollectiveComm.Mode.SCATTER return collective_comm(inp, mode, group, device) @@ -260,7 +262,7 @@ def all_to_all( :param group: communication group. :param device: execution device. """ - mode = CollectiveCommMode.ALL_TO_ALL + mode = CollectiveComm.Mode.ALL_TO_ALL return collective_comm(inp, mode, group, device) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 9ebce0a95..6ea379769 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -73,27 +73,7 @@ __all__ = [ ] -class _ElemwiseMode(Elemwise.Mode): - @classmethod - def __normalize(cls, val): - if isinstance(val, str): - if not hasattr(cls, "__member_upper_dict__"): - cls.__member_upper_dict__ = { - k.upper(): v for k, v in cls.__members__.items() - } - val = cls.__member_upper_dict__.get(val.upper(), val) - return val - - @classmethod - def convert(cls, val): - val = cls.__normalize(val) - if isinstance(val, cls): - return val - return cls(val) - - def _elwise(*args, mode): - mode = _ElemwiseMode.convert(mode) op = builtin.Elemwise(mode) tensor_args = list( filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index e89862a1f..7825fed39 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -13,7 +13,6 @@ import numbers from typing import Optional, Sequence, Tuple, Union from ..core.ops import builtin -from ..core.ops._internal import param_defs as P from ..core.ops.special import Const from ..core.tensor import utils from ..core.tensor.core import TensorBase, TensorWrapperBase, apply @@ -601,9 +600,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: """ assert len(inp.shape) <= 2, "Input should be 1d or 2d" if descending: - order = P.Argsort.Order.DESCENDING + order = "DESCENDING" else: - order = P.Argsort.Order.ASCENDING + order = "ASCENDING" op = builtin.Argsort(order=order) if len(inp.shape) == 1: @@ -643,9 +642,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: """ assert len(inp.shape) <= 2, "Input should be 1d or 2d" if descending: - order = P.Argsort.Order.DESCENDING + order = "DESCENDING" else: - order = P.Argsort.Order.ASCENDING + order = "ASCENDING" op = builtin.Argsort(order=order) if len(inp.shape) == 1: @@ -695,13 +694,12 @@ def topk( if descending: inp = -inp - Mode = P.TopK.Mode if kth_only: - mode = Mode.KTH_ONLY + mode = "KTH_ONLY" elif no_sort: - mode = Mode.VALUE_IDX_NOSORT + mode = "VALUE_IDX_NOSORT" else: - mode = Mode.VALUE_IDX_SORTED + mode = "VALUE_IDX_SORTED" op = builtin.TopK(mode=mode) if not isinstance(k, (TensorBase, TensorWrapperBase)): diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index cd80979cd..e936b1cab 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -12,7 +12,6 @@ from typing import Optional, Sequence, Tuple, Union from ..core._imperative_rt import CompNode from ..core._trace_option import use_symbolic_shape from ..core.ops import builtin -from ..core.ops._internal import param_defs as P from ..core.ops.builtin import BatchNorm from ..core.ops.special import Const from ..core.tensor import megbrain_graph, utils @@ -121,11 +120,11 @@ def conv2d( ``in_channels`` and ``out_channels`` must be divisible by ``groups``, and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. - :type conv_mode: string or :class:`P.Convolution.Mode` + :type conv_mode: string or :class:`Convolution.Mode` :param conv_mode: supports "CROSS_CORRELATION". Default: "CROSS_CORRELATION" :type compute_mode: string or - :class:`P.Convolution.ComputeMode` + :class:`Convolution.ComputeMode` :param compute_mode: when set to "DEFAULT", no special requirements will be placed on the precision of intermediate results. When set to "FLOAT32", "Float32" would be used for accumulator and intermediate result, but only @@ -139,8 +138,8 @@ def conv2d( pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) - Sparse = P.Convolution.Sparse - sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP + Sparse = builtin.Convolution.Sparse + sparse_type = "DENSE" if groups == 1 else "GROUP" op = builtin.Convolution( stride_h=stride_h, stride_w=stride_w, @@ -187,11 +186,11 @@ def conv_transpose2d( ``in_channels`` and ``out_channels`` must be divisible by groups, and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. Default: 1 - :type conv_mode: string or :class:`P.Convolution.Mode` + :type conv_mode: string or :class:`Convolution.Mode` :param conv_mode: supports "CROSS_CORRELATION". Default: "CROSS_CORRELATION" :type compute_mode: string or - :class:`P.Convolution.ComputeMode` + :class:`Convolution.ComputeMode` :param compute_mode: when set to "DEFAULT", no special requirements will be placed on the precision of intermediate results. When set to "FLOAT32", "Float32" would be used for accumulator and intermediate result, but only @@ -240,8 +239,6 @@ def local_conv2d( pad_h, pad_w = expand_hw(padding) dilate_h, dilate_w = expand_hw(dilation) - Sparse = P.Convolution.Sparse - op = builtin.GroupLocal( stride_h=stride_h, stride_w=stride_w, @@ -251,7 +248,7 @@ def local_conv2d( dilate_w=dilate_w, mode=conv_mode, compute_mode="DEFAULT", - sparse=Sparse.DENSE, + sparse="DENSE", ) inp, weight = utils.convert_inputs(inp, weight) (output,) = apply(op, inp, weight) @@ -696,19 +693,14 @@ def batch_norm( if not training: op = builtin.BatchNorm( - BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0 + fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11" ) ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] return ret else: op = builtin.BatchNorm( - BatchNorm.ParamDim.DIM_1C11, - BatchNorm.FwdMode.TRAINING, - eps, - 1.0 - momentum, - 1.0, - 0.0, + avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11" ) if has_mean or has_var: running_mean = make_full_if_none(running_mean, 0) @@ -1638,8 +1630,7 @@ def conv1d( pad_h = padding dilate_h = dilation - Sparse = P.Convolution.Sparse - sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP + sparse_type = "DENSE" if groups == 1 else "GROUP" op = builtin.Convolution( stride_h=stride_h, stride_w=1, diff --git a/imperative/python/megengine/functional/quantized.py b/imperative/python/megengine/functional/quantized.py index 0ae082b2d..b18f52d2d 100644 --- a/imperative/python/megengine/functional/quantized.py +++ b/imperative/python/megengine/functional/quantized.py @@ -41,12 +41,12 @@ def conv_bias_activation( ``in_channels`` and ``out_channels`` must be divisible by ``groups``, and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. - :type conv_mode: string or :class:`P.Convolution.Mode`. + :type conv_mode: string or :class:`Convolution.Mode`. :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: 'CROSS_CORRELATION' :param dtype: support for ``np.dtype``, Default: np.int8 :type compute_mode: string or - :class:`P.Convolution.ComputeMode`. + :class:`Convolution.ComputeMode`. :param compute_mode: when set to "DEFAULT", no special requirements will be placed on the precision of intermediate results. When set to "FLOAT32", "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. @@ -56,7 +56,7 @@ def conv_bias_activation( sh, sw = _pair_nonzero(stride) dh, dw = _pair_nonzero(dilation) sparse_type = "DENSE" if groups == 1 else "GROUP" - op = builtin.ConvBiasForward( + op = builtin.ConvBias( stride_h=sh, stride_w=sw, pad_h=ph, @@ -101,12 +101,12 @@ def batch_conv_bias_activation( ``in_channels`` and ``out_channels`` must be divisible by ``groups``, and the shape of weight should be `(groups, out_channel // groups, in_channels // groups, height, width)`. - :type conv_mode: string or :class:`P.Convolution.Mode`. + :type conv_mode: string or :class:`Convolution.Mode`. :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: 'CROSS_CORRELATION' :param dtype: support for ``np.dtype``, Default: np.int8 :type compute_mode: string or - :class:`P.Convolution.ComputeMode`. + :class:`Convolution.ComputeMode`. :param compute_mode: when set to "DEFAULT", no special requirements will be placed on the precision of intermediate results. When set to "FLOAT32", "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. @@ -116,7 +116,7 @@ def batch_conv_bias_activation( sh, sw = _pair_nonzero(stride) dh, dw = _pair_nonzero(dilation) sparse_type = "DENSE" if groups == 1 else "GROUP" - op = builtin.BatchConvBiasForward( + op = builtin.BatchConvBias( stride_h=sh, stride_w=sw, pad_h=ph, diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 05c87a8ad..ad180bee2 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -16,7 +16,6 @@ import numpy as np from ..core._imperative_rt import CompNode from ..core._wrap import device as as_device from ..core.ops import builtin -from ..core.ops._internal import param_defs as P from ..core.ops.special import Const from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis @@ -722,7 +721,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: [1 0]] """ - return inp.transpose(pattern) + return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern)) def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: @@ -756,10 +755,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: return inp.reshape(target_shape) -AxisAddRemove = builtin.AxisAddRemove -AxisDesc = AxisAddRemove.AxisDesc - - def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: r""" Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``. @@ -826,7 +821,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: (1, 2) """ - Param = builtin.AxisAddRemove.Param def get_axes(): try: @@ -839,8 +833,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ndim = inp.ndim + len(axis) axis = sorted(i + ndim if i < 0 else i for i in axis) - param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_add, axis)) - op = builtin.AxisAddRemove(param=param) + op = builtin.AddAxis(axis=axis) (result,) = apply(op, inp) return result diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index e61ca4849..31f2f4720 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -21,9 +21,10 @@ import numpy as np from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt.ops import ( CollectiveComm, - OprAttr, + GaussianRNG, RemoteRecv, RemoteSend, + UniformRNG, VirtualDep, ) from ..core._trace_option import set_symbolic_shape @@ -182,14 +183,7 @@ class trace: record = self._seq[self._pc] op_, ihandles, ohandles = record if op != op_: - # FIXME: will be removed once better rng implementation is done - if isinstance(op, OprAttr) and ( - op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type - ): - if op.param[8:] != op_.param[8:]: - raise TraceMismatchError("op different from last time") - else: - raise TraceMismatchError("op different from last time") + raise TraceMismatchError("op different from last time") if len(ihandles) != len(args): raise TraceMismatchError("op input size different from last time") diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index ab46b0dd6..2bd957100 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -10,7 +10,6 @@ from typing import Tuple, Union import numpy as np -from ..core.ops._internal import param_defs as P from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu from ..functional.types import _pair, _pair_nonzero from ..tensor import Parameter @@ -156,8 +155,6 @@ class Conv1d(_ConvNd): (2, 1, 2) """ - _conv_mode_type = P.Convolution.Mode - _compute_mode_type = P.Convolution.ComputeMode def __init__( self, @@ -176,8 +173,8 @@ class Conv1d(_ConvNd): stride = stride padding = padding dilation = dilation - self.conv_mode = self._conv_mode_type.convert(conv_mode) - self.compute_mode = self._compute_mode_type.convert(compute_mode) + self.conv_mode = conv_mode + self.compute_mode = compute_mode super().__init__( in_channels, out_channels, @@ -302,9 +299,6 @@ class Conv2d(_ConvNd): """ - _conv_mode_type = P.Convolution.Mode - _compute_mode_type = P.Convolution.ComputeMode - def __init__( self, in_channels: int, @@ -322,8 +316,8 @@ class Conv2d(_ConvNd): stride = _pair_nonzero(stride) padding = _pair(padding) dilation = _pair_nonzero(dilation) - self.conv_mode = self._conv_mode_type.convert(conv_mode) - self.compute_mode = self._compute_mode_type.convert(compute_mode) + self.conv_mode = conv_mode + self.compute_mode = compute_mode super().__init__( in_channels, out_channels, @@ -414,9 +408,6 @@ class ConvTranspose2d(_ConvNd): effective when input and output are of float16 dtype. """ - _conv_mode_type = P.Convolution.Mode - _compute_mode_type = P.Convolution.ComputeMode - def __init__( self, in_channels: int, @@ -434,8 +425,8 @@ class ConvTranspose2d(_ConvNd): stride = _pair_nonzero(stride) padding = _pair(padding) dilation = _pair_nonzero(dilation) - self.conv_mode = self._conv_mode_type.convert(conv_mode) - self.compute_mode = self._compute_mode_type.convert(compute_mode) + self.conv_mode = conv_mode + self.compute_mode = compute_mode super().__init__( in_channels, out_channels, @@ -509,8 +500,6 @@ class LocalConv2d(Conv2d): in_channels // groups, *kernel_size, out_channels // groups)`. """ - _conv_mode_type = P.Convolution.Mode - def __init__( self, in_channels: int, diff --git a/imperative/python/megengine/module/elemwise.py b/imperative/python/megengine/module/elemwise.py index dfc697251..a22b113fb 100644 --- a/imperative/python/megengine/module/elemwise.py +++ b/imperative/python/megengine/module/elemwise.py @@ -5,7 +5,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ..core.ops._internal import param_defs as P from ..functional.elemwise import _elwise from ..tensor import Tensor from .module import Module diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index a44db3847..6c6465896 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -41,8 +41,8 @@ class Conv2d(Float.Conv2d, QATModule): float_module.dilation, float_module.groups, float_module.bias is not None, - float_module.conv_mode.name, - float_module.compute_mode.name, + float_module.conv_mode, + float_module.compute_mode, ) qat_module.weight = float_module.weight qat_module.bias = float_module.bias diff --git a/imperative/python/megengine/module/quantized/elemwise.py b/imperative/python/megengine/module/quantized/elemwise.py index 3b16f8cf3..a5a555571 100644 --- a/imperative/python/megengine/module/quantized/elemwise.py +++ b/imperative/python/megengine/module/quantized/elemwise.py @@ -5,7 +5,6 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ...core.ops._internal import param_defs as P from ...functional.elemwise import _elemwise_multi_type from ...tensor import Tensor from ..qat import elemwise as QAT @@ -15,11 +14,9 @@ from .module import QuantizedModule class Elemwise(QuantizedModule): r"""Quantized version of :class:`~.qat.elemwise.Elemwise`.""" - _elemwise_multi_type_mode = P.ElemwiseMultiType.Mode - def __init__(self, method, dtype=None): super().__init__() - self.method = self._elemwise_multi_type_mode.convert("Q" + method) + self.method = "Q" + method self.output_dtype = dtype def forward(self, *inps): diff --git a/imperative/python/megengine/utils/profiler.py b/imperative/python/megengine/utils/profiler.py index 0fe88de9b..3850630c4 100644 --- a/imperative/python/megengine/utils/profiler.py +++ b/imperative/python/megengine/utils/profiler.py @@ -15,7 +15,7 @@ from typing import Iterable, List, Optional from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry from ..core._imperative_rt import ProfilerImpl as _Profiler from ..core._imperative_rt.imperative import sync -from ..core._imperative_rt.ops import CollectiveCommMode +from ..core._imperative_rt.ops import CollectiveComm def _make_dict(**kwargs): @@ -194,7 +194,7 @@ class Profiler: _type_map = { OperatorNodeConfig: lambda x: _print_opnode_config(x), bytes: lambda x: base64.encodebytes(x).decode("ascii"), - CollectiveCommMode: lambda x: str(x), + CollectiveComm.Mode: lambda x: str(x), } _dumper_map = { diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 6b133ec74..73f58373b 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -421,9 +421,7 @@ void init_graph_rt(py::module m) { common.def("invoke_op", [](const OpDef& def, const std::vector inputs, cg::ComputingGraph* graph) { cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); - auto opr = OpDef::apply_on_var_node(def, vinputs); - auto outputs = opr->usable_output(); - return to_tuple(outputs); + return to_tuple(OpDef::apply_on_var_node(def, vinputs)); }, py::arg(), py::arg(), py::arg("graph") = py::none()); diff --git a/imperative/python/src/imperative_rt.cpp b/imperative/python/src/imperative_rt.cpp index 93da734bf..812fa7243 100644 --- a/imperative/python/src/imperative_rt.cpp +++ b/imperative/python/src/imperative_rt.cpp @@ -109,9 +109,6 @@ void init_imperative_rt(py::module m) { py::class_>(m, "OpDef") .def("ctype", [](const OpDef& opdef) { - if (auto attr = opdef.try_cast_final()) { - return attr->type.c_str(); - } return opdef.dyn_typeinfo()->name; }) .def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 27846a9db..bdd29f4c3 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -14,41 +14,29 @@ #include "megbrain/imperative.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" -#include "megbrain/imperative/ops/tensor_manip.h" -#include "megbrain/imperative/ops/collective_comm.h" -#include "megbrain/imperative/ops/io_remote.h" -#include "megbrain/imperative/ops/cond_take.h" -#include "megbrain/imperative/ops/nms.h" -#include "megbrain/imperative/ops/elemwise.h" -#include "megbrain/imperative/ops/batch_norm.h" -#include "megbrain/imperative/ops/broadcast.h" #include "megbrain/imperative/ops/utility.h" +#include "megbrain/imperative/ops/autogen.h" namespace py = pybind11; +namespace { +auto normalize_enum(const std::string& in) { + std::string ret; + for (auto&& c : in) { + ret += toupper(c); + } + return ret; +} +} // anonymous namespace + void init_ops(py::module m) { using namespace mgb::imperative; - py::class_, OpDef>(m, "OprAttr") - .def(py::init<>()) - .def_readwrite("type", &OprAttr::type) - .def_readwrite("param", &OprAttr::param) - .def_readwrite("config", &OprAttr::config) - .def_property("param", - [](const OprAttr& attr) -> py::bytes { - return std::string(attr.param.begin(), attr.param.end()); - }, - [] (OprAttr& attr, py::bytes data) { - auto s = py::cast(data); - attr.param.clear(); - attr.param.insert(attr.param.end(), s.begin(), s.end()); - }); - py::class_, OpDef>(m, "BackwardGraph") .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, const mgb::SmallVector& inputs) { auto f = [pyf](OpDef& op, const mgb::SmallVector& inputs) { - return py::cast>(pyf(op.copy(), inputs)); + return py::cast>(pyf(op.shared_from_this(), inputs)); }; auto c = [pyc](const TensorPtr& tensor) { return pyc(tensor->dev_tensor()); @@ -56,162 +44,8 @@ void init_ops(py::module m) { return self.graph().interpret(f, c, inputs); }); - py::class_, OpDef>(m, "GetVarShape") - .def(py::init()); - -#define V(m) .value(#m, CollectiveComm::Mode::m) - py::enum_(m, "CollectiveCommMode") - V(REDUCE_SUM) - V(BROADCAST) - V(ALL_GATHER) - V(REDUCE_SCATTER_SUM) - V(ALL_REDUCE_SUM) - V(ALL_REDUCE_MAX) - V(ALL_REDUCE_MIN) - V(ALL_REDUCE_PROD) - V(GATHER) - V(SCATTER) - V(ALL_TO_ALL); -#undef V - - py::class_, OpDef>(m, "CollectiveComm") - .def(py::init<>()) - .def_readwrite("key", &CollectiveComm::key) - .def_readwrite("nr_devices", &CollectiveComm::nr_devices) - .def_readwrite("rank", &CollectiveComm::rank) - .def_readwrite("is_root", &CollectiveComm::is_root) - .def_readwrite("local_grad", &CollectiveComm::local_grad) - .def_readwrite("addr", &CollectiveComm::addr) - .def_readwrite("port", &CollectiveComm::port) - .def_readwrite("mode", &CollectiveComm::mode) - .def_readwrite("dtype", &CollectiveComm::dtype) - .def_readwrite("backend", &CollectiveComm::backend) - .def_readwrite("comp_node", &CollectiveComm::comp_node); - - py::class_, OpDef>(m, "RemoteSend") - .def(py::init<>()) - .def_readwrite("key", &RemoteSend::key) - .def_readwrite("addr", &RemoteSend::addr) - .def_readwrite("port", &RemoteSend::port) - .def_readwrite("rank_to", &RemoteSend::rank_to); - - py::class_, OpDef>(m, "RemoteRecv") - .def(py::init<>()) - .def_readwrite("key", &RemoteRecv::key) - .def_readwrite("addr", &RemoteRecv::addr) - .def_readwrite("port", &RemoteRecv::port) - .def_readwrite("rank_from", &RemoteRecv::rank_from) - .def_readwrite("shape", &RemoteRecv::shape) - .def_readwrite("cn", &RemoteRecv::cn) - .def_readwrite("dtype", &RemoteRecv::dtype); - - py::class_, OpDef>(m, "ParamPackSplit") - .def(py::init<>()) - .def_readwrite("offsets", &ParamPackSplit::offsets) - .def_readwrite("shapes", &ParamPackSplit::shapes); - - py::class_, OpDef>(m, "ParamPackConcat") - .def(py::init<>()) - .def_readwrite("offsets", &ParamPackConcat::offsets); - py::class_, OpDef>(m, "VirtualDep") .def(py::init<>()); - py::class_, OpDef>(m, "CondTake") - .def(py::init<>()); - - py::class_, OpDef>(m, "NMSKeep") - .def(py::init()) - .def_readwrite("iou_thresh", &NMSKeep::iou_thresh) - .def_readwrite("max_output", &NMSKeep::max_output); - - py::class_, OpDef> elemwise(m, "Elemwise"); - elemwise.def(py::init()) - .def_readwrite("mode", &Elemwise::mode); - -#define V(m) .value(#m, Elemwise::Mode::m) - py::enum_(elemwise, "Mode") - V(RELU) - V(ABS) - V(ACOS) - V(ASIN) - V(CEIL) - V(COS) - V(EXP) - V(EXPM1) - V(FLOOR) - V(LOG) - V(LOG1P) - V(NEGATE) - V(SIGMOID) - V(SIN) - V(TANH) - V(ABS_GRAD) - V(ADD) - V(FLOOR_DIV) - V(MAX) - V(MIN) - V(MOD) - V(MUL) - V(POW) - V(SIGMOID_GRAD) - V(SUB) - V(SWITCH_GT0) - V(TANH_GRAD) - V(TRUE_DIV) - V(LOG_SUM_EXP) - V(LT) - V(LEQ) - V(EQ) - V(SHL) - V(SHR) - V(COND_LEQ_MOV) - V(FUSE_MUL_ADD3) - V(FUSE_MUL_ADD4) - V(FUSE_ADD_RELU) - V(FUSE_ADD_SIGMOID) - V(FUSE_ADD_TANH) - V(FAST_TANH) - V(FAST_TANH_GRAD) - V(ROUND) - V(RMULH) - V(ATAN2) - V(ERF) - V(ERFINV) - V(ERFC) - V(ERFCINV) - V(H_SWISH) - V(H_SWISH_GRAD) - V(FUSE_ADD_H_SWISH) - V(NOT) - V(AND) - V(OR) - V(XOR); -#undef V - - py::class_, OpDef> batchnorm(m, "BatchNorm"); - batchnorm.def(py::init()) - .def_readwrite("param_dim", &BatchNorm::param_dim) - .def_readwrite("fwd_mode", &BatchNorm::fwd_mode) - .def_readwrite("epsilon", &BatchNorm::epsilon) - .def_readwrite("avg_factor", &BatchNorm::avg_factor) - .def_readwrite("scale", &BatchNorm::scale) - .def_readwrite("bias", &BatchNorm::bias); - -#define V(m) .value(#m, BatchNorm::Param::ParamDim::m) - py::enum_(batchnorm, "ParamDim") - V(DIM_11HW) - V(DIM_1CHW) - V(DIM_1C11); -#undef V - -#define V(m) .value(#m, BatchNorm::Param::FwdMode::m) - py::enum_(batchnorm, "FwdMode") - V(TRAINING) - V(INFERENCE); -#undef V - - py::class_, OpDef>(m, "Broadcast") - .def(py::init<>()); - + #include "opdef.py.inl" } diff --git a/imperative/python/test/unit/core/test_dtype_quant.py b/imperative/python/test/unit/core/test_dtype_quant.py index e06cb297e..36fdfef72 100644 --- a/imperative/python/test/unit/core/test_dtype_quant.py +++ b/imperative/python/test/unit/core/test_dtype_quant.py @@ -113,7 +113,7 @@ def test_quint8_typecvt(): data = np.random.random(shape).astype(np.float32) * 5 - 1 def typecvt(x, dt=None): - (y,) = apply(ops.TypeCvt(param=dt), x) + (y,) = apply(ops.TypeCvt(dtype=dt), x) return y # convert to quint8 @@ -194,7 +194,7 @@ def test_quint4_typecvt(): data = np.random.random(shape).astype(np.float32) * 5 - 1 def typecvt(x, dt=None): - (y,) = apply(ops.TypeCvt(param=dt), x) + (y,) = apply(ops.TypeCvt(dtype=dt), x) return y # convert to quint4 diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index a336a28a4..6a6880531 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -11,10 +11,9 @@ import collections import numpy as np import pytest -import megengine.core.ops.builtin import megengine.core.tensor.raw_tensor from megengine.core._trace_option import use_symbolic_shape -from megengine.core.ops._internal import all_ops +from megengine.core.ops import builtin from megengine.core.tensor import Tensor from megengine.core.tensor.core import apply from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor @@ -105,7 +104,7 @@ def canonize_inputs(inputs, *, config): need_cvt = False for i in old_inputs: if isinstance(i, RawTensor): - get_comp_node = lambda cn=i.device.to_c(): cn + get_comp_node = lambda cn=i.device: cn else: need_cvt = True inputs.append(i) @@ -193,91 +192,91 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): def transpose(*args, **kwargs): - op = all_ops.Dimshuffle(**kwargs).to_c() + op = builtin.Dimshuffle(**kwargs) return invoke_op(op, args) def broadcast(input, tshape): - op = all_ops.Broadcast().to_c() + op = builtin.Broadcast() return invoke_op(op, (input, tshape), canonize_reshape) def subtensor(input, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.Subtensor(items).to_c() + op = builtin.Subtensor(items) return invoke_op(op, (input, *tensors)) def set_subtensor(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.SetSubtensor(items).to_c() + op = builtin.SetSubtensor(items) return invoke_op(op, (input, value, *tensors)) def incr_subtensor(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.IncrSubtensor(items).to_c() + op = builtin.IncrSubtensor(items) return invoke_op(op, (input, value, *tensors)) def advance_indexing(input, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.IndexingMultiAxisVec(items).to_c() + op = builtin.IndexingMultiAxisVec(items) return invoke_op(op, (input, *tensors)) def set_advance_indexing(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.IndexingSetMultiAxisVec(items).to_c() + op = builtin.IndexingSetMultiAxisVec(items) return invoke_op(op, (input, value, *tensors)) def incr_advance_indexing(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.IndexingIncrMultiAxisVec(items).to_c() + op = builtin.IndexingIncrMultiAxisVec(items) return invoke_op(op, (input, value, *tensors)) def mesh_indexing(input, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.MeshIndexing(items).to_c() + op = builtin.MeshIndexing(items) return invoke_op(op, (input, *tensors)) def set_mesh_indexing(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.SetMeshIndexing(items).to_c() + op = builtin.SetMeshIndexing(items) return invoke_op(op, (input, value, *tensors)) def incr_mesh_indexing(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.IncrMeshIndexing(items).to_c() + op = builtin.IncrMeshIndexing(items) return invoke_op(op, (input, value, *tensors)) def batched_mesh_indexing(input, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.BatchedMeshIndexing(items).to_c() + op = builtin.BatchedMeshIndexing(items) return invoke_op(op, (input, *tensors)) def batched_set_mesh_indexing(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.BatchedSetMeshIndexing(items).to_c() + op = builtin.BatchedSetMeshIndexing(items) return invoke_op(op, (input, value, *tensors)) def batched_incr_mesh_indexing(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) - op = all_ops.BatchedIncrMeshIndexing(items).to_c() + op = builtin.BatchedIncrMeshIndexing(items) return invoke_op(op, (input, value, *tensors)) def test_transpose(): x = np.arange(10).reshape(2, 5).astype("int32") xx = as_raw_tensor(x) - (yy,) = transpose(xx, pattern="1x0") + (yy,) = transpose(xx, pattern=[1, -1, 0]) np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) diff --git a/imperative/python/tools/gen_ops.py b/imperative/python/tools/gen_ops.py deleted file mode 100755 index 299b3c19e..000000000 --- a/imperative/python/tools/gen_ops.py +++ /dev/null @@ -1,320 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from io import StringIO -import re -import argparse -import subprocess -import os -import textwrap -import inspect - -def camel2underscore( - name, *, - first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'), - all_cap_re = re.compile('([a-z])([A-Z]+)')): - if name.isupper(): - return name.lower() - s1 = first_cap_re.sub(r'\1_\2', name) - return all_cap_re.sub(r'\1_\2', s1).lower() - - -def caller_lineno(level=1): - f = inspect.stack()[level+1] - return '%s:%d' % (f.filename, f.lineno) - - -class Doc: - """wrap an identifier and doc""" - _id = None - - def __init__(self, id_, doc, typestr=None, default=None): - self._id = id_ - self.doc = doc - self.typestr = typestr - self.default = default - - def __str__(self): - return self._id - - -class Context: - fout = None - - def __init__(self): - self.fout = StringIO() - self.indent = 0 - self.skipped = [] - self.generated_signature = set() - self.generated_opr = dict() - - def write(self, text, *fmt, indent=0): - text = textwrap.dedent(text) - text = textwrap.indent(text, ' '*4*(self.indent + indent)) - text = text % fmt - if not text.endswith('\n'): - text += '\n' - self.fout.write(text) - - def _gen_signature(self, params, *, have_config=True, - has_out_dtype=False): - sig = ['self', '*'] - - for i, _ in params: - sig.append('{}=None'.format(i)) - - if have_config: - sig.extend(['name=None', 'comp_node=None', 'config=None']) - if has_out_dtype: - sig.append('dtype=None') - - if params: - sig.append('**kwargs') - - if sig[-1] == '*': - sig.pop() - return ', '.join(sig) - - def _write_canonize_inputs(self, inputs, convert_inputs, - convert_inputs_args=None, - has_out_dtype=False): - self._write_gen_config(has_out_dtype) - inputs = list(map(str, inputs)) - if convert_inputs_args is None: - if inputs[0][0] == '*': - arg = inputs[0][1:] - else: - arg = '[{}]'.format(', '.join(inputs)) - else: - arg = convert_inputs_args - self.write('inputs = helper.%s(%s, config=config)', - convert_inputs, arg) - - def _write_gen_config(self, has_out_dtype=False): - self.write('''\ - config = config or Config() - if name: - config.name = name - if comp_node: - config.comp_node = comp_node - ''') - if has_out_dtype: - self.write('''\ - if dtype: - config.dtype = dtype - ''') - self.write('self.config = config') - - def _write_make_params(self, params): - for pname, ptype in params: - self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)', - pname, pname, ptype) - self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)') - - def _write_doc(self, inputs, params, desc): - self.write('"""') - if isinstance(desc, Doc): - assert desc._id is None - self.write(desc.doc) - elif desc: - for i in textwrap.wrap(desc, 75): - self.write(i) - - self.write('') - for i in inputs: - name = str(i) - typestr = ':class:`.Tensor`' - if name[0] == '*': - name = name[1:] - typestr = 'list of ' + typestr - if isinstance(i, Doc): - self.write(':param %s: %s', name, i.doc) - if i.typestr is not None: - typestr = i.typestr - if typestr: - if not isinstance(i, Doc): - self.write(':param %s: ', name) - self.write(':type %s: %s', name, typestr) - - for pname, ptype in params: - self.write(':param %s: ', pname) - self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`', - pname, ptype) - - self.write(':param comp_node: see doc for *config*') - self.write(':param name: see doc for *config*') - self.write( - ':param config: give a :class:`.OperatorNodeConfig` object to set ' - 'operator name and comp node. This can also be achieved by passing ' - '*comp_node* and *name* separately.') - - self.write('"""') - - def _write_return(self, name, outputs): - self.write('opdef = helper.PodOpVisitor("%s", config, params)', name) - self.write('outputs = helper.create_op(opdef, inputs)') - if outputs: - self.write('outputs = [outputs[i] for i in %s]', - list(map(int, outputs))) - self.write('return helper.convert_outputs(outputs)') - - def decl_opr(self, name, *, inputs, params, desc=None, pyname=None, - canonize_input_vars=None, - canonize_input_vars_args=None, body=None, - outputs=None, version=0, has_out_dtype=False): - """ - :param inputs: name of variable inputs; a name starting with `*' means - a list of vars - :type inputs: list of str - :param params: (param name, param type) pairs; it can be a single - string representing the param type, and param name defaults to - 'param' - :type params: list of pair of str, or str - :param pyname: python function name - :param body: extra statements to be placed before calling _create_opr - :param outputs: the indices of output vars to be selected from raw opr - result - """ - - class OprItem: - def __init__(self, inputs, desc, params, version, has_out_dtype): - self.inputs = inputs - self.desc = desc - self.params = params - self.version = version - self.has_out_dtype = has_out_dtype - - if body: - self.skipped.append(name) - return - - signature = (name, params if isinstance(params, str) else frozenset(params), has_out_dtype, version) - if signature in self.generated_signature: - self.skipped.append(name) - return - else: - self.generated_signature.add(signature) - - body = body or [] - if isinstance(params, str): - params = [('param', params)] - assert params - - if name in self.generated_opr: - org_opr = self.generated_opr[name] - if version > org_opr.version: - def compare_doc(a, b): - if isinstance(a, str): - return a == b - else: - assert isinstance(a, Doc) - return a.doc == b.doc - - assert compare_doc(desc, org_opr.desc) - assert len(inputs) == len(org_opr.inputs) - for i, j in zip(inputs, org_opr.inputs): - assert compare_doc(i, j) - - self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) - else: - self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) - - def write_generated_oprs(self): - - for opr, opr_item in self.generated_opr.items(): - - name = opr - params = opr_item.params - version = opr_item.version - has_out_dtype = opr_item.has_out_dtype - - self.write('# %s', caller_lineno()) - self.write('class %s(PodOpVisitor):', name) - self.indent += 1 - - param_names, _ = zip(*params) - self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) - self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) - self.write('\n') - - self.write('def __init__(%s):', - self._gen_signature(params, - has_out_dtype=has_out_dtype)) - self.indent += 1 - - self._write_gen_config(has_out_dtype=has_out_dtype) - self.write('\n') - - self._write_make_params(params) - - self.write('\n') - self.indent -= 2 - - - def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None, - desc=None, local_defs=[], have_config=True, params=None, has_out_dtype=False): - self.skipped.append(name) - - def get_str(self): - return self.fout.getvalue() - - def all_list(self): - buf = StringIO() - print( - '[', - *(' "%s",' % i for i in self.generated_opr), - ']', - sep='\n', - file=buf - ) - return buf.getvalue() - - -def main(): - parser = argparse.ArgumentParser( - description='generate operator function def code from decl file') - parser.add_argument('inputs', nargs='+') - parser.add_argument('--output', '-o') - args = parser.parse_args() - - gen = Context() - exec_globals = { - 'decl_opr': gen.decl_opr, - 'decl_raw_opr': gen.decl_raw_opr, - 'Doc': Doc, - 'camel2underscore': camel2underscore, - } - for i in args.inputs: - print('generate ops from {}'.format(i)) - with open(i) as fin: - exec(compile(fin.read(), i, 'exec'), exec_globals) - - gen.write_generated_oprs() - try: - git_commit = subprocess.check_output( - ['git', 'rev-parse', 'HEAD'], universal_newlines=True, - cwd=os.path.dirname(os.path.realpath(__file__))).strip() - except: - git_commit = 'NOT_A_GIT_REPO' - - def relpath(*args): - d = os.path.dirname(__file__) - return os.path.join(d, *args) - - with open(relpath('ops.tpl.py')) as fin: - with open(args.output, 'w') as fout: - fout.write(fin.read() - .replace('{%all%}', gen.all_list()) - .replace('{%body%}', gen.get_str()) - .replace('{%git_commit%}', git_commit)) - - print('Skipped:') - print(*gen.skipped, sep='\n') - -if __name__ == '__main__': - main() diff --git a/imperative/python/tools/ops.tpl.py b/imperative/python/tools/ops.tpl.py deleted file mode 100644 index f91004b1f..000000000 --- a/imperative/python/tools/ops.tpl.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -"""This python module contains functions to apply the operators defined by -megbrain. - -.. note:: - Most of the functions are automatically generated, and their signature have - the form contain a ``param`` argument (or more than one arguments such as - :func:`convolution` that has ``param`` and ``execution_polity``) and also - accept keyword arguments. In such case, it can be called by either - providing a param object of appropriate type, or by passing the arguments - needed by the constructor of param object to the keyword arguments. - Furthermore, for a param that needs an enumeration member, the enum name - can be used to refer to the enum object. - - For example, the following statements are equivalent:: - - elemwise([a, b], mode='max') - elemwise([a, b], mode=opr_param_defs.Elemwise.Mode.MAX) - elemwise([a, b], param=opr_param_defs.Elemwise('max')) -""" - -__git_commit__ = "{%git_commit%}" - -import collections - -from . import helper -from .helper import PodOpVisitor -from . import param_defs -from ..._imperative_rt import OperatorNodeConfig as Config - -__all__ = {%all%} - -{%body%} diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 9f052ac0e..495d2da22 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -36,7 +36,7 @@ SmallVector OpDef::apply_on_physical_tensor( return def.trait()->apply_on_physical_tensor(def, inputs); } -cg::OperatorNodeBase* OpDef::apply_on_var_node( +VarNodeArray OpDef::apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { return def.trait()->apply_on_var_node(def, inputs); @@ -56,6 +56,14 @@ BackwardGraphResult OpDef::make_backward_graph( return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); } +size_t OpDef::hash() const { + return trait()->hash(*this); +} + +bool OpDef::is_same_st(const Hashable& rhs) const { + return trait()->is_same_st(*this, static_cast(rhs)); +} + const OpTrait* OpDef::trait() const { if (!m_trait) { m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo()); diff --git a/imperative/src/impl/op_trait.cpp b/imperative/src/impl/op_trait.cpp index a6a81ce17..56bffb620 100644 --- a/imperative/src/impl/op_trait.cpp +++ b/imperative/src/impl/op_trait.cpp @@ -23,7 +23,7 @@ namespace detail { struct StaticData { std::list registries; - std::unordered_map name2reg; + std::unordered_map name2reg; std::unordered_map type2reg; }; diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index a539c084a..f2d36a198 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -30,6 +30,32 @@ struct OpMeth: public thin_function { return this->Base::operator ()(args...); } }; +template +struct ToVarNodeArray: std::false_type {}; +template<> +struct ToVarNodeArray: std::true_type { + VarNodeArray operator()(const SymbolVar& inp) { + return {inp.node()}; + } +}; +template<> +struct ToVarNodeArray: std::true_type { + VarNodeArray operator()(const SymbolVarArray& inputs) { + return cg::to_var_node_array(inputs); + } +}; +template +struct ToVarNodeArray>: std::true_type { + VarNodeArray operator()(const std::array& inp) { + return cg::to_var_node_array({inp.begin(), inp.end()}); + } +}; +template<> +struct ToVarNodeArray: std::true_type { + VarNodeArray operator()(const cg::OperatorNodeBase* opr) { + return opr->usable_output(); + } +}; } // detail using OpDefMaker = detail::OpMeth< @@ -42,6 +68,8 @@ using InferOutputAttrsFallible = detail::OpMeth< decltype(OpDef::infer_output_attrs_fallible)>; using GradMaker = detail::OpMeth< decltype(OpDef::make_backward_graph)>; +using HashFunc = detail::OpMeth; +using IsSame = detail::OpMeth; struct OpTrait { const char* name; @@ -50,6 +78,8 @@ struct OpTrait { ApplyOnVarNode apply_on_var_node; InferOutputAttrsFallible infer_output_attrs_fallible; GradMaker make_backward_graph; + HashFunc hash; + IsSame is_same_st; OpTrait(const char* name); static OpTrait* find_by_name(const char* name); static OpTrait* find_by_typeinfo(Typeinfo* type); @@ -61,7 +91,9 @@ struct OpTrait { cb(apply_on_physical_tensor) \ cb(apply_on_var_node) \ cb(infer_output_attrs_fallible) \ - cb(make_backward_graph) + cb(make_backward_graph) \ + cb(hash) \ + cb(is_same_st) struct OpTraitRegistry { OpTrait* trait; @@ -97,6 +129,15 @@ struct OpTraitRegistry { void do_insert(Typeinfo* type); static OpTraitRegistry do_insert(const char* name); + + template, + typename = std::enable_if_t> + OpTraitRegistry& apply_on_var_node(T (*f)(const OpDef&, const VarNodeArray&)) { + return apply_on_var_node([=](const OpDef& opdef, const VarNodeArray& inputs) { + return To()(f(opdef, inputs)); + }); + } }; } // namespace imperative diff --git a/imperative/src/impl/ops/autogen.cpp b/imperative/src/impl/ops/autogen.cpp new file mode 100644 index 000000000..75e00ff1c --- /dev/null +++ b/imperative/src/impl/ops/autogen.cpp @@ -0,0 +1,46 @@ +/** + * \file imperative/src/impl/ops/autogen.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/autogen.h" + +#include "../op_trait.h" + +using namespace megdnn; + +// FIXME: remove this when mgb::hash support tuple_hash +namespace mgb { +namespace { +template +auto tail(T t, std::index_sequence) { + return std::make_tuple(std::get(t)...); +} +} // anonymous namespace +template +class HashTrait> { + constexpr static size_t length = sizeof...(Args); +public: + static size_t eval(const std::tuple &t) { + const T& val = std::get<0>(t); + if constexpr (!length) { + return mgb::hash(val); + } else { + return mgb::hash_pair_combine(mgb::hash(val), + mgb::hash(tail(t, std::make_index_sequence{}))); + } + } +}; +} // namespace mgb + +namespace mgb::imperative { + +#include "./opdef.cpp.inl" + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/impl/ops/batch_norm.cpp b/imperative/src/impl/ops/batch_norm.cpp index 913dc3b07..2ca8c7602 100644 --- a/imperative/src/impl/ops/batch_norm.cpp +++ b/imperative/src/impl/ops/batch_norm.cpp @@ -9,7 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megbrain/imperative/ops/batch_norm.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/dnn/batch_norm.h" #include "../op_trait.h" namespace mgb { @@ -19,9 +20,7 @@ namespace { std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { auto* node = &node_->cast_final_safe(); - auto&& param = node->param(); - return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon, - param.avg_factor, param.scale, param.bias); + return BatchNorm::make(node->param()); } cg::OperatorNodeBase* apply_on_var_node( @@ -33,13 +32,11 @@ cg::OperatorNodeBase* apply_on_var_node( "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); if (nr_inp == 3) { return opr::BatchNorm::make( - inputs[0], inputs[1], inputs[2], - {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] + inputs[0], inputs[1], inputs[2], bn_opr.param())[0] .node()->owner_opr(); } else { return opr::BatchNorm::make( - inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], - {bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0] + inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0] .node()->owner_opr(); } } @@ -52,7 +49,7 @@ std::tuple, bool> infer_output_attrs_fallible( mgb_assert(nr_inp == 3 ||nr_inp == 5, "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); // need running mean/variance - bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING; + bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING; size_t nr_out = need_stat? 5 : 3; SmallVector out_shapes(nr_out); auto&& i0 = inputs[0]; @@ -76,8 +73,6 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) .fallback(); } // anonymous namespace -MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm); - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 1a2bb4627..48140e17e 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -9,7 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megbrain/imperative/ops/broadcast.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/tensor_manip.h" + #include "../op_trait.h" namespace mgb { @@ -87,8 +89,6 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) .fallback(); } // anonymous namespace -MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast); - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 68a97bbb5..df83feff2 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -18,7 +18,7 @@ #include "megbrain/utils/hash.h" #endif // MGB_ENABLE_OPR_MM -#include "megbrain/imperative/ops/collective_comm.h" +#include "megbrain/imperative/ops/autogen.h" namespace mgb { namespace imperative { @@ -61,8 +61,8 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node) { auto [addr, port] = split_address(group_client->get_addr()); auto comp_node = node->config().get_single_comp_node().to_string_logical(); return std::make_shared( - comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(), - comm.local_grad(), addr, std::stoi(port), comm.param().mode, + comm.param().mode, comm.key(), comm.nr_devices(), comm.rank(), + comm.is_root(), comm.local_grad(), addr, std::stoi(port), comm.dtype(), comm.backend(), comp_node); } @@ -73,35 +73,6 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) } // anonymous namespace #endif // MGB_ENABLE_OPR_MM -bool CollectiveComm::is_same_st(const Hashable& another) const{ - auto* comm_opr = another.try_cast_final(); - if(!comm_opr){ - return false; - } - return as_tuple() == comm_opr->as_tuple(); -} - -size_t CollectiveComm::hash() const{ - XXHash xxhash{}; - auto append = [&xxhash](auto field){ - auto hash_val = HashTrait::eval(field); - xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); - }; - append(key); - append(nr_devices); - append(rank); - append(is_root); - append(local_grad); - append(addr); - append(port); - append(mode); - append(backend); - append(comp_node); - return xxhash.digest(); -} - -MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index cc49e5c23..3fa3643a3 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -9,8 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megbrain/imperative/ops/cond_take.h" -#include "megbrain/imperative/ops/opr_attr.h" +#include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/misc.h" #include "../dnn_op_helper.h" #include "../op_trait.h" @@ -19,8 +18,6 @@ using namespace megdnn; namespace mgb::imperative { -MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); - namespace { class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 8b789b490..40c6aeab7 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -9,7 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megbrain/imperative/ops/elemwise.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/basic_arith.h" + #include "../op_trait.h" namespace mgb { @@ -33,7 +35,7 @@ std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op_def = def.cast_final_safe(); - auto trait = Elemwise::ModeTrait::from_mode(op_def.mode); + auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); mgb_assert(inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually", trait.name, trait.arity, inputs.size()); @@ -70,8 +72,6 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) .fallback(); } // anonymous namespace -MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise); - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index 096082922..86eb59bb5 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -18,7 +18,7 @@ #include "megbrain/opr/mm_handler.h" #endif // MGB_ENABLE_OPR_MM -#include "megbrain/imperative/ops/io_remote.h" +#include "megbrain/imperative/ops/autogen.h" namespace mgb { namespace imperative { @@ -60,45 +60,5 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) } // anonymous namespace #endif // MGB_ENABLE_OPR_MM -bool RemoteSend::is_same_st(const Hashable& another) const{ - return as_tuple() == another.cast_final().as_tuple(); -} - -size_t RemoteSend::hash() const{ - XXHash xxhash; - auto append = [&xxhash](auto field){ - auto hash_val = HashTrait::eval(field); - xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); - }; - append(key); - append(addr); - append(port); - append(rank_to); - return xxhash.digest(); -} - -bool RemoteRecv::is_same_st(const Hashable& another) const{ - return as_tuple() == another.cast_final().as_tuple(); -} - -size_t RemoteRecv::hash() const{ - XXHash xxhash; - auto append = [&xxhash](auto field){ - auto hash_val = HashTrait::eval(field); - xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); - }; - append(key); - append(addr); - append(port); - append(rank_from); - append(cn.to_string()); - append(dtype.handle()); - append(shape.to_string()); - return xxhash.digest(); -} - -MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend); -MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv); - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/nms.cpp b/imperative/src/impl/ops/nms.cpp index 45550a68b..ff78f1220 100644 --- a/imperative/src/impl/ops/nms.cpp +++ b/imperative/src/impl/ops/nms.cpp @@ -11,7 +11,7 @@ #include "../op_trait.h" -#include "megbrain/imperative/ops/nms.h" +#include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/standalone/nms_opr.h" namespace mgb { @@ -37,8 +37,6 @@ OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) .fallback(); } // anonymous namespace -MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep); - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp new file mode 100644 index 000000000..3d88826f3 --- /dev/null +++ b/imperative/src/impl/ops/specializations.cpp @@ -0,0 +1,630 @@ +/** + * \file imperative/src/impl/ops/autogen.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +// FIXME: split this file into separate files for each specialized op + +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/dnn/convolution.h" +#include "megbrain/opr/dnn/adaptive_pooling.h" +#include "megbrain/opr/dnn/fake_quant.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/dnn/local.h" +#include "megbrain/opr/dnn/roi_align.h" +#include "megbrain/opr/dnn/roi_pooling.h" +#include "megbrain/opr/basic_arith.h" +#include "megbrain/opr/blas.h" +#include "megbrain/opr/imgproc.h" +#include "megbrain/opr/indexing.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/misc.h" +#include "megbrain/opr/nn_int.h" +#include "megbrain/opr/rand.h" +#include "megbrain/opr/tensor_gen.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/opr/utility.h" + +#include "../op_trait.h" + +namespace mgb::imperative { + +namespace { namespace convolution { +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + return Convolution::make(node->param(), node->execution_policy()); +} + +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& conv = static_cast(def); + return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy()); +} + +OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // convolution + +namespace { namespace convolution_backward_data { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& conv = static_cast(def); + cg::OperatorNodeConfig config; + if (inputs.size() == 2) { + return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); + } else { + mgb_assert(inputs.size() == 3); + return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); + } +} + +OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // convolution_backward_data + +namespace { namespace dimshuffle { +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + std::vector pattern(node->param().pattern_len); + for (size_t i = 0; i < node->param().pattern_len; ++ i) { + pattern[i] = node->param().pattern[i]; + } + return Dimshuffle::make(pattern); +} + +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& ds = static_cast(def); + return opr::Dimshuffle::make(inputs[0], ds.pattern); +} + +OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // dimshuffle + +namespace { namespace add_axis { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& add_axis = static_cast(def); + using Desc = opr::AxisAddRemove::AxisDesc; + std::vector param; + for (auto&& i : add_axis.axis) { + param.push_back(Desc::make_add(i)); + } + return opr::AxisAddRemove::make(inputs[0], param); +} + +OP_TRAIT_REG(AddAxis, AddAxis) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // add_axis + +namespace { namespace remove_axis { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& remove_axis = static_cast(def); + using Desc = opr::AxisAddRemove::AxisDesc; + std::vector param; + for (auto&& i : remove_axis.axis) { + param.push_back(Desc::make_remove(i)); + } + return opr::AxisAddRemove::make(inputs[0], param); +} + +OP_TRAIT_REG(RemoveAxis, RemoveAxis) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // remove_axis + +namespace { namespace top_k { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& topk = static_cast(def); + return opr::TopK::make(inputs[0], inputs[1], topk.param())[0] + .node()->owner_opr(); +} + +OP_TRAIT_REG(TopK, TopK) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // top_k + +namespace { namespace reduce { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& reduce = static_cast(def); + if (inputs.size() > 1) { + return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]); + } else { + return opr::Reduce::make(inputs[0], reduce.param()); + } +} + +OP_TRAIT_REG(Reduce, Reduce) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // reduce + +namespace { namespace adaptive_pooling { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& pool = static_cast(def); + return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param()); +} + +OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // adaptive_pooling + +namespace { namespace conv_bias { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& conv = static_cast(def); + cg::OperatorNodeConfig config{conv.dtype}; + if (inputs.size() == 2) { + return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); + } else if (inputs.size() == 3) { + return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); + } else if (inputs.size() == 4) { + return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); + } + mgb_assert(0); +} + +OP_TRAIT_REG(ConvBias, ConvBias) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // conv_bias + +namespace { namespace batch_conv_bias { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& conv = static_cast(def); + cg::OperatorNodeConfig config{conv.dtype}; + if (inputs.size() == 2) { + return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); + } else if (inputs.size() == 3) { + return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); + } else if (inputs.size() == 4) { + return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); + } + mgb_assert(0); +} + +OP_TRAIT_REG(BatchConvBias, BatchConvBias) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // batch_conv_bias + +namespace { namespace pooling { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& pool = static_cast(def); + return opr::Pooling::make(inputs[0], pool.param()); +} +OP_TRAIT_REG(Pooling, Pooling) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // pooling + +namespace { namespace matrix_mul { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& matmul = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param()); +} +OP_TRAIT_REG(MatrixMul, MatrixMul) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // matrix_mul + +namespace { namespace batched_matrix_mul { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& matmul = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param()); +} +OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // batched_matrix_mul + +namespace { namespace dot { +auto apply_on_var_node( + const OpDef&, + const VarNodeArray& inputs) { + mgb_assert(inputs.size() == 2); + return opr::Dot::make(inputs[0], inputs[1]); +} +OP_TRAIT_REG(Dot, Dot) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // dot + +namespace { namespace argsort { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& argsort = static_cast(def); + return opr::Argsort::make(inputs[0], argsort.param()); +} +OP_TRAIT_REG(Argsort, Argsort) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // argsort + +namespace { namespace argmax { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& argmax = static_cast(def); + return opr::Argmax::make(inputs[0], argmax.param()); +} +OP_TRAIT_REG(Argmax, Argmax) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // argmax + +namespace { namespace argmin { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& argmin = static_cast(def); + return opr::Argmin::make(inputs[0], argmin.param()); +} +OP_TRAIT_REG(Argmin, Argmin) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // argmin + +namespace { namespace warp_perspective { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& warp = static_cast(def); + if (inputs.size() == 3) { + return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param()); + } else { + mgb_assert(inputs.size() == 4); + return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param()); + } +} +OP_TRAIT_REG(WarpPerspective, WarpPerspective) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // warp_perspective + +namespace { namespace group_local { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& local = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::GroupLocal::make(inputs[0], inputs[1], local.param()); +} +OP_TRAIT_REG(GroupLocal, GroupLocal) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // group_local + +namespace { namespace indexing_one_hot { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param()); +} +OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // indexing_one_hot + +namespace { namespace indexing_set_one_hot { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 3); + return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param()); +} +OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // indexing_set_one_hot + +namespace { namespace typecvt { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + return opr::TypeCvt::make(inputs[0], op.dtype); +} +OP_TRAIT_REG(TypeCvt, TypeCvt) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // typecvt + +namespace { namespace concat { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + cg::OperatorNodeConfig config{op.comp_node}; + return opr::Concat::make(inputs, op.axis, config); +} +OP_TRAIT_REG(Concat, Concat) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // concat + +namespace { namespace copy { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + cg::OperatorNodeConfig config{op.comp_node}; + return opr::Copy::make(inputs[0], config); +} +OP_TRAIT_REG(Copy, Copy) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // copy + +namespace { namespace identity { +auto apply_on_var_node( + const OpDef&, + const VarNodeArray& inputs) { + mgb_assert(inputs.size() == 1); + return opr::Identity::make(inputs[0]); +} +OP_TRAIT_REG(Identity, Identity) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // identity + +namespace { namespace uniform_rng { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + return opr::UniformRNG::make(inputs[0], op.param()); +} +OP_TRAIT_REG(UniformRNG, UniformRNG) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // uniform_rng + +namespace { namespace gaussian_rng { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + return opr::GaussianRNG::make(inputs[0], op.param()); +} +OP_TRAIT_REG(GaussianRNG, GaussianRNG) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // gaussian_rng + +namespace { namespace roi_align { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::ROIAlign::make(inputs[0], inputs[1], op.param()); +} +OP_TRAIT_REG(ROIAlign, ROIAlign) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // roi_align + +#if MGB_CUDA +namespace { namespace nvof { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + return opr::NvOf::make(inputs[0], op.param()); +} +OP_TRAIT_REG(NvOf, NvOf) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // nvof +#endif + +namespace { namespace linspace { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 3); + cg::OperatorNodeConfig config{op.comp_node}; + return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); +} +OP_TRAIT_REG(Linspace, Linspace) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // linspace + +namespace { namespace eye { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + cg::OperatorNodeConfig config{op.comp_node}; + opr::Eye::Param param{op.k, op.dtype.enumv()}; + return opr::Eye::make(inputs[0], param, config); +} +OP_TRAIT_REG(Eye, Eye) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // eye + +namespace { namespace roi_pooling { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 3); + return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()); +} +OP_TRAIT_REG(ROIPooling, ROIPooling) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // roi_pooling + +namespace { namespace remap { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::Remap::make(inputs[0], inputs[1], op.param()); +} +OP_TRAIT_REG(Remap, Remap) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // remap + +namespace { namespace reshape { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::Reshape::make(inputs[0], inputs[1], op.param()); +} +OP_TRAIT_REG(Reshape, Reshape) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // reshape + +namespace { +auto get_index( + const VarNodeArray& inputs, size_t vidx, + const std::vector>& mask) { + size_t length = mask.size(); + opr::Subtensor::IndexDesc ret(length); + for (size_t i = 0; i < length; ++ i) { + auto&& [axis, begin, end, step, idx] = mask[i]; + ret[i].axis = axis; + if (idx) { + ret[i].idx = inputs[vidx++]; + } else { + mgb_assert(begin || end || step); + if (begin) ret[i].begin = inputs[vidx++]; + if (end) ret[i].end = inputs[vidx++]; + if (step) ret[i].step = inputs[vidx++]; + } + } + mgb_assert(vidx == inputs.size()); + return ret; +} +#define IN1 inputs[0] +#define IN2 inputs[0], inputs[1] + +#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ +namespace NAME##_impl { \ +auto apply_on_var_node( \ + const OpDef& def, \ + const VarNodeArray& inputs) { \ + auto&& op = static_cast(def); \ + return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \ +} \ +OP_TRAIT_REG(NAME, NAME) \ + .apply_on_var_node(apply_on_var_node) \ + .fallback(); \ +} + +FANCY_INDEXING_IMPL(Subtensor, 1) +FANCY_INDEXING_IMPL(SetSubtensor, 2) +FANCY_INDEXING_IMPL(IncrSubtensor, 2) +FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1) +FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2) +FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2) +FANCY_INDEXING_IMPL(MeshIndexing, 1) +FANCY_INDEXING_IMPL(IncrMeshIndexing, 2) +FANCY_INDEXING_IMPL(SetMeshIndexing, 2) +FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1) +FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2) +FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2) + +#undef FANCY_INDEXING_IMPL +#undef IN1 +#undef IN2 +} // anonymous namespace + +namespace { namespace fake_quant { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 3); + return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param()); +} +OP_TRAIT_REG(FakeQuant, FakeQuant) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // fake_quant +namespace { namespace elemwise_multi_type { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + OperatorNodeConfig config{op.dtype}; + return opr::ElemwiseMultiType::make(inputs, op.param(), config); +} +OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // fake_quant + +namespace { namespace svd { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 1); + return opr::SVD::make(inputs[0], op.param()); +} +OP_TRAIT_REG(SVD, SVD) + .apply_on_var_node(apply_on_var_node) + .fallback(); +}} // svd + +} // namespace mgb::imperative \ No newline at end of file diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index a4d23de37..7123e13c6 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -9,7 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megbrain/imperative/ops/tensor_manip.h" +#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/opr/tensor_manip.h" #include "../op_trait.h" @@ -140,8 +140,4 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) .fallback(); } // namespace -MGB_DYN_TYPE_OBJ_FINAL_IMPL(GetVarShape); -MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); -MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); - } // namespace mgb::imperative diff --git a/imperative/src/impl/profiler.cpp b/imperative/src/impl/profiler.cpp index 65deeb96f..2ab3aa173 100644 --- a/imperative/src/impl/profiler.cpp +++ b/imperative/src/impl/profiler.cpp @@ -130,7 +130,7 @@ void Profiler::start(uint32_t flags) { // TODO: assign parent entry.parent = 0; // Record apply context and save to m_profile - entry.op = def.copy(); + entry.op = const_cast(def).shared_from_this(); for (auto&& input : inputs) { entry.inputs.push_back({m_tensor_recorder.record_tensor(input), shape2vector(input->layout()), @@ -172,31 +172,31 @@ void Profiler::start(uint32_t flags) { if (flags & PROFILE_FOOTPRINT) { hook_apply_on_var_node->apply_hook( [this](auto&& apply, const OpDef& def, - VarNodeArray inputs) -> cg::OperatorNodeBase* { - auto* operator_node = apply(def, std::move(inputs)); + VarNodeArray inputs) -> VarNodeArray { + auto vars = apply(def, std::move(inputs)); std::remove_reference_t top; { MGB_LOCK_GUARD(m_lock); if (m_entry_stack.empty()) { - return operator_node; + return vars; } top = m_entry_stack.top(); } auto [current_op, current_entry, thread_id] = top; if (current_op != &def || thread_id != std::this_thread::get_id()) { - return operator_node; + return vars; } auto&& footprint_result = - footprint.calc_footprint(operator_node); + footprint.calc_footprint(vars[0]->owner_opr()); current_entry->memory = footprint_result.memory; current_entry->computation = footprint_result.computation; #if MGB_ENABLE_JSON current_entry->param = footprint_result.param; #endif - return operator_node; + return vars; }); } m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 85fc8f39b..c48d8f571 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -590,7 +590,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( for (size_t i = 0; i < inputs.size(); ++ i) { vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); } - auto opr = OpDef::apply_on_var_node(opdef, vinputs); + auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); mgb_assert(!opr->same_type()); for (auto &&i : opr->input()) { mgb_assert(i->owner_opr()->same_type()); @@ -639,7 +639,7 @@ ProxyGraph::make_backward_graph( return ret.first->second; }; auto inputs = make_input_place_holders(input_descs); - auto fwd = OpDef::apply_on_var_node(opdef, inputs); + auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr(); auto&& outputs = fwd->usable_output(); SmallVector output_descs; for (auto&& i : outputs) { @@ -799,7 +799,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef, const SmallVector& inputs) { mgb_assert(!m_cur_opr); auto vinputs = make_input_place_holders(inputs); - return OpDef::apply_on_var_node(opdef, vinputs); + return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); } VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector& inputs) { diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index d57f7edba..5e122f9c8 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -26,13 +26,12 @@ struct BackwardGraphResult { std::vector input_has_grad; }; -class OpDef : public Hashable { +class OpDef : public Hashable, + public std::enable_shared_from_this { mutable const OpTrait* m_trait = nullptr; public: virtual ~OpDef() = default; - virtual std::shared_ptr copy() const = 0; - static std::shared_ptr make_from_op_node( cg::OperatorNodeBase* node); @@ -40,7 +39,7 @@ public: const OpDef& def, const SmallVector& inputs); - static cg::OperatorNodeBase* apply_on_var_node( + static cg::VarNodeArray apply_on_var_node( const OpDef& def, const VarNodeArray& inputs); @@ -56,25 +55,17 @@ public: const OpTrait* trait() const; - virtual size_t hash() const { - mgb_throw(MegBrainError, "not implemented"); - } + virtual size_t hash() const; - virtual bool is_same_st(const Hashable&) const { - mgb_throw(MegBrainError, "not implemented"); - } + virtual bool is_same_st(const Hashable&) const; }; template class OpDefImplBase : public OpDef { public: - virtual std::shared_ptr copy() const override { - return std::shared_ptr(new T(this->cast_final_safe())); - } - template - static std::shared_ptr make(const Args& ...args) { - return std::shared_ptr(new T(args...)); + static std::shared_ptr make(Args&& ...args) { + return std::make_shared(std::forward(args)...); } }; diff --git a/imperative/src/include/megbrain/imperative/ops/cond_take.h b/imperative/src/include/megbrain/imperative/ops/autogen.h similarity index 54% rename from imperative/src/include/megbrain/imperative/ops/cond_take.h rename to imperative/src/include/megbrain/imperative/ops/autogen.h index bed3465ce..a17b495d5 100644 --- a/imperative/src/include/megbrain/imperative/ops/cond_take.h +++ b/imperative/src/include/megbrain/imperative/ops/autogen.h @@ -1,5 +1,5 @@ /** - * \file imperative/src/include/megbrain/imperative/ops/cond_take.h + * \file imperative/src/include/megbrain/imperative/ops/autogen.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -12,22 +12,15 @@ #pragma once #include "megbrain/imperative/op_def.h" +#include "megdnn/opr_param_defs.h" +#include "megbrain/opr/param_defs.h" -namespace mgb::imperative { - -class CondTake : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - CondTake() = default; +#include "megbrain/utils/hash.h" - size_t hash() const override { - return reinterpret_cast(dyn_typeinfo()); - } - - bool is_same_st(const Hashable& rhs) const override { - return rhs.dyn_typeinfo() == dyn_typeinfo(); - } +namespace mgb::imperative { -}; +// TODO: split into separate files to avoid recompiling all +// impl/ops/*.cpp on each modification of ops.td +#include "./opdef.h.inl" } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/batch_norm.h b/imperative/src/include/megbrain/imperative/ops/batch_norm.h deleted file mode 100644 index 0fc2fb3e7..000000000 --- a/imperative/src/include/megbrain/imperative/ops/batch_norm.h +++ /dev/null @@ -1,70 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/batch_norm.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/opr/dnn/batch_norm.h" -#include "megbrain/imperative/op_def.h" -#include "megbrain/utils/hash.h" - -namespace mgb::imperative { - -class BatchNorm : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - using Param = opr::BatchNorm::Param; - - Param::ParamDim param_dim; - Param::FwdMode fwd_mode; - double epsilon; - double avg_factor; - float scale; - float bias; - - BatchNorm() = default; - - BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_, - double epsilon_, double avg_factor_, float scale_, float bias_) - : param_dim(param_dim_), - fwd_mode(fwd_mode_), - epsilon(epsilon_), - avg_factor(avg_factor_), - scale(scale_), - bias(bias_) {} - - size_t hash() const override { - XXHash xxhash{}; - auto append = [&xxhash](auto field){ - auto hash_val = HashTrait::eval(field); - xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); - }; - append(param_dim); - append(fwd_mode); - append(epsilon); - append(avg_factor); - append(scale); - append(bias); - return xxhash.digest(); - } - - bool is_same_st(const Hashable& rhs_) const override { - auto&& rhs = static_cast(rhs_); - return rhs.param_dim == param_dim - && rhs.fwd_mode == fwd_mode - && rhs.epsilon == epsilon - && rhs.avg_factor == avg_factor - && rhs.scale == scale - && rhs.bias == bias; - } - -}; - -} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/broadcast.h b/imperative/src/include/megbrain/imperative/ops/broadcast.h deleted file mode 100644 index 1c2f7075d..000000000 --- a/imperative/src/include/megbrain/imperative/ops/broadcast.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/broadcast.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/opr/tensor_manip.h" -#include "megbrain/imperative/ops/opr_attr.h" -#include "megbrain/imperative/op_def.h" - -namespace mgb::imperative { - -class Broadcast : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - Broadcast() = default; - - size_t hash() const override { - return reinterpret_cast(dyn_typeinfo()); - } - - bool is_same_st(const Hashable& rhs) const override { - return true; - } - -}; - -} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/collective_comm.h b/imperative/src/include/megbrain/imperative/ops/collective_comm.h deleted file mode 100644 index f45fff652..000000000 --- a/imperative/src/include/megbrain/imperative/ops/collective_comm.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/collective_comm.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/imperative/op_def.h" -#include "megbrain/opr/param_defs.h" - -namespace mgb { -namespace imperative { - -class CollectiveComm : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; - -public: - using Mode = megdnn::param::CollectiveComm::Mode; - - CollectiveComm() = default; - CollectiveComm(const std::string& key_, size_t nr_devices_, - uint32_t rank_, bool is_root_, bool local_grad_, - const std::string& addr_, uint32_t port_, - const Mode& mode_, - const DType& dtype_, const std::string& backend_, - const std::string& comp_node_) - : key(key_), - nr_devices(nr_devices_), - rank(rank_), - is_root(is_root_), - local_grad(local_grad_), - addr(addr_), - port(port_), - mode(mode_), - dtype(dtype_), - backend(backend_), - comp_node(comp_node_) {} - std::string key; - size_t nr_devices; - uint32_t rank; - bool is_root; - bool local_grad; - std::string addr; - uint32_t port; - Mode mode; - DType dtype; - std::string backend; - std::string comp_node; - - size_t hash() const override; - - bool is_same_st(const Hashable& another) const override; - auto as_tuple() const{ - return std::tuple(key, nr_devices, rank, is_root, - local_grad, addr, port, mode, dtype, - backend, comp_node); - } -}; - -} // namespace imperative -} // namespace mgb - -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/include/megbrain/imperative/ops/elemwise.h b/imperative/src/include/megbrain/imperative/ops/elemwise.h deleted file mode 100644 index 5878f08fa..000000000 --- a/imperative/src/include/megbrain/imperative/ops/elemwise.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/elemwise.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/opr/basic_arith.h" -#include "megbrain/imperative/op_def.h" - -namespace mgb::imperative { - -class Elemwise : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - using Mode = opr::Elemwise::Mode; - using ModeTrait = megdnn::Elemwise::ModeTrait; - - Mode mode; - - Elemwise() = default; - - Elemwise(const Mode& mode_): mode(mode_) {} - - size_t hash() const override { - return hash_pair_combine(mgb::hash(mode), reinterpret_cast(dyn_typeinfo())); - } - - bool is_same_st(const Hashable& rhs_) const override { - auto&& rhs = static_cast(rhs_); - return rhs.mode == mode; - } - -}; - -} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/io_remote.h b/imperative/src/include/megbrain/imperative/ops/io_remote.h deleted file mode 100644 index 9ec6e4f43..000000000 --- a/imperative/src/include/megbrain/imperative/ops/io_remote.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/io_remote.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/imperative/op_def.h" - -namespace mgb { -namespace imperative { - -class RemoteSend : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; - -public: - RemoteSend() = default; - RemoteSend(const std::string& key_, const std::string& addr_, - uint32_t port_, uint32_t rank_to_) - : key(key_), - addr(addr_), - port(port_), - rank_to(rank_to_) {} - std::string key; - std::string addr; - uint32_t port; - uint32_t rank_to; - - size_t hash() const override; - bool is_same_st(const Hashable& another) const override; - - auto as_tuple() const{ - return std::tuple(key, addr, port, rank_to); - } -}; - -class RemoteRecv : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; - -public: - RemoteRecv() = default; - RemoteRecv(const std::string& key_, const std::string& addr_, - uint32_t port_, uint32_t rank_from_, TensorShape shape_, - CompNode cn_, const DType& dtype_) - : key(key_), - addr(addr_), - port(port_), - rank_from(rank_from_), - cn(cn_), - shape(shape_), - dtype(dtype_) {} - std::string key; - std::string addr; - uint32_t port; - uint32_t rank_from; - CompNode cn; - TensorShape shape; - DType dtype; - - size_t hash() const override; - bool is_same_st(const Hashable& another) const override; - - auto as_tuple() const{ - return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string()); - } -}; - -} // namespace imperative -} // namespace mgb - -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/include/megbrain/imperative/ops/nms.h b/imperative/src/include/megbrain/imperative/ops/nms.h deleted file mode 100644 index ed66cd8e6..000000000 --- a/imperative/src/include/megbrain/imperative/ops/nms.h +++ /dev/null @@ -1,41 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/nms.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/imperative/op_def.h" - -namespace mgb::imperative { - -class NMSKeep : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - float iou_thresh; //!< IoU threshold for overlapping - uint32_t max_output; //!< max number of output boxes per batch - NMSKeep() = default; - NMSKeep(float iou_thresh_, uint32_t max_output_): - iou_thresh(iou_thresh_), max_output(max_output_) {} - - size_t hash() const override { - return hash_pair_combine( - hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)), - reinterpret_cast(dyn_typeinfo())); - } - - bool is_same_st(const Hashable& rhs_) const override { - auto&& rhs = static_cast(rhs_); - return rhs.iou_thresh == iou_thresh - && rhs.max_output == max_output; - } - -}; - -} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h b/imperative/src/include/megbrain/imperative/ops/tensor_manip.h deleted file mode 100644 index 8d3d44d38..000000000 --- a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h +++ /dev/null @@ -1,99 +0,0 @@ -/** - * \file imperative/src/include/megbrain/imperative/ops/tensor_manip.h - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - */ - -#pragma once - -#include "megbrain/imperative/op_def.h" - -#include "megbrain/utils/hash.h" - -namespace mgb::imperative { - -class GetVarShape : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - GetVarShape() = default; - - size_t hash() const override { - return reinterpret_cast(dyn_typeinfo()); - } - - bool is_same_st(const Hashable& rhs) const override { - return rhs.dyn_typeinfo() == dyn_typeinfo(); - } -}; - -class ParamPackSplit : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; - -public: - ParamPackSplit() = default; - - ParamPackSplit(std::vector& offsets_, - std::vector>& shapes_) - : offsets(offsets_), shapes(shapes_) {} - - std::vector offsets; - std::vector> shapes; - - size_t hash() const override { - XXHash builder; - for (auto&& offset : offsets) { - builder.update(&offset, sizeof(offset)); - } - auto&& offset_cnt = offsets.size(); - builder.update(&offset_cnt, sizeof(offset_cnt)); - for (auto&& shape : shapes) { - for (auto&& dim_len : shape) { - builder.update(&dim_len, sizeof(dim_len)); - } - auto&& dim_cnt = shape.size(); - builder.update(&dim_cnt, sizeof(dim_cnt)); - } - auto&& shape_cnt = shapes.size(); - builder.update(&shape_cnt, sizeof(shape_cnt)); - return builder.digest(); - } - - bool is_same_st(const Hashable& rhs) const override { - auto&& pps = rhs.cast_final_safe(); - return offsets == pps.offsets && shapes == pps.shapes; - } -}; - -class ParamPackConcat : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; - -public: - ParamPackConcat() = default; - - ParamPackConcat(std::vector& offsets_) - : offsets(offsets_) {} - - std::vector offsets; - - size_t hash() const override { - XXHash builder; - for (auto&& offset : offsets) { - builder.update(&offset, sizeof(offset)); - } - auto&& offset_cnt = offsets.size(); - builder.update(&offset_cnt, sizeof(offset_cnt)); - return builder.digest(); - } - - bool is_same_st(const Hashable& rhs) const override { - auto&& ppc = rhs.cast_final_safe(); - return offsets == ppc.offsets; - } -}; - -} // namespace mgb::imperative diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index e07f8dd22..d673c7f62 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -29,18 +29,18 @@ TEST(TestImperative, BackwardGraphBasic) { using Param = opr::Elemwise::Param; Param param{Param::Mode::MUL}; - OprAttr attr{"Elemwise", {}, {}}; - attr.param.write_pod(param); + auto attr = OprAttr::make("Elemwise"); + attr->cast_final_safe().param.write_pod(param); SmallVector input_descs; for (auto&& i : inputs) { input_descs.push_back({i->layout(), i->comp_node()}); } - auto result = OpDef::make_backward_graph(attr, input_descs, {true, true}, {true}); + auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); auto&& save_for_backward = result.save_for_backward; auto&& input_has_grad = result.input_has_grad; - auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); + auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); inputs.push_back(outputs[0]); hvs.push_back(*gen({42})); inputs.push_back(Tensor::make(hvs.back())); @@ -82,16 +82,16 @@ TEST(TestImperative, BackwardGraphIdentity) { SmallVector inputs; inputs.push_back(a); - OprAttr attr{"Identity", {}, {}}; - attr.param.write_pod({}); + auto attr = OprAttr::make("Identity"); + attr->cast_final_safe().param.write_pod({}); SmallVector input_descs; input_descs.push_back({a->layout(), a->comp_node()}); - auto result = OpDef::make_backward_graph(attr, input_descs, {true}, {true}); + auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); auto&& save_for_backward = result.save_for_backward; auto&& input_has_grad = result.input_has_grad; - auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); + auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs); inputs.push_back(outputs[0]); inputs.push_back(dc); mgb_assert(save_for_backward.size() == inputs.size()); diff --git a/imperative/src/test/collective_comm.cpp b/imperative/src/test/collective_comm.cpp index 860450679..4c28fd659 100644 --- a/imperative/src/test/collective_comm.cpp +++ b/imperative/src/test/collective_comm.cpp @@ -10,7 +10,7 @@ */ #include "./helper.h" -#include "megbrain/imperative/ops/collective_comm.h" +#include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/mm_handler.h" using namespace mgb; @@ -32,12 +32,13 @@ TEST(TestImperative, AllReduceBasic) { } auto run = [&](std::shared_ptr hnd, uint32_t idx) { - imperative::CollectiveComm - def{"all_reduce", 2, idx, idx==0, false, server_addr, port, + auto def = + imperative::CollectiveComm::make( megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM, - dtype::Float32(), "nccl", ""}; + "all_reduce", 2, idx, idx==0, false, server_addr, port, + dtype::Float32(), "nccl", ""); auto inp = Tensor::make(*hnd); - auto oup = OpDef::apply_on_physical_tensor(def, {inp}); + auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); HostTensorND host_v; host_v.copy_from(oup[0]->dev_tensor()).sync(); MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); diff --git a/imperative/src/test/cond_take.cpp b/imperative/src/test/cond_take.cpp index 0a013c748..9eca1c4da 100644 --- a/imperative/src/test/cond_take.cpp +++ b/imperative/src/test/cond_take.cpp @@ -10,7 +10,7 @@ */ #include "./helper.h" -#include "megbrain/imperative/ops/cond_take.h" +#include "megbrain/imperative/ops/autogen.h" using namespace mgb; using namespace imperative; diff --git a/imperative/src/test/helper.cpp b/imperative/src/test/helper.cpp index 2c369ae61..2c2d19e30 100644 --- a/imperative/src/test/helper.cpp +++ b/imperative/src/test/helper.cpp @@ -119,7 +119,7 @@ void OprChecker::run(std::vector inp_keys) { }, inp_keys[i]); sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node(); } - auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp)->usable_output(); + auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp); size_t nr_oups = sym_oup.size(); ComputingGraph::OutputSpec oup_spec(nr_oups); SmallVector host_sym_oup(nr_oups); diff --git a/imperative/src/test/io_remote.cpp b/imperative/src/test/io_remote.cpp index fbfd63387..051a14587 100644 --- a/imperative/src/test/io_remote.cpp +++ b/imperative/src/test/io_remote.cpp @@ -10,7 +10,7 @@ */ #include "./helper.h" -#include "megbrain/imperative/ops/io_remote.h" +#include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/mm_handler.h" using namespace mgb; @@ -33,24 +33,19 @@ TEST(TestImperative, IORemote) { } auto run_send = [&](std::shared_ptr hnd) { - imperative::RemoteSend def{"io_remote_test", server_addr, port, 1}; + auto def = imperative::RemoteSend::make( + "io_remote_test", server_addr, port, 1); auto inp = Tensor::make(*hnd); - auto oup = OpDef::apply_on_physical_tensor(def, {inp}); + auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); }; auto run_recv = [&](std::shared_ptr hnd) { - // auto&& shape = std::initializer_list{vector_size}; - imperative::RemoteRecv def{"io_remote_test", - server_addr, - port, - 0, - { - vector_size, - }, - CompNode::load("gpu1"), - dtype::Float32()}; + auto def = imperative::RemoteRecv::make( + "io_remote_test", server_addr, port, 0, + CompNode::load("gpu1"), TensorShape{vector_size}, + dtype::Float32()); auto inp = Tensor::make(*hnd); - auto oup = OpDef::apply_on_physical_tensor(def, {inp}); + auto oup = OpDef::apply_on_physical_tensor(*def, {inp}); HostTensorND host_v; host_v.copy_from(oup[0]->dev_tensor()).sync(); MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); diff --git a/imperative/tablegen/CMakeLists.txt b/imperative/tablegen/CMakeLists.txt new file mode 100644 index 000000000..5beb57952 --- /dev/null +++ b/imperative/tablegen/CMakeLists.txt @@ -0,0 +1,14 @@ +# mgb tablegen executable +set(TABLE_TARGET mgb-mlir-autogen) +add_executable(${TABLE_TARGET} autogen.cpp) +target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR}) +target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport) +set(MGB_TABLEGEN_EXE ${TABLE_TARGET}) + +# generate megbrain opdef c header and python bindings +set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td) +tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header") +tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body") +tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding") +add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen) +set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE) diff --git a/imperative/tablegen/autogen.cpp b/imperative/tablegen/autogen.cpp new file mode 100644 index 000000000..f2dc3dc53 --- /dev/null +++ b/imperative/tablegen/autogen.cpp @@ -0,0 +1,383 @@ +#include +#include +#include + +#include "./helper.h" + +using llvm::raw_ostream; +using llvm::RecordKeeper; + +enum ActionType { + None, + CppHeader, + CppBody, + Pybind +}; + +// NOLINTNEXTLINE +llvm::cl::opt action( + llvm::cl::desc("Action to perform:"), + llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header", + "Generate operator cpp header"), + clEnumValN(CppBody, "gen-cpp-body", + "Generate operator cpp body"), + clEnumValN(Pybind, "gen-python-binding", + "Generate pybind11 python bindings"))); + +using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; +using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; +using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; +using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; +using MgbOp = mlir::tblgen::MgbOpBase; +using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; + +llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) { + // Note: we have already registered the corresponding attr wrappers + // for following basic ctypes so we needn't handle them here + /* auto&& attr_type_name = attr.getAttrDefName(); + if (attr_type_name == "UI32Attr") { + return "uint32_t"; + } + if (attr_type_name == "UI64Attr") { + return "uint64_t"; + } + if (attr_type_name == "I32Attr") { + return "int32_t"; + } + if (attr_type_name == "F32Attr") { + return "float"; + } + if (attr_type_name == "F64Attr") { + return "double"; + } + if (attr_type_name == "StrAttr") { + return "std::string"; + } + if (attr_type_name == "BoolAttr") { + return "bool"; + }*/ + + auto&& attr = llvm::cast(attr_); + if (auto e = llvm::dyn_cast(&attr)) { + return e->getEnumName(); + } + return attr.getUnderlyingType(); +} + +static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { + os << formatv( + "class {0} : public OpDefImplBase<{0}> {{\n" + " MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n" + "public:\n", + op.getCppClassName() + ); + // handle enum alias + for (auto &&i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + os << formatv( + " using {0} = {1};\n", + attr->getEnumName(), attr->getUnderlyingType() + ); + } + } + for (auto &&i : op.getMgbAttributes()) { + auto defaultValue = i.attr.getDefaultValue().str(); + if (!defaultValue.empty()) { + defaultValue = formatv(" = {0}", defaultValue); + } + os << formatv( + " {0} {1}{2};\n", + attr_to_ctype(i.attr), i.name, defaultValue + ); + } + + auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) { + os << formatv( + " {0}({1}){2}{3}\n", + op.getCppClassName(), paramList, memInitList, body + ); + }; + + gen_ctor("", "", " = default;"); + + if (!op.getMgbAttributes().empty()) { + std::vector paramList, initList; + for (auto &&i : op.getMgbAttributes()) { + paramList.push_back(formatv( + "{0} {1}_", attr_to_ctype(i.attr), i.name + )); + initList.push_back(formatv( + "{0}({0}_)", i.name + )); + } + gen_ctor(llvm::join(paramList, ", "), + ": " + llvm::join(initList, ", "), + " {}"); + } + + auto packedParams = op.getPackedParams(); + if (!packedParams.empty()) { + std::vector paramList, initList; + for (auto &&p : packedParams) { + auto&& paramFields = p.getFields(); + auto&& paramType = p.getFullName(); + auto&& paramName = formatv("packed_param_{0}", paramList.size()); + paramList.push_back( + paramFields.empty() ? paramType.str() + : formatv("{0} {1}", paramType, paramName) + ); + for (auto&& i : paramFields) { + initList.push_back(formatv( + "{0}({1}.{0})", i.name, paramName + )); + } + } + for (auto&& i : op.getExtraArguments()) { + paramList.push_back(formatv( + "{0} {1}_", attr_to_ctype(i.attr), i.name + )); + initList.push_back(formatv( + "{0}({0}_)", i.name + )); + } + gen_ctor(llvm::join(paramList, ", "), + initList.empty() ? "" : ": " + llvm::join(initList, ", "), + " {}"); + } + + if (!packedParams.empty()) { + for (auto&& p : packedParams) { + auto accessor = p.getAccessor(); + if (!accessor.empty()) { + os << formatv( + " {0} {1}() const {{\n", + p.getFullName(), accessor + ); + std::vector fields; + for (auto&& i : p.getFields()) { + fields.push_back(i.name); + } + os << formatv( + " return {{{0}};\n", + llvm::join(fields, ", ") + ); + os << " }\n"; + } + } + } + + if (auto decl = op.getExtraOpdefDecl()) { + os << decl.getValue(); + } + + os << formatv( + "};\n\n" + ); +} + +static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { + auto&& className = op.getCppClassName(); + os << formatv( + "MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className + ); + auto formatMethImpl = [&](auto&& meth) { + return formatv( + "{0}_{1}_impl", className, meth + ); + }; + std::vector methods; + if (auto hashable = llvm::dyn_cast(&op)) { + os << "namespace {\n"; + + // generate hash() + mlir::tblgen::FmtContext ctx; + os << formatv( + "size_t {0}(const OpDef& def_) {{\n", + formatMethImpl("hash") + ); + os << formatv( + " auto op_ = def_.cast_final_safe<{0}>();\n" + " static_cast(op_);\n", + className + ); + ctx.withSelf("op_"); + os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx); + os << "}\n"; + + // generate is_same_st() + os << formatv( + "bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n", + formatMethImpl("is_same_st") + ); + os << formatv( + " auto a_ = lhs_.cast_final_safe<{0}>(),\n" + " b_ = rhs_.cast_final_safe<{0}>();\n" + " static_cast(a_);\n" + " static_cast(b_);\n", + className + ); + os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_"); + os << "}\n"; + + os << "} // anonymous namespace\n"; + + methods.push_back("hash"); + methods.push_back("is_same_st"); + } + if (!methods.empty()) { + os << formatv( + "OP_TRAIT_REG({0}, {0})", op.getCppClassName() + ); + for (auto&& i : methods) { + os << formatv( + "\n .{0}({1})", i, formatMethImpl(i) + ); + } + os << ";\n\n"; + } +} + +struct PybindContext { + std::unordered_map enumAlias; +}; + +static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) { + auto class_name = op.getCppClassName(); + os << formatv( + "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", + class_name + ); + for (auto&& i : op.getMgbAttributes()) { + if (auto attr = llvm::dyn_cast(&i.attr)) { + unsigned int enumID; + if (auto alias = llvm::dyn_cast(attr)) { + auto&& aliasBase = alias->getAliasBase(); + enumID = + llvm::cast(aliasBase) + .getBaseRecord()->getID(); + } else { + enumID = attr->getBaseRecord()->getID(); + } + auto&& enumAlias = ctx.enumAlias; + auto&& iter = enumAlias.find(enumID); + if (iter == enumAlias.end()) { + os << formatv( + "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", + class_name, attr->getEnumName() + ); + std::vector body; + for (auto&& i: attr->getEnumMembers()) { + os << formatv( + "\n .value(\"{2}\", {0}::{1}::{2})", + class_name, attr->getEnumName(), i + ); + body.push_back(formatv( + "if (str == \"{2}\") return {0}::{1}::{2};", + class_name, attr->getEnumName(), i + )); + } + os << formatv( + "\n .def(py::init([](const std::string& in) {" + "\n auto&& str = normalize_enum(in);" + "\n {0}" + "\n throw py::cast_error(\"invalid enum value \" + in);" + "\n }));\n", + llvm::join(body, "\n ") + ); + os << formatv( + "py::implicitly_convertible();\n\n", + class_name, attr->getEnumName() + ); + enumAlias.emplace(enumID, formatv( + "{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName() + )); + } else { + os << formatv( + "{0}Inst.attr(\"{1}\") = {2};\n\n", + class_name, attr->getEnumName(), iter->second + ); + } + } + } + // generate op class binding + os << formatv("{0}Inst", class_name); + bool hasDefaultCtor = op.getMgbAttributes().empty(); + if (!hasDefaultCtor) { + os << "\n .def(py::init<"; + std::vector targs; + for (auto &&i : op.getMgbAttributes()) { + targs.push_back(i.attr.getReturnType()); + } + os << llvm::join(targs, ", "); + os << ">()"; + for (auto &&i : op.getMgbAttributes()) { + os << formatv(", py::arg(\"{0}\")", i.name); + auto defaultValue = i.attr.getDefaultValue(); + if (!defaultValue.empty()) { + os << formatv(" = {0}", defaultValue); + } else { + hasDefaultCtor = true; + } + } + os << ")"; + } + if (hasDefaultCtor) { + os << "\n .def(py::init<>())"; + } + for (auto &&i : op.getMgbAttributes()) { + os << formatv( + "\n .def_readwrite(\"{0}\", &{1}::{0})", + i.name, class_name + ); + } + os << ";\n\n"; +} + +static void for_each_operator(raw_ostream &os, RecordKeeper &keeper, + std::function callback) { + auto op_base_class = keeper.getClass("Op"); + ASSERT(op_base_class, "could not find base class Op"); + for (auto&& i: keeper.getDefs()) { + auto&& r = i.second; + if (r->isSubClassOf(op_base_class)) { + auto op = mlir::tblgen::Operator(r.get()); + if (op.getDialectName().str() == "mgb") { + std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; + callback(os, llvm::cast(op)); + } + } + } +} + +static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) { + for_each_operator(os, keeper, gen_op_def_c_header_single); + return false; +} + +static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) { + for_each_operator(os, keeper, gen_op_def_c_body_single); + return false; +} + +static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) { + PybindContext ctx; + using namespace std::placeholders; + for_each_operator(os, keeper, + std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx))); + return false; +} + +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + llvm::cl::ParseCommandLineOptions(argc, argv); + if (action == ActionType::CppHeader) { + return TableGenMain(argv[0], &gen_op_def_c_header); + } + if (action == ActionType::CppBody) { + return TableGenMain(argv[0], &gen_op_def_c_body); + } + if (action == ActionType::Pybind) { + return TableGenMain(argv[0], &gen_op_def_pybind11); + } + return -1; +} \ No newline at end of file diff --git a/imperative/tablegen/helper.h b/imperative/tablegen/helper.h new file mode 100644 index 000000000..ea086f21e --- /dev/null +++ b/imperative/tablegen/helper.h @@ -0,0 +1,228 @@ +#include +#include + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/Operator.h" + +using llvm::formatv; +using llvm::StringRef; +using llvm::Record; + +#define ASSERT(stmt, msg) \ + if (!(stmt)) { \ + std::cerr << "\033[1;31m" \ + << "tablegen autogen abort due to: " << msg \ + << "\033[0m" << std::endl; \ + exit(1); \ + } + +namespace mlir { +namespace tblgen { +template +struct MgbInterface : public ConcreteType { + MgbInterface() = delete; + MgbInterface(const MgbInterface&) = delete; + MgbInterface(MgbInterface&&) = delete; + ~MgbInterface() = delete; +}; + +struct MgbAttrWrapperBase : public MgbInterface { +private: + struct RecordVisitor : public MgbInterface { + public: + static bool classof(const Constraint*) { + return true; + } + + const llvm::Record* getDef() const { + return def; + } + }; +public: + static bool classof(const Attribute* attr) { + return attr->isSubClassOf("MgbAttrWrapperBase"); + } + + const llvm::Record* getBaseRecord() const { + auto baseAttr = getBaseAttr(); + return llvm::cast(baseAttr).getDef(); + } + llvm::StringRef getUnderlyingType() const { + return def->getValueAsString("underlyingType"); + } +}; + +struct MgbEnumAttrMixin : public MgbAttrWrapperBase { + static bool classof(const Attribute* attr) { + return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin"); + } + + llvm::StringRef getParentNamespace() const { + return getBaseRecord()->getValueAsString("parentNamespce"); + } + llvm::StringRef getEnumName() const { + return getBaseRecord()->getValueAsString("enumName"); + } + std::vector getEnumMembers() const { + return getBaseRecord()->getValueAsListOfStrings("enumMembers"); + } +}; + +struct MgbHashableAttrMixin : public MgbAttrWrapperBase { + static bool classof(const Attribute* attr) { + return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin"); + } + + llvm::StringRef getHashFunctionTemplate() const { + return getBaseRecord()->getValueAsString("hashFunction"); + } + llvm::StringRef getCmpFunctionTemplate() const { + return getBaseRecord()->getValueAsString("cmpFunction"); + } +}; + +struct MgbAliasAttrMixin : public MgbAttrWrapperBase { + static bool classof(const Attribute* attr) { + return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin"); + } + + Attribute getAliasBase() const { + return Attribute(getBaseRecord()->getValueAsDef("aliasBase")); + } +}; + +class MgbPackedParam { +public: + MgbPackedParam(Record* def_): def(def_) { + auto&& dag = def->getValueAsDag("fields"); + for (size_t i = 0; i < dag->getNumArgs(); ++ i) { + fields.push_back({ + dag->getArgNameStr(i), + Attribute(llvm::cast(dag->getArg(i))) + }); + } + } + + llvm::StringRef getFullName() const { + return def->getValueAsString("fullName"); + } + std::vector getFields() const { + return fields; + } + llvm::StringRef getAccessor() const { + return def->getValueAsString("paramAccessor"); + } +private: + std::vector fields; + Record* def; +}; + +struct MgbOpBase : public MgbInterface { + static bool isPackedParam(Record* def) { + return def->isSubClassOf("MgbPackedParamBase"); + } + +public: + static bool classof(const Operator* op) { + return op->getDef().isSubClassOf("MgbOp"); + } + + std::vector getMgbAttributes() const { + std::vector ret; + for (auto&& i: getAttributes()) { + if (isa(i.attr)) { + ret.push_back(i); + } + } + return ret; + } + std::vector getExtraArguments() const { + std::vector ret; + auto&& dag = getDef().getValueAsDag("extraArguments"); + for (size_t i = 0; i < dag->getNumArgs(); ++ i) { + ret.push_back({ + dag->getArgNameStr(i), + Attribute(llvm::cast(dag->getArg(i))) + }); + } + return ret; + } + llvm::Optional getExtraOpdefDecl() const { + return getDef().getValueAsOptionalString("extraOpdefDecl"); + } + std::vector getPackedParams() const { + std::vector ret; + for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) { + if (isPackedParam(i)) { + ret.emplace_back(i); + } + } + return ret; + } +}; + +struct MgbHashableOpMixin : public MgbOpBase { +private: + std::string getDefaultHashFunction() const { + std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n"; + if (!getMgbAttributes().empty()) { + auto getHashFunc = [&](auto&& iter) { + auto&& attr = llvm::cast(iter.attr); + return attr.getHashFunctionTemplate(); + }; + mlir::tblgen::FmtContext ctx; + for (auto&& it: getMgbAttributes()) { + body += formatv( + " val = mgb::hash_pair_combine(val, {0});\n", + mlir::tblgen::tgfmt(getHashFunc(it), &ctx, "$_self." + it.name) + ); + } + } + body += " return val;\n"; + return body; + } + std::string getDefaultCmpFunction() const { + std::string body; + if (!getMgbAttributes().empty()) { + mlir::tblgen::FmtContext ctx; + for (auto&& it : getMgbAttributes()) { + auto&& attr = llvm::cast(it.attr); + body += formatv( + " if ({0}) return false;\n", + mlir::tblgen::tgfmt(attr.getCmpFunctionTemplate(), + &ctx, "$0." + it.name, "$1." + it.name) + ); + } + } + body += " return true;\n"; + return body; + } +public: + static bool classof(const Operator* op) { + return op->getDef().isSubClassOf("MgbHashableOpMixin"); + } + + std::string getHashFunctionTemplate() const { + if (auto f = getDef().getValueAsOptionalString("hashFunction")) { + return f.getValue().str(); + } + return getDefaultHashFunction(); + } + std::string getCmpFunctionTemplate() const { + if (auto f = getDef().getValueAsOptionalString("cmpFunction")) { + return f.getValue().str(); + } + return getDefaultCmpFunction(); + } +}; + +} // namespace tblgen +} // namespace mlir \ No newline at end of file diff --git a/imperative/test/CMakeLists.txt b/imperative/test/CMakeLists.txt index 280da981a..03a92575f 100644 --- a/imperative/test/CMakeLists.txt +++ b/imperative/test/CMakeLists.txt @@ -11,7 +11,7 @@ endif() # TODO: turn python binding into a static/object library add_executable(imperative_test ${SOURCES} ${SRCS}) -target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include) +target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) # Python binding target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) diff --git a/src/core/include/megbrain/ir/base.td b/src/core/include/megbrain/ir/base.td new file mode 100644 index 000000000..82e638d58 --- /dev/null +++ b/src/core/include/megbrain/ir/base.td @@ -0,0 +1,257 @@ +/** + * \file src/core/include/megbrain/ir/base.td + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#ifndef MGB_BASE +#define MGB_BASE + +include "mlir/IR/OpBase.td" + +def Mgb_Dialect : Dialect { + let name = "mgb"; + let cppNamespace = "mgb::dialect"; +} + +// -- mgb Attr mixin +class MgbAttrWrapperBase { + string underlyingType = className; + int recursionDepth = 0; +} + +class MgbHashableAttrMixin { + string hashFunction = "mgb::hash($0)"; + // return 0 for eq, else for ne + string cmpFunction = "$0 != $1"; +} + +class MgbEnumAttrMixin members> { + string parentNamespace = namespace; + string enumName = name; + list enumMembers = members; +} + +class MgbAttrWrapper; +class MgbAliasAttrMixin { + Attr aliasBase = base; +} + +// -- mgb custom Attr +// TODO: CPred and description +class MgbAttrWrapper: + Attr, "TODO">, MgbAttrWrapperBase { + let returnType = underlyingType; +} + +class HashableAttr: + MgbAttrWrapper, MgbHashableAttrMixin; + +// -- basic types +class MgbIntegerAttrBase : HashableAttr { + let storageType = "::mlir::IntegerAttr"; +} + +class MgbSignlessIntegerAttrBase : MgbIntegerAttrBase { + let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4), $0)"; +} + +class MgbSignedIntegerAttrBase : MgbIntegerAttrBase { + let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getSInt())"; + let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, true), $0)"; +} + +class MgbUnsignedIntegerAttrBase : MgbIntegerAttrBase { + let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getUInt())"; + let constBuilderCall = "$_builder.getIntegerAttr($_builder.getIntegerType(sizeof(" # underlyingType # ") * 4, false), $0)"; +} + +def MgbI8Attr: MgbSignlessIntegerAttrBase<"int8_t">; +def MgbI32Attr: MgbSignlessIntegerAttrBase<"int32_t">; +def MgbI64Attr: MgbSignlessIntegerAttrBase<"int64_t">; +def MgbUI32Attr: MgbUnsignedIntegerAttrBase<"uint32_t">; +def MgbUI64Attr: MgbUnsignedIntegerAttrBase<"uint64_t">; +def MgbSizeTAddr: MgbUnsignedIntegerAttrBase<"size_t">; + +class MgbFloatAttrBase : HashableAttr { + let storageType = "::mlir::FloatAttr"; + let convertFromStorage = "static_cast<" # underlyingType # ">($_self.getValueAsDouble())"; + let constBuilderCall = "$_builder.getFloatAttr($_builder.get" # DType # "Type(), $0)"; +} + +def MgbF32Attr : MgbFloatAttrBase<"float", "F32">; +def MgbF64Attr : MgbFloatAttrBase<"double", "F64">; + +def MgbBoolAttr : HashableAttr<"bool"> { + let storageType = "::mlir::BoolAttr"; + let constBuilderCall = "$_builder.getBoolAttr($0)"; +} + +def MgbStringAttr : HashableAttr<"std::string"> { + let storageType = "::mlir::StringAttr"; + let convertFromStorage = "$_self.getValue().str()"; + let constBuilderCall = "$_builder.getStringAttr($0)"; // llvm::StringRef implicit ctor +} + +class MgbArrayAttr: + HashableAttr<"std::vector<" # elem.underlyingType # ">"> { + let storageType = "::mlir::ArrayAttr"; + let recursionDepth = !add(elem.recursionDepth, 1); + let convertFromStorage = + "[&] {\n" + " " # underlyingType # " ret" # recursionDepth # ";\n" + " std::for_each($_self.begin(), $_self.end(), [&](auto&& i" # recursionDepth # ") {\n" + " ret" # recursionDepth # ".push_back(\n" + " " # !subst("$_self", "i" # recursionDepth # ".template cast<" # elem.storageType # ">()", "" # elem.convertFromStorage) # "\n" + " );\n" + " });\n" + " return ret" # recursionDepth # ";}()"; + let constBuilderCall = + "[&] {\n" + " std::vector ret" # recursionDepth # ";\n" + " std::for_each($0.begin(), $0.end(), [&](auto&& i" # recursionDepth # ") {\n" + " ret" # recursionDepth # ".push_back(\n" + " " # !subst("$0", "i" # recursionDepth, "" # elem.constBuilderCall) # "\n" + " );\n" + " });\n" + " return $_builder.getArrayAttr(ret" # recursionDepth # ");" + "}()"; +} + +defvar EmptyStrList = !listsplat("", 0); +class StrListAppend l, string s> { + list r = !listconcat(l, !listsplat(s, 1)); +} + +class TupleConvertFromStorage { + string r = !subst( + "$_self", + "$_self[" # !cast(idx) # "].template cast<"# attr.storageType #">()", + "" # attr.convertFromStorage); +} + +class TupleConstBuilderCall { + string r = !subst( + "$0", + "std::get<" # !cast(idx) # ">($0)", + "" # attr.constBuilderCall); +} + +class ApplyTupleConvertFromStorage args> { + list r = !foldl( + EmptyStrList, args, l, arg, StrListAppend.r>.r); +} + +class ApplyTupleConstBuilderCall args> { + list r = !foldl( + EmptyStrList, args, l, arg, StrListAppend.r>.r); +} + +class MgbTupleAttr args>: + HashableAttr<"std::tuple<" # StrJoin.result # ">"> { + let storageType = "::mlir::ArrayAttr"; + let convertFromStorage = "std::make_tuple(" # StrJoin.r>.result # ")"; + let constBuilderCall = "$_builder.getArrayAttr({" # StrJoin.r>.result # "})"; +} + +// -- enum types +class MgbEnumAttr members>: + HashableAttr, MgbEnumAttrMixin { + let storageType = "::mlir::IntegerAttr"; + let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; + let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; + let hashFunction = "mgb::enumhash()($0)"; +} + +class MgbEnumAliasAttr: + MgbEnumAttr, MgbAliasAttrMixin; + +// -- other types +def MgbDTypeAttr: HashableAttr<"::megdnn::DType"> { + let storageType = "::mlir::IntegerAttr"; + let convertFromStorage = underlyingType # "::from_enum(static_cast<::megdnn::DTypeEnum>($_self.getInt()))"; + let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0.enumv()))"; + let hashFunction = "mgb::hash($0.handle())"; +} + +def MgbCompNodeAttr: HashableAttr<"::mgb::CompNode"> { + let storageType = "::mlir::StringAttr"; + let convertFromStorage = underlyingType # "::load($_self.getValue().str())"; + let constBuilderCall = "$_builder.getStringAttr($0.to_string_logical())"; +} + +def MgbTensorShapeAttr: HashableAttr<"::megdnn::TensorShape"> { + let storageType = "::mlir::ArrayAttr"; + let hashFunction = "mgb::PODHash::perform($0.shape, $0.ndim)"; + let cmpFunction = "!$0.eq_shape($1)"; + defvar elemInst = MgbSizeTAddr; + let convertFromStorage = + "[&] {\n" + " " # underlyingType # " ret;\n" + " std::for_each($_self.begin(), $_self.end(), [&ret](auto&& i) {\n" + " ret[ret.ndim ++] = " # !subst("$_self", "i.template cast<"# elemInst.storageType #">()", "" # elemInst.convertFromStorage) # ";\n" + " });\n" + " return ret;}()"; + let constBuilderCall = + "[&] {\n" + " std::vector ret;\n" + " for (size_t i = 0; i < $0.ndim; ++ i) {\n" + " ret.push_back(\n" + " " # !subst("$0", "$0[i]", "" # elemInst.constBuilderCall) # "\n" + " );\n" + " }\n" + " return $_builder.getArrayAttr(ret);" + "}()"; +} + +class MgbDefaultValuedAttr: + DefaultValuedAttr, MgbAttrWrapperBase { + // Note: this class is similar to DefaultValuedAttr but with extra + // meta informations which are used by mgb dialect tblgen, so this + // has to be kept up to date with class MgbAttrWrapperMixin + let recursionDepth = attr.recursionDepth; +} + +// -- dnn params +class MgbParamBase { + string paramType = className; + string fullName = "::megdnn::param::" # paramType; + dag fields = ?; +} + +class MgbPackedParamBase: + MgbParamBase { + string paramAccessor = accessor; +} + +// -- mgb ops +class MgbHashableOpMixin { + string hashFunction = ?; + string cmpFunction = ?; +} + +class MgbOp params=[], list traits=[]>: + Op { + dag inputs = (ins); + dag extraArguments = (ins); + // TODO: remove it + code extraOpdefDecl = ?; + + let arguments = !con( + !foldl(inputs, params, args, param, !con(args, param.fields)), + extraArguments); + + list dnnParams = params; +} + +class MgbHashableOp params=[], list traits=[]>: + MgbOp, MgbHashableOpMixin; + +#endif // MGB_BASE diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td new file mode 100644 index 000000000..c96d3cda8 --- /dev/null +++ b/src/core/include/megbrain/ir/ops.td @@ -0,0 +1,240 @@ +/** + * \file src/core/include/megbrain/ir/ops.td + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#ifndef MGB_OPS +#define MGB_OPS + +include "base.td" +include "param_defs.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" + +def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { + let inputs = (ins Variadic:$input); + let results = (outs AnyType); +} + +def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; + +def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { + let inputs = (ins AnyType:$inputs); + let extraArguments = (ins + TypeAttr:$idtype, + MgbDTypeAttr:$dtype + ); + let results = (outs AnyType); +} + +def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>; + +def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>; + +def Dot: MgbHashableOp<"Dot", [EmptyParam]>; + +def SVD: MgbHashableOp<"SVD", [SVDParam]>; + +def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; + +def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>; + +def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; + +def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; + +def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; + +def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>; + +def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> { + let extraArguments = (ins + MgbDTypeAttr:$dtype + ); +} + +def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> { + let extraArguments = (ins + MgbDTypeAttr:$dtype + ); +} + +def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>; + +def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>; + +def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>; + +def Remap: MgbHashableOp<"Remap", [RemapParam]>; + +def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>; + +def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>; + +def Copy: MgbHashableOp<"Copy"> { + let extraArguments = (ins + MgbCompNodeAttr:$comp_node + ); +} + +def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>; + +def Argmax : MgbHashableOp<"Argmax", [AxisParam]>; + +def Argmin : MgbHashableOp<"Argmin", [AxisParam]>; + +def CondTake : MgbHashableOp<"CondTake">; + +def TopK: MgbHashableOp<"TopK", [TopKParam]>; + +def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; + +def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { + let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; + let cmpFunction = [{return true;}]; +} + +def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); + }]; + let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; +} + +def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { + let extraArguments = (ins + MgbCompNodeAttr:$comp_node + ); +} + +def Eye: MgbHashableOp<"Eye", [EyeParam]> { + let extraArguments = (ins + MgbCompNodeAttr:$comp_node + ); +} + +def GetVarShape : MgbHashableOp<"GetVarShape">; + +def Concat: MgbHashableOp<"Concat", [AxisParam]> { + let extraArguments = (ins + MgbCompNodeAttr:$comp_node + ); +} + +def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>; + +def Identity: MgbHashableOp<"Identity">; + +def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> { + let extraArguments = (ins + MgbStringAttr:$key, + MgbUI32Attr:$nr_devices, + MgbUI32Attr:$rank, + MgbBoolAttr:$is_root, + MgbBoolAttr:$local_grad, + MgbStringAttr:$addr, + MgbUI32Attr:$port, + MgbDTypeAttr:$dtype, + MgbStringAttr:$backend, + MgbStringAttr:$comp_node + ); +} + +def RemoteSend : MgbHashableOp<"RemoteSend"> { + let extraArguments = (ins + MgbStringAttr:$key, + MgbStringAttr:$addr, + MgbUI32Attr:$port, + MgbUI32Attr:$rank_to + ); +} + +def RemoteRecv : MgbHashableOp<"RemoteRecv"> { + let extraArguments = (ins + MgbStringAttr:$key, + MgbStringAttr:$addr, + MgbUI32Attr:$port, + MgbUI32Attr:$rank_from, + MgbCompNodeAttr:$cn, + MgbTensorShapeAttr:$shape, + MgbDTypeAttr:$dtype + ); +} + +def NMSKeep : MgbHashableOp<"NMSKeep"> { + let extraArguments = (ins + MgbF32Attr:$iou_thresh, + MgbUI32Attr:$max_output + ); +} + +def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> { + let extraArguments = (ins + MgbArrayAttr:$offsets, + MgbArrayAttr>:$shapes + ); +} + +def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> { + let extraArguments = (ins + MgbArrayAttr:$offsets + ); +} + +def Dimshuffle: MgbHashableOp<"Dimshuffle"> { + let inputs = (ins AnyMemRef:$input); + let extraArguments = (ins MgbArrayAttr:$pattern); + let results = (outs AnyMemRef); +} + +def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>; + +// TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? +def AddAxis: MgbHashableOp<"AddAxis"> { + let extraArguments = (ins + MgbArrayAttr:$axis + ); +} +def RemoveAxis: MgbHashableOp<"RemoveAxis"> { + let extraArguments = (ins + MgbArrayAttr:$axis + ); +} + +class FancyIndexingBase: MgbHashableOp { + let extraArguments = (ins + MgbArrayAttr>:$items + ); +} + +def Subtensor: FancyIndexingBase<"Subtensor">; +def SetSubtensor: FancyIndexingBase<"SetSubtensor">; +def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">; +def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">; +def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">; +def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">; +def MeshIndexing: FancyIndexingBase<"MeshIndexing">; +def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">; +def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">; +def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">; +def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">; +def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; + +def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; +def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { + let extraArguments = (ins + MgbDTypeAttr:$dtype + ); +} + +#endif // MGB_OPS diff --git a/third_party/prepare.sh b/third_party/prepare.sh index c38e282ca..25a02dab7 100755 --- a/third_party/prepare.sh +++ b/third_party/prepare.sh @@ -47,3 +47,4 @@ pushd MegRay/third_party >/dev/null popd >/dev/null git submodule update --init pybind11 +git submodule update --init llvm-project -- GitLab