提交 69e3e322 编写于 作者: M Megvii Engine Team

feat(imperative): auto generated opdef header and python binding

GitOrigin-RevId: d2f22ad5fe0b15f45afa1ea31af8874e8b18fef9
上级 0398a786
...@@ -230,6 +230,10 @@ endif() ...@@ -230,6 +230,10 @@ endif()
# FIXME At present, there are some conflicts between the LLVM that halide # FIXME At present, there are some conflicts between the LLVM that halide
# depends on and the LLVM that MLIR depends on. Should be fixed in subsequent # depends on and the LLVM that MLIR depends on. Should be fixed in subsequent
# versions. # versions.
if(MGE_BUILD_IMPERATIVE_RT)
set(MGE_WITH_HALIDE OFF)
message(WARNING "cannot use HALIDE when building IMPERATIVE_RT")
endif()
if(MGE_WITH_JIT_MLIR) if(MGE_WITH_JIT_MLIR)
if(MGE_WITH_HALIDE) if(MGE_WITH_HALIDE)
message(FATAL_ERROR "please set MGE_WITH_HALIDE to OFF with MGE_WITH_JIT_MLIR enabled") message(FATAL_ERROR "please set MGE_WITH_HALIDE to OFF with MGE_WITH_JIT_MLIR enabled")
...@@ -310,7 +314,7 @@ if(MGE_INFERENCE_ONLY) ...@@ -310,7 +314,7 @@ if(MGE_INFERENCE_ONLY)
set(MGE_BUILD_IMPERATIVE_RT OFF) set(MGE_BUILD_IMPERATIVE_RT OFF)
endif() endif()
if(MGE_WITH_JIT_MLIR) if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
include(cmake/llvm-project.cmake) include(cmake/llvm-project.cmake)
endif() endif()
...@@ -750,7 +754,7 @@ target_include_directories(mgb_opr_param_defs ...@@ -750,7 +754,7 @@ target_include_directories(mgb_opr_param_defs
add_dependencies(mgb_opr_param_defs _mgb_opr_param_defs) add_dependencies(mgb_opr_param_defs _mgb_opr_param_defs)
install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS}) install(TARGETS mgb_opr_param_defs EXPORT ${MGE_EXPORT_TARGETS})
if(MGE_WITH_JIT_MLIR) if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT)
# generate param_defs.td # generate param_defs.td
set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles) set(MGE_GENFILE_DIR ${PROJECT_BINARY_DIR}/src/genfiles)
set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir) set(MGE_GEN_IR_DIR ${PROJECT_BINARY_DIR}/src/core/include/megbrain/ir)
...@@ -800,12 +804,6 @@ if(TARGET _imperative_rt) ...@@ -800,12 +804,6 @@ if(TARGET _imperative_rt)
COMMAND ${CMAKE_COMMAND} -E create_symlink COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}> ${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}>
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}> ${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/$<TARGET_FILE_NAME:${MODULE_NAME}>
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/generated_ops.py
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_BINARY_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py
${CMAKE_CURRENT_SOURCE_DIR}/imperative/python/${PACKAGE_NAME}/core/ops/_internal/param_defs.py
DEPENDS _imperative_rt DEPENDS _imperative_rt
VERBATIM VERBATIM
) )
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import collections
import textwrap
import os
import hashlib
import struct
import io
from gen_param_defs import member_defs, ParamDef, IndentWriterBase
class ConverterWriter(IndentWriterBase):
_skip_current_param = False
_last_param = None
_current_tparams = None
_packed = None
_const = None
def __call__(self, fout, defs):
super().__call__(fout)
self._write("// %s", self._get_header())
self._write("#ifndef MGB_PARAM")
self._write("#define MGB_PARAM")
self._process(defs)
self._write("#endif // MGB_PARAM")
def _ctype2attr(self, ctype, value):
if ctype == 'uint32_t':
return 'MgbUI32Attr', value
if ctype == 'uint64_t':
return 'MgbUI64Attr', value
if ctype == 'int32_t':
return 'MgbI32Attr', value
if ctype == 'float':
return 'MgbF32Attr', value
if ctype == 'double':
return 'MgbF64Attr', value
if ctype == 'bool':
return 'MgbBoolAttr', value
if ctype == 'DTypeEnum':
self._packed = False
return 'MgbDTypeAttr', 'megdnn::DType::from_enum(megdnn::{})'.format(value)
raise RuntimeError("unknown ctype")
def _on_param_begin(self, p):
self._last_param = p
if p.is_legacy:
self._skip_current_param = True
return
self._packed = True
self._current_tparams = []
self._const = set()
def _on_param_end(self, p):
if self._skip_current_param:
self._skip_current_param = False
return
if self._packed:
self._write("class {0}ParamBase<string accessor> : MgbPackedParamBase<\"{0}\", accessor> {{".format(p.name), indent=1)
else:
self._write("def {0}Param: MgbParamBase<\"{0}\"> {{".format(p.name), indent=1)
self._write("let fields = (ins", indent=1)
self._write(",\n{}".format(self._cur_indent).join(self._current_tparams))
self._write(");", indent=-1)
self._write("}\n", indent=-1)
if self._packed:
self._write("def {0}Param : {0}ParamBase<\"param\">;\n".format(p.name))
self._current_tparams = None
self._packed = None
self._const = None
def _wrapped_with_default_value(self, attr, default):
return 'MgbDefaultValuedAttr<{}, \"{}\">'.format(attr, default)
def _on_member_enum(self, e):
p = self._last_param
# Note: always generate llvm Record def for enum attribute even it was not
# directly used by any operator, or other enum couldn't alias to this enum
td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name)
enum_def = "MgbEnumAttr<\"{}\", \"{}\", [".format(fullname, e.name)
def format(v):
return '\"{}\"'.format(str(v))
enum_def += ','.join(format(i) for i in e.members)
enum_def += "]>"
self._write("def {} : {};".format(td_class, enum_def))
if self._skip_current_param:
return
# wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.default)
wrapped = self._wrapped_with_default_value(td_class, default_val)
self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
def _on_member_enum_alias(self, e):
p = self._last_param
if self._skip_current_param:
return
# write enum attr def
td_class = "{}{}".format(p.name, e.name)
fullname = "::megdnn::param::{}".format(p.name)
base_td_class = "{}{}".format(e.src_class, e.src_name)
enum_def = "MgbEnumAliasAttr<\"{}\", \"{}\", {}>".format(fullname, e.name, base_td_class)
self._write("def {} : {};".format(td_class, enum_def))
# wrapped with default value
default_val = "static_cast<{}::{}>({})".format(fullname, e.name, e.get_default())
wrapped = self._wrapped_with_default_value(td_class, default_val)
self._current_tparams.append("{}:${}".format(wrapped, e.name_field))
def _on_member_field(self, f):
if self._skip_current_param:
return
attr, value = self._ctype2attr(f.dtype.cname, str(f.default))
if str(value) in self._const:
value = '::megdnn::param::{}::{}'.format(self._last_param.name, value)
wrapped = self._wrapped_with_default_value(attr, value)
self._current_tparams.append("{}:${}".format(wrapped, f.name))
def _on_const_field(self, f):
self._const.add(str(f.name))
def main():
parser = argparse.ArgumentParser('generate op param tablegen file')
parser.add_argument('input')
parser.add_argument('output')
args = parser.parse_args()
with open(args.input) as fin:
inputs = fin.read()
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc})
input_hash = hashlib.sha256()
input_hash.update(inputs.encode(encoding='UTF-8'))
input_hash = input_hash.hexdigest()
writer = ConverterWriter()
with open(args.output, 'w') as fout:
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
if __name__ == "__main__":
main()
...@@ -8,9 +8,7 @@ file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/sr ...@@ -8,9 +8,7 @@ file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/sr
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1")
file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl")
file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py) file(GLOB_RECURSE PYTHON_SRCS python/${PACKAGE_NAME}/*.py)
list(REMOVE_ITEM PYTHON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/generated_ops.py ${CMAKE_CURRENT_SOURCE_DIR}/python/megengine/core/ops/_internal/param_defs.py)
file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h
${PROJECT_SOURCE_DIR}/src/core/include/* ${PROJECT_SOURCE_DIR}/src/core/include/*
${PROJECT_SOURCE_DIR}/src/opr/include/* ${PROJECT_SOURCE_DIR}/src/opr/include/*
...@@ -19,33 +17,8 @@ file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h ...@@ -19,33 +17,8 @@ file(GLOB_RECURSE ALL_HEADERS src/cpp/megbrain_pubapi.h
${PROJECT_SOURCE_DIR}/dnn/include/*) ${PROJECT_SOURCE_DIR}/dnn/include/*)
set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/) set(MEGENGINE_DIR ${CMAKE_CURRENT_BINARY_DIR}/python/)
set(GEN_OPS_DIR ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal)
file(MAKE_DIRECTORY ${GEN_OPS_DIR})
set(GEN_OPS_FILE ${GEN_OPS_DIR}/generated_ops.py)
set(GEN_OP_PARAMS_FILE ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/ops/_internal/param_defs.py)
set(GEN_OP_PARAMS_TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/ops.tpl.py)
##################### generate python opr_param_defs.py ############## add_subdirectory(tablegen)
file(COPY ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS)
file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${CONTENTS})
add_custom_command(
OUTPUT ${GEN_OPS_FILE}
COMMAND ${CMAKE_COMMAND} -E touch ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE}
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/${PACKAGE_NAME} ${MEGENGINE_DIR}/${PACKAGE_NAME}
COMMAND ${CMAKE_COMMAND} -E remove -f ${MEGENGINE_DIR}/${PACKAGE_NAME}/core/${MODULE_NAME}.so ${GEN_OPS_FILE} ${GEN_OP_PARAMS_FILE}
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/python/tools/gen_ops.py ${OPR_DECL_SRCS} -o ${GEN_OPS_FILE}
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/python/test ${MEGENGINE_DIR}/${PACKAGE_NAME}/test
COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py -t py --imperative ${CMAKE_CURRENT_BINARY_DIR}/opr_param_defs.py ${GEN_OP_PARAMS_FILE}
DEPENDS ${OPR_DECL_SRCS} ${PYTHON_SRCS} ${ALL_HEADERS} ${GEN_OP_PARAMS_TEMPLATE}
VERBATIM
)
add_custom_target(gen_opr_py DEPENDS ${GEN_OPS_FILE})
##################### end of opdef generation #########################
add_custom_target(_version_ld SOURCES ${MGE_VERSION_SCRIPT}) add_custom_target(_version_ld SOURCES ${MGE_VERSION_SCRIPT})
...@@ -73,7 +46,7 @@ else() ...@@ -73,7 +46,7 @@ else()
endif() endif()
endif() endif()
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR})
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME})
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter)
if(CXX_SUPPORT_WCLASS_MEMACCESS) if(CXX_SUPPORT_WCLASS_MEMACCESS)
...@@ -87,7 +60,7 @@ if (APPLE OR MSVC OR WIN32) ...@@ -87,7 +60,7 @@ if (APPLE OR MSVC OR WIN32)
message(VERBOSE "overwriting SUFFIX at macos and windows before config by set_target_properties") message(VERBOSE "overwriting SUFFIX at macos and windows before config by set_target_properties")
pybind11_extension(${MODULE_NAME}) pybind11_extension(${MODULE_NAME})
endif() endif()
add_dependencies(${MODULE_NAME} gen_opr_py _version_ld) add_dependencies(${MODULE_NAME} mgb_opdef _version_ld)
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) if(MGE_WITH_TEST AND MGE_ENABLE_RTTI)
add_subdirectory(test) add_subdirectory(test)
......
...@@ -19,7 +19,6 @@ from ..ops.builtin import ( ...@@ -19,7 +19,6 @@ from ..ops.builtin import (
IndexingMultiAxisVec, IndexingMultiAxisVec,
IndexingSetMultiAxisVec, IndexingSetMultiAxisVec,
OpDef, OpDef,
OprAttr,
Reduce, Reduce,
Reshape, Reshape,
SetSubtensor, SetSubtensor,
...@@ -31,8 +30,6 @@ from ..tensor.function import Function ...@@ -31,8 +30,6 @@ from ..tensor.function import Function
from ..tensor.tensor import Tensor from ..tensor.tensor import Tensor
from ..tensor.tensor_wrapper import TensorWrapper from ..tensor.tensor_wrapper import TensorWrapper
_reduce_sum_param = Reduce(mode="SUM").to_c().param[0]
@functools.singledispatch @functools.singledispatch
def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad):
...@@ -41,17 +38,18 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad): ...@@ -41,17 +38,18 @@ def builtin_op_get_backward_fn(op: OpDef, inputs, outputs, input_requires_grad):
@builtin_op_get_backward_fn.register(OpDef) @builtin_op_get_backward_fn.register(OpDef)
def _(op: OpDef, inputs, outputs, input_requires_grad): def _(op: OpDef, inputs, outputs, input_requires_grad):
if isinstance(op, OprAttr): if isinstance(op, Reshape):
grad_fn = _oprAttr_grad_fn.get(op.type, None) grad_fn = reshape_grad_fn
if grad_fn is None: elif isinstance(op, Subtensor):
if op.type == Reduce.name and op.param[0] == _reduce_sum_param: grad_fn = subtensor_grad_fn
grad_fn = reduce_sum_grad_fn elif isinstance(op, IndexingMultiAxisVec):
else: grad_fn = indexingMultiAxisVec_grad_fn
grad_fn = default_grad_fn
elif isinstance(op, Broadcast) or ( elif isinstance(op, Broadcast) or (
isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD isinstance(op, Elemwise) and op.mode == Elemwise.Mode.ADD
): ):
grad_fn = elemwise_add_grad_fn grad_fn = elemwise_add_grad_fn
elif isinstance(op, Reduce) and op.mode.name == "SUM":
grad_fn = reduce_sum_grad_fn
else: else:
grad_fn = default_grad_fn grad_fn = default_grad_fn
return grad_fn(op, inputs, outputs, input_requires_grad) return grad_fn(op, inputs, outputs, input_requires_grad)
...@@ -152,9 +150,7 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad): ...@@ -152,9 +150,7 @@ def reshape_grad_fn(op, inputs, outputs, input_requires_grad):
# override for Subtensor # override for Subtensor
def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr() grad_op = SetSubtensor(op.items)
grad_op.type = SetSubtensor.name
grad_op.param = op.param
input_shape = get_shape(inputs[0]) input_shape = get_shape(inputs[0])
params = inputs[1:] params = inputs[1:]
...@@ -175,9 +171,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad): ...@@ -175,9 +171,7 @@ def subtensor_grad_fn(op, inputs, outputs, input_requires_grad):
# override for IndexingMultiAxisVec # override for IndexingMultiAxisVec
def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad): def indexingMultiAxisVec_grad_fn(op, inputs, outputs, input_requires_grad):
grad_op = OprAttr() grad_op = IndexingSetMultiAxisVec(op.items)
grad_op.type = IndexingSetMultiAxisVec.name
grad_op.param = op.param
input_shape = get_shape(inputs[0]) input_shape = get_shape(inputs[0])
params = inputs[1:] params = inputs[1:]
...@@ -209,10 +203,3 @@ def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad): ...@@ -209,10 +203,3 @@ def reduce_sum_grad_fn(op, inputs, outputs, input_requires_grad):
return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,) return (broadcast_to(dy, input_shape) if input_requires_grad[0] else None,)
return backward, [True] return backward, [True]
_oprAttr_grad_fn = {
Reshape.name: reshape_grad_fn,
Subtensor.name: subtensor_grad_fn,
IndexingMultiAxisVec.name: indexingMultiAxisVec_grad_fn,
}
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .generated_ops import *
from .misc_ops import *
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import warnings
from ..._imperative_rt.ops import OprAttr
from . import param_defs
def make_param(param, ptype, kwargs):
if param is not None:
if isinstance(param, ptype):
return param
param = [param]
assert len(param) == len(
ptype.__slots__
), "{} needs {} params, but {} are provided".format(
ptype, len(ptype.__slots__), len(param)
)
return ptype(*param)
ckw = {}
for i in ptype.__slots__:
val = kwargs.pop(i, ckw)
if val is not ckw:
ckw[i] = val
return ptype(**ckw)
class PodOpVisitor:
__name2subclass = {}
__c = None
name = None
param_names = []
config = None
def __init__(self, config, **params):
self.config = config
assert set(params) == set(self.param_names)
self.__dict__.update(params)
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) # python 3.5 does not have this
name = cls.name
if name in cls.__name2subclass:
if not issubclass(cls, cls.__name2subclass[name]):
warnings.warn("Multiple subclasses for bultin op: %s" % name)
cls.__name2subclass[name] = cls
def to_c(self):
if self.__c:
return self.__c
op = OprAttr()
op.type = self.name
if self.config is not None:
op.config = self.config
# first 4 bytes is TAG, has to remove them currently
op.param = b"".join(self.__dict__[k].serialize()[4:] for k in self.param_names)
self.__c = op
return op
def __eq__(self, rhs):
return self.to_c() == rhs.to_c()
def __repr__(self):
name = self.__class__.__name__
if self.__c:
return "{}(<binary data>)".format(name)
kwargs = {}
for i in self.param_names:
p = self.__dict__[i]
if isinstance(p, param_defs._ParamDefBase):
for k in p.__slots__:
v = getattr(p, k)
if isinstance(v, param_defs._EnumBase):
v = v.name
kwargs[k] = repr(v)
else:
kwargs[i] = repr(p)
if self.config:
if len(self.config.comp_node_arr) == 1:
kwargs["device"] = "'%s'" % self.config.comp_node
return "{}({})".format(
name, ", ".join("{}={}".format(k, v) for k, v in kwargs.items())
)
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import ctypes
from ..._imperative_rt import OperatorNodeConfig as Config
from . import param_defs
from .helper import PodOpVisitor, make_param
__all__ = ["ConvolutionBackwardData", "Dimshuffle", "Reshape", "AxisAddRemove"]
class TensorShape:
MAX_NDIM = 7
class ConvolutionBackwardData(PodOpVisitor):
param_names = (
"param",
"execution_polity",
)
name = "ConvolutionBackwardDataV1"
def __init__(
self,
*,
param=None,
execution_polity=None,
name=None,
comp_node=None,
config=None,
dtype=None,
**kwargs
):
config = config or Config()
if name:
config.name = name
if comp_node:
config.comp_node = comp_node
if dtype:
config.dtype = dtype
self.config = config
self.param = make_param(param, param_defs.Convolution, kwargs)
self.execution_polity = make_param(
execution_polity, param_defs.ExecutionPolicy, kwargs
)
assert not kwargs, "extra kwargs: {}".format(kwargs)
class Dimshuffle(PodOpVisitor):
name = "Dimshuffle"
param_names = ("pattern",)
class Pattern(ctypes.Structure):
Pattern_Array = ctypes.c_int32 * TensorShape.MAX_NDIM
_fields_ = [
("length", ctypes.c_uint32),
("pattern", Pattern_Array),
("ndim", ctypes.c_uint32),
]
def serialize(self):
return bytes(ctypes.c_uint32(0)) + bytes(self)
def __init__(self, pattern, ndim=0):
assert isinstance(pattern, collections.abc.Iterable)
assert len(pattern) <= TensorShape.MAX_NDIM
pattern_array = Dimshuffle.Pattern.Pattern_Array()
for idx, v in enumerate(pattern):
pattern_array[idx] = ctypes.c_int32(-1 if v == "x" else int(v))
self.pattern = Dimshuffle.Pattern(len(pattern), pattern_array, ndim)
class Reshape(PodOpVisitor):
name = "ReshapeV1"
param_names = ("unspec_axis",)
def __init__(self, unspec_axis=None):
if unspec_axis is None:
self.unspec_axis = param_defs.OptionalAxisV1()
else:
self.unspec_axis = param_defs.OptionalAxisV1(unspec_axis)
class AxisNum(ctypes.Structure):
_fields_ = [
("m_num", ctypes.c_int),
]
class AxisDesc(ctypes.Structure):
class Method(ctypes.c_int):
ADD_1 = 0
REMOVE = 1
_fields_ = [
("method", Method),
("axis", AxisNum),
]
@classmethod
def make_add(cls, axis):
return cls(cls.Method.ADD_1, AxisNum(axis))
@classmethod
def make_remove(cls, axis):
return cls(cls.Method.REMOVE, AxisNum(axis))
class AxisAddRemove(PodOpVisitor):
name = "AxisAddRemove"
param_names = ("param",)
AxisDesc = AxisDesc
class Param(ctypes.Structure):
MAX_DESC_SIZE = TensorShape.MAX_NDIM * 2
_fields_ = [("nr_desc", ctypes.c_uint32), ("desc", AxisDesc * MAX_DESC_SIZE)]
def __init__(self, *args):
super().__init__()
self.nr_desc = len(args)
for i, a in enumerate(args):
self.desc[i] = a
def serialize(self):
return bytes(ctypes.c_uint32(0)) + bytes(self)
def __init__(self, param):
assert isinstance(param, self.Param)
self.param = param
del AxisDesc
class IndexingOpBase(PodOpVisitor):
param_names = ("index_desc",)
class IndexDescMaskDump(ctypes.Structure):
class Item(ctypes.Structure):
_fields_ = [
("axis", ctypes.c_int8),
("begin", ctypes.c_bool),
("end", ctypes.c_bool),
("step", ctypes.c_bool),
("idx", ctypes.c_bool),
]
Item_Array = Item * TensorShape.MAX_NDIM
_fields_ = [("nr_item", ctypes.c_uint8), ("items", Item_Array)]
def serialize(self):
return bytes(ctypes.c_uint32(0)) + bytes(self)
def __init__(self, items):
nr_item = len(items)
assert nr_item <= TensorShape.MAX_NDIM
item_array = IndexingOpBase.IndexDescMaskDump.Item_Array()
for idx, item in enumerate(items):
assert isinstance(item, (tuple, list)) and len(item) == 5
item_array[idx] = IndexingOpBase.IndexDescMaskDump.Item(*item)
self.index_desc = IndexingOpBase.IndexDescMaskDump(nr_item, item_array)
def _gen_indexing_defs(*names):
for name in names:
globals()[name] = type(name, (IndexingOpBase,), dict(name=name))
__all__.append(name)
_gen_indexing_defs(
"Subtensor",
"SetSubtensor",
"IncrSubtensor",
"IndexingMultiAxisVec",
"IndexingSetMultiAxisVec",
"IndexingIncrMultiAxisVec",
"MeshIndexing",
"IncrMeshIndexing",
"SetMeshIndexing",
"BatchedMeshIndexing",
"BatchedIncrMeshIndexing",
"BatchedSetMeshIndexing",
)
...@@ -11,25 +11,12 @@ from typing import Union ...@@ -11,25 +11,12 @@ from typing import Union
from ..._imperative_rt import OpDef, ops from ..._imperative_rt import OpDef, ops
from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .._internal import all_ops
from .._internal.helper import PodOpVisitor
# register OpDef as a "virtual subclass" of OpBase, so any of registered # register OpDef as a "virtual subclass" of OpBase, so any of registered
# apply(OpBase, ...) rules could work well on OpDef # apply(OpBase, ...) rules could work well on OpDef
OpBase.register(OpDef) OpBase.register(OpDef)
# forward to apply(OpDef, ...) __all__ = ["OpDef"]
@apply.register()
def _(op: PodOpVisitor, *args: Union[TensorBase, TensorWrapperBase]):
return apply(op.to_c(), *args)
__all__ = ["OpDef", "PodOpVisitor"]
for k, v in all_ops.__dict__.items():
if isinstance(v, type) and issubclass(v, PodOpVisitor):
globals()[k] = v
__all__.append(k)
for k, v in ops.__dict__.items(): for k, v in ops.__dict__.items():
if isinstance(v, type) and issubclass(v, OpDef): if isinstance(v, type) and issubclass(v, OpDef):
......
...@@ -90,7 +90,7 @@ def _reshape(x, shape): ...@@ -90,7 +90,7 @@ def _reshape(x, shape):
if unspec_axis is None: if unspec_axis is None:
op = builtin.Reshape() op = builtin.Reshape()
else: else:
op = builtin.Reshape(unspec_axis=unspec_axis) op = builtin.Reshape(axis=unspec_axis)
(x,) = apply(op, x, shape) (x,) = apply(op, x, shape)
return x return x
...@@ -144,8 +144,6 @@ def _logical_binary_elwise(mode, rev=False): ...@@ -144,8 +144,6 @@ def _logical_binary_elwise(mode, rev=False):
def _remove_axis(inp: Tensor, axis) -> Tensor: def _remove_axis(inp: Tensor, axis) -> Tensor:
Param = builtin.AxisAddRemove.Param
def get_axes(): def get_axes():
if axis is None: if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1] return [i for i, s in enumerate(inp.shape) if s == 1]
...@@ -159,8 +157,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor: ...@@ -159,8 +157,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = [a - i for i, a in enumerate(axis)] axis = [a - i for i, a in enumerate(axis)]
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) op = builtin.RemoveAxis(axis=axis)
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp) (result,) = apply(op, inp)
if len(axis) == inp.ndim: if len(axis) == inp.ndim:
setscalar(result) setscalar(result)
......
...@@ -134,7 +134,7 @@ def astype(x, dtype): ...@@ -134,7 +134,7 @@ def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if not is_equal(x.dtype, dtype): if not is_equal(x.dtype, dtype):
isscalar = x.__wrapped__._data._isscalar isscalar = x.__wrapped__._data._isscalar
(x,) = apply(builtin.TypeCvt(param=dtype), x) (x,) = apply(builtin.TypeCvt(dtype=dtype), x)
x.__wrapped__._data._isscalar = isscalar x.__wrapped__._data._isscalar = isscalar
return x return x
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple from typing import Optional, Tuple
from ..core._imperative_rt.ops import CollectiveCommMode
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
from ..core.autodiff.grad import ( from ..core.autodiff.grad import (
Tracer, Tracer,
...@@ -110,17 +109,20 @@ def collective_comm(inp, mode, group, device): ...@@ -110,17 +109,20 @@ def collective_comm(inp, mode, group, device):
assert isinstance(group, Group) assert isinstance(group, Group)
if group is None: if group is None:
return inp return inp
op = CollectiveComm() addr, port = get_mm_server_addr()
op.key = group.key op = CollectiveComm(
op.nr_devices = group.size key=group.key,
op.rank = group.rank nr_devices=group.size,
op.is_root = op.rank == 0 rank=group.rank,
op.local_grad = False is_root=(group.rank == 0),
op.addr, op.port = get_mm_server_addr() local_grad=False,
op.mode = mode addr=addr,
op.dtype = inp.dtype port=port,
op.backend = get_backend() mode=mode,
op.comp_node = device dtype=inp.dtype,
backend=get_backend(),
comp_node=device,
)
return apply(op, inp)[0] return apply(op, inp)[0]
...@@ -134,7 +136,7 @@ def reduce_sum( ...@@ -134,7 +136,7 @@ def reduce_sum(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.REDUCE_SUM mode = CollectiveComm.Mode.REDUCE_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -148,7 +150,7 @@ def broadcast( ...@@ -148,7 +150,7 @@ def broadcast(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.BROADCAST mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -162,7 +164,7 @@ def all_gather( ...@@ -162,7 +164,7 @@ def all_gather(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_GATHER mode = CollectiveComm.Mode.ALL_GATHER
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -176,7 +178,7 @@ def reduce_scatter_sum( ...@@ -176,7 +178,7 @@ def reduce_scatter_sum(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.REDUCE_SCATTER_SUM mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -190,7 +192,7 @@ def all_reduce_sum( ...@@ -190,7 +192,7 @@ def all_reduce_sum(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_REDUCE_SUM mode = CollectiveComm.Mode.ALL_REDUCE_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -204,7 +206,7 @@ def all_reduce_max( ...@@ -204,7 +206,7 @@ def all_reduce_max(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_REDUCE_MAX mode = CollectiveComm.Mode.ALL_REDUCE_MAX
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -218,7 +220,7 @@ def all_reduce_min( ...@@ -218,7 +220,7 @@ def all_reduce_min(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_REDUCE_MIN mode = CollectiveComm.Mode.ALL_REDUCE_MIN
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -232,7 +234,7 @@ def gather( ...@@ -232,7 +234,7 @@ def gather(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.GATHER mode = CollectiveComm.Mode.GATHER
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -246,7 +248,7 @@ def scatter( ...@@ -246,7 +248,7 @@ def scatter(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.SCATTER mode = CollectiveComm.Mode.SCATTER
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
...@@ -260,7 +262,7 @@ def all_to_all( ...@@ -260,7 +262,7 @@ def all_to_all(
:param group: communication group. :param group: communication group.
:param device: execution device. :param device: execution device.
""" """
mode = CollectiveCommMode.ALL_TO_ALL mode = CollectiveComm.Mode.ALL_TO_ALL
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
......
...@@ -73,27 +73,7 @@ __all__ = [ ...@@ -73,27 +73,7 @@ __all__ = [
] ]
class _ElemwiseMode(Elemwise.Mode):
@classmethod
def __normalize(cls, val):
if isinstance(val, str):
if not hasattr(cls, "__member_upper_dict__"):
cls.__member_upper_dict__ = {
k.upper(): v for k, v in cls.__members__.items()
}
val = cls.__member_upper_dict__.get(val.upper(), val)
return val
@classmethod
def convert(cls, val):
val = cls.__normalize(val)
if isinstance(val, cls):
return val
return cls(val)
def _elwise(*args, mode): def _elwise(*args, mode):
mode = _ElemwiseMode.convert(mode)
op = builtin.Elemwise(mode) op = builtin.Elemwise(mode)
tensor_args = list( tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)
......
...@@ -13,7 +13,6 @@ import numbers ...@@ -13,7 +13,6 @@ import numbers
from typing import Optional, Sequence, Tuple, Union from typing import Optional, Sequence, Tuple, Union
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import utils from ..core.tensor import utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
...@@ -601,9 +600,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: ...@@ -601,9 +600,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor:
""" """
assert len(inp.shape) <= 2, "Input should be 1d or 2d" assert len(inp.shape) <= 2, "Input should be 1d or 2d"
if descending: if descending:
order = P.Argsort.Order.DESCENDING order = "DESCENDING"
else: else:
order = P.Argsort.Order.ASCENDING order = "ASCENDING"
op = builtin.Argsort(order=order) op = builtin.Argsort(order=order)
if len(inp.shape) == 1: if len(inp.shape) == 1:
...@@ -643,9 +642,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: ...@@ -643,9 +642,9 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
""" """
assert len(inp.shape) <= 2, "Input should be 1d or 2d" assert len(inp.shape) <= 2, "Input should be 1d or 2d"
if descending: if descending:
order = P.Argsort.Order.DESCENDING order = "DESCENDING"
else: else:
order = P.Argsort.Order.ASCENDING order = "ASCENDING"
op = builtin.Argsort(order=order) op = builtin.Argsort(order=order)
if len(inp.shape) == 1: if len(inp.shape) == 1:
...@@ -695,13 +694,12 @@ def topk( ...@@ -695,13 +694,12 @@ def topk(
if descending: if descending:
inp = -inp inp = -inp
Mode = P.TopK.Mode
if kth_only: if kth_only:
mode = Mode.KTH_ONLY mode = "KTH_ONLY"
elif no_sort: elif no_sort:
mode = Mode.VALUE_IDX_NOSORT mode = "VALUE_IDX_NOSORT"
else: else:
mode = Mode.VALUE_IDX_SORTED mode = "VALUE_IDX_SORTED"
op = builtin.TopK(mode=mode) op = builtin.TopK(mode=mode)
if not isinstance(k, (TensorBase, TensorWrapperBase)): if not isinstance(k, (TensorBase, TensorWrapperBase)):
......
...@@ -12,7 +12,6 @@ from typing import Optional, Sequence, Tuple, Union ...@@ -12,7 +12,6 @@ from typing import Optional, Sequence, Tuple, Union
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._trace_option import use_symbolic_shape from ..core._trace_option import use_symbolic_shape
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.builtin import BatchNorm from ..core.ops.builtin import BatchNorm
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph, utils from ..core.tensor import megbrain_graph, utils
...@@ -121,11 +120,11 @@ def conv2d( ...@@ -121,11 +120,11 @@ def conv2d(
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups, and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode` :type conv_mode: string or :class:`Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION". Default: :param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION" "CROSS_CORRELATION"
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode` :class:`Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32", placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only "Float32" would be used for accumulator and intermediate result, but only
...@@ -139,8 +138,8 @@ def conv2d( ...@@ -139,8 +138,8 @@ def conv2d(
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
Sparse = P.Convolution.Sparse Sparse = builtin.Convolution.Sparse
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.Convolution( op = builtin.Convolution(
stride_h=stride_h, stride_h=stride_h,
stride_w=stride_w, stride_w=stride_w,
...@@ -187,11 +186,11 @@ def conv_transpose2d( ...@@ -187,11 +186,11 @@ def conv_transpose2d(
``in_channels`` and ``out_channels`` must be divisible by groups, ``in_channels`` and ``out_channels`` must be divisible by groups,
and the shape of weight should be `(groups, out_channel // groups, and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. Default: 1 in_channels // groups, height, width)`. Default: 1
:type conv_mode: string or :class:`P.Convolution.Mode` :type conv_mode: string or :class:`Convolution.Mode`
:param conv_mode: supports "CROSS_CORRELATION". Default: :param conv_mode: supports "CROSS_CORRELATION". Default:
"CROSS_CORRELATION" "CROSS_CORRELATION"
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode` :class:`Convolution.ComputeMode`
:param compute_mode: when set to "DEFAULT", no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32", placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only "Float32" would be used for accumulator and intermediate result, but only
...@@ -240,8 +239,6 @@ def local_conv2d( ...@@ -240,8 +239,6 @@ def local_conv2d(
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)
dilate_h, dilate_w = expand_hw(dilation) dilate_h, dilate_w = expand_hw(dilation)
Sparse = P.Convolution.Sparse
op = builtin.GroupLocal( op = builtin.GroupLocal(
stride_h=stride_h, stride_h=stride_h,
stride_w=stride_w, stride_w=stride_w,
...@@ -251,7 +248,7 @@ def local_conv2d( ...@@ -251,7 +248,7 @@ def local_conv2d(
dilate_w=dilate_w, dilate_w=dilate_w,
mode=conv_mode, mode=conv_mode,
compute_mode="DEFAULT", compute_mode="DEFAULT",
sparse=Sparse.DENSE, sparse="DENSE",
) )
inp, weight = utils.convert_inputs(inp, weight) inp, weight = utils.convert_inputs(inp, weight)
(output,) = apply(op, inp, weight) (output,) = apply(op, inp, weight)
...@@ -696,19 +693,14 @@ def batch_norm( ...@@ -696,19 +693,14 @@ def batch_norm(
if not training: if not training:
op = builtin.BatchNorm( op = builtin.BatchNorm(
BatchNorm.ParamDim.DIM_1C11, BatchNorm.FwdMode.INFERENCE, eps, 1.0, 1.0, 0.0 fwd_mode=BatchNorm.FwdMode.INFERENCE, epsilon=eps, param_dim="DIM_1C11"
) )
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] ret = apply(op, inp, weight, bias, running_mean, running_var)[-1]
return ret return ret
else: else:
op = builtin.BatchNorm( op = builtin.BatchNorm(
BatchNorm.ParamDim.DIM_1C11, avg_factor=1 - momentum, epsilon=eps, param_dim="DIM_1C11"
BatchNorm.FwdMode.TRAINING,
eps,
1.0 - momentum,
1.0,
0.0,
) )
if has_mean or has_var: if has_mean or has_var:
running_mean = make_full_if_none(running_mean, 0) running_mean = make_full_if_none(running_mean, 0)
...@@ -1638,8 +1630,7 @@ def conv1d( ...@@ -1638,8 +1630,7 @@ def conv1d(
pad_h = padding pad_h = padding
dilate_h = dilation dilate_h = dilation
Sparse = P.Convolution.Sparse sparse_type = "DENSE" if groups == 1 else "GROUP"
sparse_type = Sparse.DENSE if groups == 1 else Sparse.GROUP
op = builtin.Convolution( op = builtin.Convolution(
stride_h=stride_h, stride_h=stride_h,
stride_w=1, stride_w=1,
......
...@@ -41,12 +41,12 @@ def conv_bias_activation( ...@@ -41,12 +41,12 @@ def conv_bias_activation(
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups, and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`. :type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION' 'CROSS_CORRELATION'
:param dtype: support for ``np.dtype``, Default: np.int8 :param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode`. :class:`Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32", placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.
...@@ -56,7 +56,7 @@ def conv_bias_activation( ...@@ -56,7 +56,7 @@ def conv_bias_activation(
sh, sw = _pair_nonzero(stride) sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation) dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP" sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.ConvBiasForward( op = builtin.ConvBias(
stride_h=sh, stride_h=sh,
stride_w=sw, stride_w=sw,
pad_h=ph, pad_h=ph,
...@@ -101,12 +101,12 @@ def batch_conv_bias_activation( ...@@ -101,12 +101,12 @@ def batch_conv_bias_activation(
``in_channels`` and ``out_channels`` must be divisible by ``groups``, ``in_channels`` and ``out_channels`` must be divisible by ``groups``,
and the shape of weight should be `(groups, out_channel // groups, and the shape of weight should be `(groups, out_channel // groups,
in_channels // groups, height, width)`. in_channels // groups, height, width)`.
:type conv_mode: string or :class:`P.Convolution.Mode`. :type conv_mode: string or :class:`Convolution.Mode`.
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: :param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default:
'CROSS_CORRELATION' 'CROSS_CORRELATION'
:param dtype: support for ``np.dtype``, Default: np.int8 :param dtype: support for ``np.dtype``, Default: np.int8
:type compute_mode: string or :type compute_mode: string or
:class:`P.Convolution.ComputeMode`. :class:`Convolution.ComputeMode`.
:param compute_mode: when set to "DEFAULT", no special requirements will be :param compute_mode: when set to "DEFAULT", no special requirements will be
placed on the precision of intermediate results. When set to "FLOAT32", placed on the precision of intermediate results. When set to "FLOAT32",
"Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype. "Float32" would be used for accumulator and intermediate result, but only effective when input and output are of Float16 dtype.
...@@ -116,7 +116,7 @@ def batch_conv_bias_activation( ...@@ -116,7 +116,7 @@ def batch_conv_bias_activation(
sh, sw = _pair_nonzero(stride) sh, sw = _pair_nonzero(stride)
dh, dw = _pair_nonzero(dilation) dh, dw = _pair_nonzero(dilation)
sparse_type = "DENSE" if groups == 1 else "GROUP" sparse_type = "DENSE" if groups == 1 else "GROUP"
op = builtin.BatchConvBiasForward( op = builtin.BatchConvBias(
stride_h=sh, stride_h=sh,
stride_w=sw, stride_w=sw,
pad_h=ph, pad_h=ph,
......
...@@ -16,7 +16,6 @@ import numpy as np ...@@ -16,7 +16,6 @@ import numpy as np
from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis from ..core.tensor.tensor_wrapper import _broadcast, _remove_axis
...@@ -722,7 +721,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: ...@@ -722,7 +721,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor:
[1 0]] [1 0]]
""" """
return inp.transpose(pattern) return inp.transpose(list(-1 if _ == "x" else _ for _ in pattern))
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
...@@ -756,10 +755,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: ...@@ -756,10 +755,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
return inp.reshape(target_shape) return inp.reshape(target_shape)
AxisAddRemove = builtin.AxisAddRemove
AxisDesc = AxisAddRemove.AxisDesc
def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor:
r""" r"""
Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``. Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``.
...@@ -826,7 +821,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ...@@ -826,7 +821,6 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
(1, 2) (1, 2)
""" """
Param = builtin.AxisAddRemove.Param
def get_axes(): def get_axes():
try: try:
...@@ -839,8 +833,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: ...@@ -839,8 +833,7 @@ def expand_dims(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
ndim = inp.ndim + len(axis) ndim = inp.ndim + len(axis)
axis = sorted(i + ndim if i < 0 else i for i in axis) axis = sorted(i + ndim if i < 0 else i for i in axis)
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_add, axis)) op = builtin.AddAxis(axis=axis)
op = builtin.AxisAddRemove(param=param)
(result,) = apply(op, inp) (result,) = apply(op, inp)
return result return result
......
...@@ -21,9 +21,10 @@ import numpy as np ...@@ -21,9 +21,10 @@ import numpy as np
from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import ( from ..core._imperative_rt.ops import (
CollectiveComm, CollectiveComm,
OprAttr, GaussianRNG,
RemoteRecv, RemoteRecv,
RemoteSend, RemoteSend,
UniformRNG,
VirtualDep, VirtualDep,
) )
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
...@@ -182,14 +183,7 @@ class trace: ...@@ -182,14 +183,7 @@ class trace:
record = self._seq[self._pc] record = self._seq[self._pc]
op_, ihandles, ohandles = record op_, ihandles, ohandles = record
if op != op_: if op != op_:
# FIXME: will be removed once better rng implementation is done raise TraceMismatchError("op different from last time")
if isinstance(op, OprAttr) and (
op.type in ("UniformRNG", "GaussianRNG") and op.type == op_.type
):
if op.param[8:] != op_.param[8:]:
raise TraceMismatchError("op different from last time")
else:
raise TraceMismatchError("op different from last time")
if len(ihandles) != len(args): if len(ihandles) != len(args):
raise TraceMismatchError("op input size different from last time") raise TraceMismatchError("op input size different from last time")
......
...@@ -10,7 +10,6 @@ from typing import Tuple, Union ...@@ -10,7 +10,6 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from ..core.ops._internal import param_defs as P
from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu from ..functional import conv1d, conv2d, conv_transpose2d, local_conv2d, relu
from ..functional.types import _pair, _pair_nonzero from ..functional.types import _pair, _pair_nonzero
from ..tensor import Parameter from ..tensor import Parameter
...@@ -156,8 +155,6 @@ class Conv1d(_ConvNd): ...@@ -156,8 +155,6 @@ class Conv1d(_ConvNd):
(2, 1, 2) (2, 1, 2)
""" """
_conv_mode_type = P.Convolution.Mode
_compute_mode_type = P.Convolution.ComputeMode
def __init__( def __init__(
self, self,
...@@ -176,8 +173,8 @@ class Conv1d(_ConvNd): ...@@ -176,8 +173,8 @@ class Conv1d(_ConvNd):
stride = stride stride = stride
padding = padding padding = padding
dilation = dilation dilation = dilation
self.conv_mode = self._conv_mode_type.convert(conv_mode) self.conv_mode = conv_mode
self.compute_mode = self._compute_mode_type.convert(compute_mode) self.compute_mode = compute_mode
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -302,9 +299,6 @@ class Conv2d(_ConvNd): ...@@ -302,9 +299,6 @@ class Conv2d(_ConvNd):
""" """
_conv_mode_type = P.Convolution.Mode
_compute_mode_type = P.Convolution.ComputeMode
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -322,8 +316,8 @@ class Conv2d(_ConvNd): ...@@ -322,8 +316,8 @@ class Conv2d(_ConvNd):
stride = _pair_nonzero(stride) stride = _pair_nonzero(stride)
padding = _pair(padding) padding = _pair(padding)
dilation = _pair_nonzero(dilation) dilation = _pair_nonzero(dilation)
self.conv_mode = self._conv_mode_type.convert(conv_mode) self.conv_mode = conv_mode
self.compute_mode = self._compute_mode_type.convert(compute_mode) self.compute_mode = compute_mode
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -414,9 +408,6 @@ class ConvTranspose2d(_ConvNd): ...@@ -414,9 +408,6 @@ class ConvTranspose2d(_ConvNd):
effective when input and output are of float16 dtype. effective when input and output are of float16 dtype.
""" """
_conv_mode_type = P.Convolution.Mode
_compute_mode_type = P.Convolution.ComputeMode
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
...@@ -434,8 +425,8 @@ class ConvTranspose2d(_ConvNd): ...@@ -434,8 +425,8 @@ class ConvTranspose2d(_ConvNd):
stride = _pair_nonzero(stride) stride = _pair_nonzero(stride)
padding = _pair(padding) padding = _pair(padding)
dilation = _pair_nonzero(dilation) dilation = _pair_nonzero(dilation)
self.conv_mode = self._conv_mode_type.convert(conv_mode) self.conv_mode = conv_mode
self.compute_mode = self._compute_mode_type.convert(compute_mode) self.compute_mode = compute_mode
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -509,8 +500,6 @@ class LocalConv2d(Conv2d): ...@@ -509,8 +500,6 @@ class LocalConv2d(Conv2d):
in_channels // groups, *kernel_size, out_channels // groups)`. in_channels // groups, *kernel_size, out_channels // groups)`.
""" """
_conv_mode_type = P.Convolution.Mode
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ..core.ops._internal import param_defs as P
from ..functional.elemwise import _elwise from ..functional.elemwise import _elwise
from ..tensor import Tensor from ..tensor import Tensor
from .module import Module from .module import Module
......
...@@ -41,8 +41,8 @@ class Conv2d(Float.Conv2d, QATModule): ...@@ -41,8 +41,8 @@ class Conv2d(Float.Conv2d, QATModule):
float_module.dilation, float_module.dilation,
float_module.groups, float_module.groups,
float_module.bias is not None, float_module.bias is not None,
float_module.conv_mode.name, float_module.conv_mode,
float_module.compute_mode.name, float_module.compute_mode,
) )
qat_module.weight = float_module.weight qat_module.weight = float_module.weight
qat_module.bias = float_module.bias qat_module.bias = float_module.bias
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core.ops._internal import param_defs as P
from ...functional.elemwise import _elemwise_multi_type from ...functional.elemwise import _elemwise_multi_type
from ...tensor import Tensor from ...tensor import Tensor
from ..qat import elemwise as QAT from ..qat import elemwise as QAT
...@@ -15,11 +14,9 @@ from .module import QuantizedModule ...@@ -15,11 +14,9 @@ from .module import QuantizedModule
class Elemwise(QuantizedModule): class Elemwise(QuantizedModule):
r"""Quantized version of :class:`~.qat.elemwise.Elemwise`.""" r"""Quantized version of :class:`~.qat.elemwise.Elemwise`."""
_elemwise_multi_type_mode = P.ElemwiseMultiType.Mode
def __init__(self, method, dtype=None): def __init__(self, method, dtype=None):
super().__init__() super().__init__()
self.method = self._elemwise_multi_type_mode.convert("Q" + method) self.method = "Q" + method
self.output_dtype = dtype self.output_dtype = dtype
def forward(self, *inps): def forward(self, *inps):
......
...@@ -15,7 +15,7 @@ from typing import Iterable, List, Optional ...@@ -15,7 +15,7 @@ from typing import Iterable, List, Optional
from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry
from ..core._imperative_rt import ProfilerImpl as _Profiler from ..core._imperative_rt import ProfilerImpl as _Profiler
from ..core._imperative_rt.imperative import sync from ..core._imperative_rt.imperative import sync
from ..core._imperative_rt.ops import CollectiveCommMode from ..core._imperative_rt.ops import CollectiveComm
def _make_dict(**kwargs): def _make_dict(**kwargs):
...@@ -194,7 +194,7 @@ class Profiler: ...@@ -194,7 +194,7 @@ class Profiler:
_type_map = { _type_map = {
OperatorNodeConfig: lambda x: _print_opnode_config(x), OperatorNodeConfig: lambda x: _print_opnode_config(x),
bytes: lambda x: base64.encodebytes(x).decode("ascii"), bytes: lambda x: base64.encodebytes(x).decode("ascii"),
CollectiveCommMode: lambda x: str(x), CollectiveComm.Mode: lambda x: str(x),
} }
_dumper_map = { _dumper_map = {
......
...@@ -421,9 +421,7 @@ void init_graph_rt(py::module m) { ...@@ -421,9 +421,7 @@ void init_graph_rt(py::module m) {
common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) {
cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); cg::VarNodeArray vinputs(inputs.begin(), inputs.end());
auto opr = OpDef::apply_on_var_node(def, vinputs); return to_tuple(OpDef::apply_on_var_node(def, vinputs));
auto outputs = opr->usable_output();
return to_tuple(outputs);
}, },
py::arg(), py::arg(), py::arg("graph") = py::none()); py::arg(), py::arg(), py::arg("graph") = py::none());
......
...@@ -109,9 +109,6 @@ void init_imperative_rt(py::module m) { ...@@ -109,9 +109,6 @@ void init_imperative_rt(py::module m) {
py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef") py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef")
.def("ctype", [](const OpDef& opdef) { .def("ctype", [](const OpDef& opdef) {
if (auto attr = opdef.try_cast_final<OprAttr>()) {
return attr->type.c_str();
}
return opdef.dyn_typeinfo()->name; return opdef.dyn_typeinfo()->name;
}) })
.def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { .def("__eq__", [](const OpDef& lhs, const OpDef& rhs) {
......
...@@ -14,41 +14,29 @@ ...@@ -14,41 +14,29 @@
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/tensor_manip.h"
#include "megbrain/imperative/ops/collective_comm.h"
#include "megbrain/imperative/ops/io_remote.h"
#include "megbrain/imperative/ops/cond_take.h"
#include "megbrain/imperative/ops/nms.h"
#include "megbrain/imperative/ops/elemwise.h"
#include "megbrain/imperative/ops/batch_norm.h"
#include "megbrain/imperative/ops/broadcast.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/autogen.h"
namespace py = pybind11; namespace py = pybind11;
namespace {
auto normalize_enum(const std::string& in) {
std::string ret;
for (auto&& c : in) {
ret += toupper(c);
}
return ret;
}
} // anonymous namespace
void init_ops(py::module m) { void init_ops(py::module m) {
using namespace mgb::imperative; using namespace mgb::imperative;
py::class_<OprAttr, std::shared_ptr<OprAttr>, OpDef>(m, "OprAttr")
.def(py::init<>())
.def_readwrite("type", &OprAttr::type)
.def_readwrite("param", &OprAttr::param)
.def_readwrite("config", &OprAttr::config)
.def_property("param",
[](const OprAttr& attr) -> py::bytes {
return std::string(attr.param.begin(), attr.param.end());
},
[] (OprAttr& attr, py::bytes data) {
auto s = py::cast<std::string>(data);
attr.param.clear();
attr.param.insert(attr.param.end(), s.begin(), s.end());
});
py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph") py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph")
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc, .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc,
const mgb::SmallVector<py::object>& inputs) { const mgb::SmallVector<py::object>& inputs) {
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs)); return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
}; };
auto c = [pyc](const TensorPtr& tensor) { auto c = [pyc](const TensorPtr& tensor) {
return pyc(tensor->dev_tensor()); return pyc(tensor->dev_tensor());
...@@ -56,162 +44,8 @@ void init_ops(py::module m) { ...@@ -56,162 +44,8 @@ void init_ops(py::module m) {
return self.graph().interpret<py::object>(f, c, inputs); return self.graph().interpret<py::object>(f, c, inputs);
}); });
py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape")
.def(py::init());
#define V(m) .value(#m, CollectiveComm::Mode::m)
py::enum_<CollectiveComm::Mode>(m, "CollectiveCommMode")
V(REDUCE_SUM)
V(BROADCAST)
V(ALL_GATHER)
V(REDUCE_SCATTER_SUM)
V(ALL_REDUCE_SUM)
V(ALL_REDUCE_MAX)
V(ALL_REDUCE_MIN)
V(ALL_REDUCE_PROD)
V(GATHER)
V(SCATTER)
V(ALL_TO_ALL);
#undef V
py::class_<CollectiveComm, std::shared_ptr<CollectiveComm>, OpDef>(m, "CollectiveComm")
.def(py::init<>())
.def_readwrite("key", &CollectiveComm::key)
.def_readwrite("nr_devices", &CollectiveComm::nr_devices)
.def_readwrite("rank", &CollectiveComm::rank)
.def_readwrite("is_root", &CollectiveComm::is_root)
.def_readwrite("local_grad", &CollectiveComm::local_grad)
.def_readwrite("addr", &CollectiveComm::addr)
.def_readwrite("port", &CollectiveComm::port)
.def_readwrite("mode", &CollectiveComm::mode)
.def_readwrite("dtype", &CollectiveComm::dtype)
.def_readwrite("backend", &CollectiveComm::backend)
.def_readwrite("comp_node", &CollectiveComm::comp_node);
py::class_<RemoteSend, std::shared_ptr<RemoteSend>, OpDef>(m, "RemoteSend")
.def(py::init<>())
.def_readwrite("key", &RemoteSend::key)
.def_readwrite("addr", &RemoteSend::addr)
.def_readwrite("port", &RemoteSend::port)
.def_readwrite("rank_to", &RemoteSend::rank_to);
py::class_<RemoteRecv, std::shared_ptr<RemoteRecv>, OpDef>(m, "RemoteRecv")
.def(py::init<>())
.def_readwrite("key", &RemoteRecv::key)
.def_readwrite("addr", &RemoteRecv::addr)
.def_readwrite("port", &RemoteRecv::port)
.def_readwrite("rank_from", &RemoteRecv::rank_from)
.def_readwrite("shape", &RemoteRecv::shape)
.def_readwrite("cn", &RemoteRecv::cn)
.def_readwrite("dtype", &RemoteRecv::dtype);
py::class_<ParamPackSplit, std::shared_ptr<ParamPackSplit>, OpDef>(m, "ParamPackSplit")
.def(py::init<>())
.def_readwrite("offsets", &ParamPackSplit::offsets)
.def_readwrite("shapes", &ParamPackSplit::shapes);
py::class_<ParamPackConcat, std::shared_ptr<ParamPackConcat>, OpDef>(m, "ParamPackConcat")
.def(py::init<>())
.def_readwrite("offsets", &ParamPackConcat::offsets);
py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep") py::class_<VirtualDep, std::shared_ptr<VirtualDep>, OpDef>(m, "VirtualDep")
.def(py::init<>()); .def(py::init<>());
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake") #include "opdef.py.inl"
.def(py::init<>());
py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef>(m, "NMSKeep")
.def(py::init<float, uint32_t>())
.def_readwrite("iou_thresh", &NMSKeep::iou_thresh)
.def_readwrite("max_output", &NMSKeep::max_output);
py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> elemwise(m, "Elemwise");
elemwise.def(py::init<Elemwise::Mode>())
.def_readwrite("mode", &Elemwise::mode);
#define V(m) .value(#m, Elemwise::Mode::m)
py::enum_<Elemwise::Mode>(elemwise, "Mode")
V(RELU)
V(ABS)
V(ACOS)
V(ASIN)
V(CEIL)
V(COS)
V(EXP)
V(EXPM1)
V(FLOOR)
V(LOG)
V(LOG1P)
V(NEGATE)
V(SIGMOID)
V(SIN)
V(TANH)
V(ABS_GRAD)
V(ADD)
V(FLOOR_DIV)
V(MAX)
V(MIN)
V(MOD)
V(MUL)
V(POW)
V(SIGMOID_GRAD)
V(SUB)
V(SWITCH_GT0)
V(TANH_GRAD)
V(TRUE_DIV)
V(LOG_SUM_EXP)
V(LT)
V(LEQ)
V(EQ)
V(SHL)
V(SHR)
V(COND_LEQ_MOV)
V(FUSE_MUL_ADD3)
V(FUSE_MUL_ADD4)
V(FUSE_ADD_RELU)
V(FUSE_ADD_SIGMOID)
V(FUSE_ADD_TANH)
V(FAST_TANH)
V(FAST_TANH_GRAD)
V(ROUND)
V(RMULH)
V(ATAN2)
V(ERF)
V(ERFINV)
V(ERFC)
V(ERFCINV)
V(H_SWISH)
V(H_SWISH_GRAD)
V(FUSE_ADD_H_SWISH)
V(NOT)
V(AND)
V(OR)
V(XOR);
#undef V
py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> batchnorm(m, "BatchNorm");
batchnorm.def(py::init<const BatchNorm::Param::ParamDim&, const BatchNorm::Param::FwdMode&, double, double, float, float>())
.def_readwrite("param_dim", &BatchNorm::param_dim)
.def_readwrite("fwd_mode", &BatchNorm::fwd_mode)
.def_readwrite("epsilon", &BatchNorm::epsilon)
.def_readwrite("avg_factor", &BatchNorm::avg_factor)
.def_readwrite("scale", &BatchNorm::scale)
.def_readwrite("bias", &BatchNorm::bias);
#define V(m) .value(#m, BatchNorm::Param::ParamDim::m)
py::enum_<BatchNorm::Param::ParamDim>(batchnorm, "ParamDim")
V(DIM_11HW)
V(DIM_1CHW)
V(DIM_1C11);
#undef V
#define V(m) .value(#m, BatchNorm::Param::FwdMode::m)
py::enum_<BatchNorm::Param::FwdMode>(batchnorm, "FwdMode")
V(TRAINING)
V(INFERENCE);
#undef V
py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef>(m, "Broadcast")
.def(py::init<>());
} }
...@@ -113,7 +113,7 @@ def test_quint8_typecvt(): ...@@ -113,7 +113,7 @@ def test_quint8_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1 data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None): def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(param=dt), x) (y,) = apply(ops.TypeCvt(dtype=dt), x)
return y return y
# convert to quint8 # convert to quint8
...@@ -194,7 +194,7 @@ def test_quint4_typecvt(): ...@@ -194,7 +194,7 @@ def test_quint4_typecvt():
data = np.random.random(shape).astype(np.float32) * 5 - 1 data = np.random.random(shape).astype(np.float32) * 5 - 1
def typecvt(x, dt=None): def typecvt(x, dt=None):
(y,) = apply(ops.TypeCvt(param=dt), x) (y,) = apply(ops.TypeCvt(dtype=dt), x)
return y return y
# convert to quint4 # convert to quint4
......
...@@ -11,10 +11,9 @@ import collections ...@@ -11,10 +11,9 @@ import collections
import numpy as np import numpy as np
import pytest import pytest
import megengine.core.ops.builtin
import megengine.core.tensor.raw_tensor import megengine.core.tensor.raw_tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops._internal import all_ops from megengine.core.ops import builtin
from megengine.core.tensor import Tensor from megengine.core.tensor import Tensor
from megengine.core.tensor.core import apply from megengine.core.tensor.core import apply
from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor from megengine.core.tensor.raw_tensor import RawTensor, as_raw_tensor
...@@ -105,7 +104,7 @@ def canonize_inputs(inputs, *, config): ...@@ -105,7 +104,7 @@ def canonize_inputs(inputs, *, config):
need_cvt = False need_cvt = False
for i in old_inputs: for i in old_inputs:
if isinstance(i, RawTensor): if isinstance(i, RawTensor):
get_comp_node = lambda cn=i.device.to_c(): cn get_comp_node = lambda cn=i.device: cn
else: else:
need_cvt = True need_cvt = True
inputs.append(i) inputs.append(i)
...@@ -193,91 +192,91 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): ...@@ -193,91 +192,91 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
def transpose(*args, **kwargs): def transpose(*args, **kwargs):
op = all_ops.Dimshuffle(**kwargs).to_c() op = builtin.Dimshuffle(**kwargs)
return invoke_op(op, args) return invoke_op(op, args)
def broadcast(input, tshape): def broadcast(input, tshape):
op = all_ops.Broadcast().to_c() op = builtin.Broadcast()
return invoke_op(op, (input, tshape), canonize_reshape) return invoke_op(op, (input, tshape), canonize_reshape)
def subtensor(input, tuple_val): def subtensor(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.Subtensor(items).to_c() op = builtin.Subtensor(items)
return invoke_op(op, (input, *tensors)) return invoke_op(op, (input, *tensors))
def set_subtensor(input, value, tuple_val): def set_subtensor(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.SetSubtensor(items).to_c() op = builtin.SetSubtensor(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def incr_subtensor(input, value, tuple_val): def incr_subtensor(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IncrSubtensor(items).to_c() op = builtin.IncrSubtensor(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def advance_indexing(input, tuple_val): def advance_indexing(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IndexingMultiAxisVec(items).to_c() op = builtin.IndexingMultiAxisVec(items)
return invoke_op(op, (input, *tensors)) return invoke_op(op, (input, *tensors))
def set_advance_indexing(input, value, tuple_val): def set_advance_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IndexingSetMultiAxisVec(items).to_c() op = builtin.IndexingSetMultiAxisVec(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def incr_advance_indexing(input, value, tuple_val): def incr_advance_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IndexingIncrMultiAxisVec(items).to_c() op = builtin.IndexingIncrMultiAxisVec(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def mesh_indexing(input, tuple_val): def mesh_indexing(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.MeshIndexing(items).to_c() op = builtin.MeshIndexing(items)
return invoke_op(op, (input, *tensors)) return invoke_op(op, (input, *tensors))
def set_mesh_indexing(input, value, tuple_val): def set_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.SetMeshIndexing(items).to_c() op = builtin.SetMeshIndexing(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def incr_mesh_indexing(input, value, tuple_val): def incr_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.IncrMeshIndexing(items).to_c() op = builtin.IncrMeshIndexing(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def batched_mesh_indexing(input, tuple_val): def batched_mesh_indexing(input, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.BatchedMeshIndexing(items).to_c() op = builtin.BatchedMeshIndexing(items)
return invoke_op(op, (input, *tensors)) return invoke_op(op, (input, *tensors))
def batched_set_mesh_indexing(input, value, tuple_val): def batched_set_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.BatchedSetMeshIndexing(items).to_c() op = builtin.BatchedSetMeshIndexing(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def batched_incr_mesh_indexing(input, value, tuple_val): def batched_incr_mesh_indexing(input, value, tuple_val):
input, tensors, items = unpack_getitem(input, tuple_val) input, tensors, items = unpack_getitem(input, tuple_val)
op = all_ops.BatchedIncrMeshIndexing(items).to_c() op = builtin.BatchedIncrMeshIndexing(items)
return invoke_op(op, (input, value, *tensors)) return invoke_op(op, (input, value, *tensors))
def test_transpose(): def test_transpose():
x = np.arange(10).reshape(2, 5).astype("int32") x = np.arange(10).reshape(2, 5).astype("int32")
xx = as_raw_tensor(x) xx = as_raw_tensor(x)
(yy,) = transpose(xx, pattern="1x0") (yy,) = transpose(xx, pattern=[1, -1, 0])
np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy())
......
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from io import StringIO
import re
import argparse
import subprocess
import os
import textwrap
import inspect
def camel2underscore(
name, *,
first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
all_cap_re = re.compile('([a-z])([A-Z]+)')):
if name.isupper():
return name.lower()
s1 = first_cap_re.sub(r'\1_\2', name)
return all_cap_re.sub(r'\1_\2', s1).lower()
def caller_lineno(level=1):
f = inspect.stack()[level+1]
return '%s:%d' % (f.filename, f.lineno)
class Doc:
"""wrap an identifier and doc"""
_id = None
def __init__(self, id_, doc, typestr=None, default=None):
self._id = id_
self.doc = doc
self.typestr = typestr
self.default = default
def __str__(self):
return self._id
class Context:
fout = None
def __init__(self):
self.fout = StringIO()
self.indent = 0
self.skipped = []
self.generated_signature = set()
self.generated_opr = dict()
def write(self, text, *fmt, indent=0):
text = textwrap.dedent(text)
text = textwrap.indent(text, ' '*4*(self.indent + indent))
text = text % fmt
if not text.endswith('\n'):
text += '\n'
self.fout.write(text)
def _gen_signature(self, params, *, have_config=True,
has_out_dtype=False):
sig = ['self', '*']
for i, _ in params:
sig.append('{}=None'.format(i))
if have_config:
sig.extend(['name=None', 'comp_node=None', 'config=None'])
if has_out_dtype:
sig.append('dtype=None')
if params:
sig.append('**kwargs')
if sig[-1] == '*':
sig.pop()
return ', '.join(sig)
def _write_canonize_inputs(self, inputs, convert_inputs,
convert_inputs_args=None,
has_out_dtype=False):
self._write_gen_config(has_out_dtype)
inputs = list(map(str, inputs))
if convert_inputs_args is None:
if inputs[0][0] == '*':
arg = inputs[0][1:]
else:
arg = '[{}]'.format(', '.join(inputs))
else:
arg = convert_inputs_args
self.write('inputs = helper.%s(%s, config=config)',
convert_inputs, arg)
def _write_gen_config(self, has_out_dtype=False):
self.write('''\
config = config or Config()
if name:
config.name = name
if comp_node:
config.comp_node = comp_node
''')
if has_out_dtype:
self.write('''\
if dtype:
config.dtype = dtype
''')
self.write('self.config = config')
def _write_make_params(self, params):
for pname, ptype in params:
self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)',
pname, pname, ptype)
self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)')
def _write_doc(self, inputs, params, desc):
self.write('"""')
if isinstance(desc, Doc):
assert desc._id is None
self.write(desc.doc)
elif desc:
for i in textwrap.wrap(desc, 75):
self.write(i)
self.write('')
for i in inputs:
name = str(i)
typestr = ':class:`.Tensor`'
if name[0] == '*':
name = name[1:]
typestr = 'list of ' + typestr
if isinstance(i, Doc):
self.write(':param %s: %s', name, i.doc)
if i.typestr is not None:
typestr = i.typestr
if typestr:
if not isinstance(i, Doc):
self.write(':param %s: ', name)
self.write(':type %s: %s', name, typestr)
for pname, ptype in params:
self.write(':param %s: ', pname)
self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`',
pname, ptype)
self.write(':param comp_node: see doc for *config*')
self.write(':param name: see doc for *config*')
self.write(
':param config: give a :class:`.OperatorNodeConfig` object to set '
'operator name and comp node. This can also be achieved by passing '
'*comp_node* and *name* separately.')
self.write('"""')
def _write_return(self, name, outputs):
self.write('opdef = helper.PodOpVisitor("%s", config, params)', name)
self.write('outputs = helper.create_op(opdef, inputs)')
if outputs:
self.write('outputs = [outputs[i] for i in %s]',
list(map(int, outputs)))
self.write('return helper.convert_outputs(outputs)')
def decl_opr(self, name, *, inputs, params, desc=None, pyname=None,
canonize_input_vars=None,
canonize_input_vars_args=None, body=None,
outputs=None, version=0, has_out_dtype=False):
"""
:param inputs: name of variable inputs; a name starting with `*' means
a list of vars
:type inputs: list of str
:param params: (param name, param type) pairs; it can be a single
string representing the param type, and param name defaults to
'param'
:type params: list of pair of str, or str
:param pyname: python function name
:param body: extra statements to be placed before calling _create_opr
:param outputs: the indices of output vars to be selected from raw opr
result
"""
class OprItem:
def __init__(self, inputs, desc, params, version, has_out_dtype):
self.inputs = inputs
self.desc = desc
self.params = params
self.version = version
self.has_out_dtype = has_out_dtype
if body:
self.skipped.append(name)
return
signature = (name, params if isinstance(params, str) else frozenset(params), has_out_dtype, version)
if signature in self.generated_signature:
self.skipped.append(name)
return
else:
self.generated_signature.add(signature)
body = body or []
if isinstance(params, str):
params = [('param', params)]
assert params
if name in self.generated_opr:
org_opr = self.generated_opr[name]
if version > org_opr.version:
def compare_doc(a, b):
if isinstance(a, str):
return a == b
else:
assert isinstance(a, Doc)
return a.doc == b.doc
assert compare_doc(desc, org_opr.desc)
assert len(inputs) == len(org_opr.inputs)
for i, j in zip(inputs, org_opr.inputs):
assert compare_doc(i, j)
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
else:
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
def write_generated_oprs(self):
for opr, opr_item in self.generated_opr.items():
name = opr
params = opr_item.params
version = opr_item.version
has_out_dtype = opr_item.has_out_dtype
self.write('# %s', caller_lineno())
self.write('class %s(PodOpVisitor):', name)
self.indent += 1
param_names, _ = zip(*params)
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
self.write('\n')
self.write('def __init__(%s):',
self._gen_signature(params,
has_out_dtype=has_out_dtype))
self.indent += 1
self._write_gen_config(has_out_dtype=has_out_dtype)
self.write('\n')
self._write_make_params(params)
self.write('\n')
self.indent -= 2
def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
desc=None, local_defs=[], have_config=True, params=None, has_out_dtype=False):
self.skipped.append(name)
def get_str(self):
return self.fout.getvalue()
def all_list(self):
buf = StringIO()
print(
'[',
*(' "%s",' % i for i in self.generated_opr),
']',
sep='\n',
file=buf
)
return buf.getvalue()
def main():
parser = argparse.ArgumentParser(
description='generate operator function def code from decl file')
parser.add_argument('inputs', nargs='+')
parser.add_argument('--output', '-o')
args = parser.parse_args()
gen = Context()
exec_globals = {
'decl_opr': gen.decl_opr,
'decl_raw_opr': gen.decl_raw_opr,
'Doc': Doc,
'camel2underscore': camel2underscore,
}
for i in args.inputs:
print('generate ops from {}'.format(i))
with open(i) as fin:
exec(compile(fin.read(), i, 'exec'), exec_globals)
gen.write_generated_oprs()
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], universal_newlines=True,
cwd=os.path.dirname(os.path.realpath(__file__))).strip()
except:
git_commit = 'NOT_A_GIT_REPO'
def relpath(*args):
d = os.path.dirname(__file__)
return os.path.join(d, *args)
with open(relpath('ops.tpl.py')) as fin:
with open(args.output, 'w') as fout:
fout.write(fin.read()
.replace('{%all%}', gen.all_list())
.replace('{%body%}', gen.get_str())
.replace('{%git_commit%}', git_commit))
print('Skipped:')
print(*gen.skipped, sep='\n')
if __name__ == '__main__':
main()
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""This python module contains functions to apply the operators defined by
megbrain.
.. note::
Most of the functions are automatically generated, and their signature have
the form contain a ``param`` argument (or more than one arguments such as
:func:`convolution` that has ``param`` and ``execution_polity``) and also
accept keyword arguments. In such case, it can be called by either
providing a param object of appropriate type, or by passing the arguments
needed by the constructor of param object to the keyword arguments.
Furthermore, for a param that needs an enumeration member, the enum name
can be used to refer to the enum object.
For example, the following statements are equivalent::
elemwise([a, b], mode='max')
elemwise([a, b], mode=opr_param_defs.Elemwise.Mode.MAX)
elemwise([a, b], param=opr_param_defs.Elemwise('max'))
"""
__git_commit__ = "{%git_commit%}"
import collections
from . import helper
from .helper import PodOpVisitor
from . import param_defs
from ..._imperative_rt import OperatorNodeConfig as Config
__all__ = {%all%}
{%body%}
...@@ -36,7 +36,7 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( ...@@ -36,7 +36,7 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
return def.trait()->apply_on_physical_tensor(def, inputs); return def.trait()->apply_on_physical_tensor(def, inputs);
} }
cg::OperatorNodeBase* OpDef::apply_on_var_node( VarNodeArray OpDef::apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
return def.trait()->apply_on_var_node(def, inputs); return def.trait()->apply_on_var_node(def, inputs);
...@@ -56,6 +56,14 @@ BackwardGraphResult OpDef::make_backward_graph( ...@@ -56,6 +56,14 @@ BackwardGraphResult OpDef::make_backward_graph(
return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad); return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
} }
size_t OpDef::hash() const {
return trait()->hash(*this);
}
bool OpDef::is_same_st(const Hashable& rhs) const {
return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs));
}
const OpTrait* OpDef::trait() const { const OpTrait* OpDef::trait() const {
if (!m_trait) { if (!m_trait) {
m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo()); m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo());
......
...@@ -23,7 +23,7 @@ namespace detail { ...@@ -23,7 +23,7 @@ namespace detail {
struct StaticData { struct StaticData {
std::list<OpTrait> registries; std::list<OpTrait> registries;
std::unordered_map<const char*, OpTrait*> name2reg; std::unordered_map<std::string, OpTrait*> name2reg;
std::unordered_map<Typeinfo*, OpTrait*> type2reg; std::unordered_map<Typeinfo*, OpTrait*> type2reg;
}; };
......
...@@ -30,6 +30,32 @@ struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> { ...@@ -30,6 +30,32 @@ struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> {
return this->Base::operator ()(args...); return this->Base::operator ()(args...);
} }
}; };
template<typename T>
struct ToVarNodeArray: std::false_type {};
template<>
struct ToVarNodeArray<SymbolVar>: std::true_type {
VarNodeArray operator()(const SymbolVar& inp) {
return {inp.node()};
}
};
template<>
struct ToVarNodeArray<SymbolVarArray>: std::true_type {
VarNodeArray operator()(const SymbolVarArray& inputs) {
return cg::to_var_node_array(inputs);
}
};
template<size_t N>
struct ToVarNodeArray<std::array<SymbolVar, N>>: std::true_type {
VarNodeArray operator()(const std::array<SymbolVar, N>& inp) {
return cg::to_var_node_array({inp.begin(), inp.end()});
}
};
template<>
struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
VarNodeArray operator()(const cg::OperatorNodeBase* opr) {
return opr->usable_output();
}
};
} // detail } // detail
using OpDefMaker = detail::OpMeth< using OpDefMaker = detail::OpMeth<
...@@ -42,6 +68,8 @@ using InferOutputAttrsFallible = detail::OpMeth< ...@@ -42,6 +68,8 @@ using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>; decltype(OpDef::infer_output_attrs_fallible)>;
using GradMaker = detail::OpMeth< using GradMaker = detail::OpMeth<
decltype(OpDef::make_backward_graph)>; decltype(OpDef::make_backward_graph)>;
using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
struct OpTrait { struct OpTrait {
const char* name; const char* name;
...@@ -50,6 +78,8 @@ struct OpTrait { ...@@ -50,6 +78,8 @@ struct OpTrait {
ApplyOnVarNode apply_on_var_node; ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible; InferOutputAttrsFallible infer_output_attrs_fallible;
GradMaker make_backward_graph; GradMaker make_backward_graph;
HashFunc hash;
IsSame is_same_st;
OpTrait(const char* name); OpTrait(const char* name);
static OpTrait* find_by_name(const char* name); static OpTrait* find_by_name(const char* name);
static OpTrait* find_by_typeinfo(Typeinfo* type); static OpTrait* find_by_typeinfo(Typeinfo* type);
...@@ -61,7 +91,9 @@ struct OpTrait { ...@@ -61,7 +91,9 @@ struct OpTrait {
cb(apply_on_physical_tensor) \ cb(apply_on_physical_tensor) \
cb(apply_on_var_node) \ cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \ cb(infer_output_attrs_fallible) \
cb(make_backward_graph) cb(make_backward_graph) \
cb(hash) \
cb(is_same_st)
struct OpTraitRegistry { struct OpTraitRegistry {
OpTrait* trait; OpTrait* trait;
...@@ -97,6 +129,15 @@ struct OpTraitRegistry { ...@@ -97,6 +129,15 @@ struct OpTraitRegistry {
void do_insert(Typeinfo* type); void do_insert(Typeinfo* type);
static OpTraitRegistry do_insert(const char* name); static OpTraitRegistry do_insert(const char* name);
template<typename T,
typename To = detail::ToVarNodeArray<T>,
typename = std::enable_if_t<To::value>>
OpTraitRegistry& apply_on_var_node(T (*f)(const OpDef&, const VarNodeArray&)) {
return apply_on_var_node([=](const OpDef& opdef, const VarNodeArray& inputs) {
return To()(f(opdef, inputs));
});
}
}; };
} // namespace imperative } // namespace imperative
......
/** /**
* \file imperative/src/include/megbrain/imperative/ops/broadcast.h * \file imperative/src/impl/ops/autogen.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -9,27 +9,38 @@ ...@@ -9,27 +9,38 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#pragma once #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h" #include "../op_trait.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/op_def.h"
namespace mgb::imperative { using namespace megdnn;
class Broadcast : public OpDefImplBase<Broadcast> { // FIXME: remove this when mgb::hash support tuple_hash
MGB_DYN_TYPE_OBJ_FINAL_DECL; namespace mgb {
namespace {
template<typename T, size_t ...Ns>
auto tail(T t, std::index_sequence<Ns...>) {
return std::make_tuple(std::get<Ns+1>(t)...);
}
} // anonymous namespace
template<typename T, typename ...Args>
class HashTrait<std::tuple<T, Args...>> {
constexpr static size_t length = sizeof...(Args);
public: public:
Broadcast() = default; static size_t eval(const std::tuple<T, Args...> &t) {
const T& val = std::get<0>(t);
size_t hash() const override { if constexpr (!length) {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo()); return mgb::hash(val);
} else {
return mgb::hash_pair_combine(mgb::hash(val),
mgb::hash(tail(t, std::make_index_sequence<length - 1>{})));
}
} }
};
} // namespace mgb
bool is_same_st(const Hashable& rhs) const override { namespace mgb::imperative {
return true;
}
}; #include "./opdef.cpp.inl"
} // namespace mgb::imperative } // namespace mgb::imperative
\ No newline at end of file
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megbrain/imperative/ops/batch_norm.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "../op_trait.h" #include "../op_trait.h"
namespace mgb { namespace mgb {
...@@ -19,9 +20,7 @@ namespace { ...@@ -19,9 +20,7 @@ namespace {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::BatchNorm>(); auto* node = &node_->cast_final_safe<opr::BatchNorm>();
auto&& param = node->param(); return BatchNorm::make(node->param());
return BatchNorm::make(param.param_dim, param.fwd_mode, param.epsilon,
param.avg_factor, param.scale, param.bias);
} }
cg::OperatorNodeBase* apply_on_var_node( cg::OperatorNodeBase* apply_on_var_node(
...@@ -33,13 +32,11 @@ cg::OperatorNodeBase* apply_on_var_node( ...@@ -33,13 +32,11 @@ cg::OperatorNodeBase* apply_on_var_node(
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
if (nr_inp == 3) { if (nr_inp == 3) {
return opr::BatchNorm::make( return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], inputs[0], inputs[1], inputs[2], bn_opr.param())[0]
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
.node()->owner_opr(); .node()->owner_opr();
} else { } else {
return opr::BatchNorm::make( return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0]
{bn_opr.param_dim, bn_opr.fwd_mode, bn_opr.epsilon, bn_opr.avg_factor, bn_opr.scale, bn_opr.bias})[0]
.node()->owner_opr(); .node()->owner_opr();
} }
} }
...@@ -52,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -52,7 +49,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
mgb_assert(nr_inp == 3 ||nr_inp == 5, mgb_assert(nr_inp == 3 ||nr_inp == 5,
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
// need running mean/variance // need running mean/variance
bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::Param::FwdMode::TRAINING; bool need_stat = (nr_inp == 5) && op_def.fwd_mode == BatchNorm::FwdMode::TRAINING;
size_t nr_out = need_stat? 5 : 3; size_t nr_out = need_stat? 5 : 3;
SmallVector<LogicalTensorDesc> out_shapes(nr_out); SmallVector<LogicalTensorDesc> out_shapes(nr_out);
auto&& i0 = inputs[0]; auto&& i0 = inputs[0];
...@@ -76,8 +73,6 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) ...@@ -76,8 +73,6 @@ OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm)
.fallback(); .fallback();
} // anonymous namespace } // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNorm);
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megbrain/imperative/ops/broadcast.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/tensor_manip.h"
#include "../op_trait.h" #include "../op_trait.h"
namespace mgb { namespace mgb {
...@@ -87,8 +89,6 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) ...@@ -87,8 +89,6 @@ OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.fallback(); .fallback();
} // anonymous namespace } // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Broadcast);
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "megbrain/utils/hash.h" #include "megbrain/utils/hash.h"
#endif // MGB_ENABLE_OPR_MM #endif // MGB_ENABLE_OPR_MM
#include "megbrain/imperative/ops/collective_comm.h" #include "megbrain/imperative/ops/autogen.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -61,8 +61,8 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node) { ...@@ -61,8 +61,8 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node) {
auto [addr, port] = split_address(group_client->get_addr()); auto [addr, port] = split_address(group_client->get_addr());
auto comp_node = node->config().get_single_comp_node().to_string_logical(); auto comp_node = node->config().get_single_comp_node().to_string_logical();
return std::make_shared<CollectiveComm>( return std::make_shared<CollectiveComm>(
comm.key(), comm.nr_devices(), comm.rank(), comm.is_root(), comm.param().mode, comm.key(), comm.nr_devices(), comm.rank(),
comm.local_grad(), addr, std::stoi(port), comm.param().mode, comm.is_root(), comm.local_grad(), addr, std::stoi(port),
comm.dtype(), comm.backend(), comp_node); comm.dtype(), comm.backend(), comp_node);
} }
...@@ -73,35 +73,6 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) ...@@ -73,35 +73,6 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm)
} // anonymous namespace } // anonymous namespace
#endif // MGB_ENABLE_OPR_MM #endif // MGB_ENABLE_OPR_MM
bool CollectiveComm::is_same_st(const Hashable& another) const{
auto* comm_opr = another.try_cast_final<CollectiveComm>();
if(!comm_opr){
return false;
}
return as_tuple() == comm_opr->as_tuple();
}
size_t CollectiveComm::hash() const{
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(nr_devices);
append(rank);
append(is_root);
append(local_grad);
append(addr);
append(port);
append(mode);
append(backend);
append(comp_node);
return xxhash.digest();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megbrain/imperative/ops/cond_take.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/misc.h" #include "megbrain/opr/misc.h"
#include "../dnn_op_helper.h" #include "../dnn_op_helper.h"
#include "../op_trait.h" #include "../op_trait.h"
...@@ -19,8 +18,6 @@ using namespace megdnn; ...@@ -19,8 +18,6 @@ using namespace megdnn;
namespace mgb::imperative { namespace mgb::imperative {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);
namespace { namespace {
class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy {
......
...@@ -9,7 +9,9 @@ ...@@ -9,7 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megbrain/imperative/ops/elemwise.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"
#include "../op_trait.h" #include "../op_trait.h"
namespace mgb { namespace mgb {
...@@ -33,7 +35,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -33,7 +35,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<Elemwise>(); auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = Elemwise::ModeTrait::from_mode(op_def.mode); auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert(inputs.size() == trait.arity, mgb_assert(inputs.size() == trait.arity,
"%s expects %u inputs; got %zu actually", trait.name, "%s expects %u inputs; got %zu actually", trait.name,
trait.arity, inputs.size()); trait.arity, inputs.size());
...@@ -70,8 +72,6 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) ...@@ -70,8 +72,6 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise)
.fallback(); .fallback();
} // anonymous namespace } // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise);
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "megbrain/opr/mm_handler.h" #include "megbrain/opr/mm_handler.h"
#endif // MGB_ENABLE_OPR_MM #endif // MGB_ENABLE_OPR_MM
#include "megbrain/imperative/ops/io_remote.h" #include "megbrain/imperative/ops/autogen.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
...@@ -60,45 +60,5 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv) ...@@ -60,45 +60,5 @@ OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv)
} // anonymous namespace } // anonymous namespace
#endif // MGB_ENABLE_OPR_MM #endif // MGB_ENABLE_OPR_MM
bool RemoteSend::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteSend>().as_tuple();
}
size_t RemoteSend::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_to);
return xxhash.digest();
}
bool RemoteRecv::is_same_st(const Hashable& another) const{
return as_tuple() == another.cast_final<RemoteRecv>().as_tuple();
}
size_t RemoteRecv::hash() const{
XXHash xxhash;
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(key);
append(addr);
append(port);
append(rank_from);
append(cn.to_string());
append(dtype.handle());
append(shape.to_string());
return xxhash.digest();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteSend);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemoteRecv);
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "../op_trait.h" #include "../op_trait.h"
#include "megbrain/imperative/ops/nms.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/standalone/nms_opr.h" #include "megbrain/opr/standalone/nms_opr.h"
namespace mgb { namespace mgb {
...@@ -37,8 +37,6 @@ OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) ...@@ -37,8 +37,6 @@ OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr)
.fallback(); .fallback();
} // anonymous namespace } // anonymous namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NMSKeep);
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
/**
* \file imperative/src/impl/ops/autogen.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
// FIXME: split this file into separate files for each specialized op
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/roi_align.h"
#include "megbrain/opr/dnn/roi_pooling.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace { namespace convolution {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Convolution>();
return Convolution::make(node->param(), node->execution_policy());
}
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution&>(def);
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy());
}
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution
namespace { namespace convolution_backward_data {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
cg::OperatorNodeConfig config;
if (inputs.size() == 2) {
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else {
mgb_assert(inputs.size() == 3);
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
}
}
OP_TRAIT_REG(ConvolutionBackwardData, ConvolutionBackwardData)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // convolution_backward_data
namespace { namespace dimshuffle {
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
auto* node = &node_->cast_final_safe<opr::Dimshuffle>();
std::vector<int> pattern(node->param().pattern_len);
for (size_t i = 0; i < node->param().pattern_len; ++ i) {
pattern[i] = node->param().pattern[i];
}
return Dimshuffle::make(pattern);
}
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& ds = static_cast<const Dimshuffle&>(def);
return opr::Dimshuffle::make(inputs[0], ds.pattern);
}
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // dimshuffle
namespace { namespace add_axis {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& add_axis = static_cast<const AddAxis&>(def);
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> param;
for (auto&& i : add_axis.axis) {
param.push_back(Desc::make_add(i));
}
return opr::AxisAddRemove::make(inputs[0], param);
}
OP_TRAIT_REG(AddAxis, AddAxis)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // add_axis
namespace { namespace remove_axis {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& remove_axis = static_cast<const RemoveAxis&>(def);
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> param;
for (auto&& i : remove_axis.axis) {
param.push_back(Desc::make_remove(i));
}
return opr::AxisAddRemove::make(inputs[0], param);
}
OP_TRAIT_REG(RemoveAxis, RemoveAxis)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // remove_axis
namespace { namespace top_k {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& topk = static_cast<const TopK&>(def);
return opr::TopK::make(inputs[0], inputs[1], topk.param())[0]
.node()->owner_opr();
}
OP_TRAIT_REG(TopK, TopK)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // top_k
namespace { namespace reduce {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def);
if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]);
} else {
return opr::Reduce::make(inputs[0], reduce.param());
}
}
OP_TRAIT_REG(Reduce, Reduce)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reduce
namespace { namespace adaptive_pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& pool = static_cast<const AdaptivePooling&>(def);
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param());
}
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // adaptive_pooling
namespace { namespace conv_bias {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
if (inputs.size() == 2) {
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) {
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
} else if (inputs.size() == 4) {
return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
}
mgb_assert(0);
}
OP_TRAIT_REG(ConvBias, ConvBias)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // conv_bias
namespace { namespace batch_conv_bias {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& conv = static_cast<const BatchConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype};
if (inputs.size() == 2) {
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) {
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config);
} else if (inputs.size() == 4) {
return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config);
}
mgb_assert(0);
}
OP_TRAIT_REG(BatchConvBias, BatchConvBias)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // batch_conv_bias
namespace { namespace pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def);
return opr::Pooling::make(inputs[0], pool.param());
}
OP_TRAIT_REG(Pooling, Pooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // pooling
namespace { namespace matrix_mul {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param());
}
OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // matrix_mul
namespace { namespace batched_matrix_mul {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param());
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // batched_matrix_mul
namespace { namespace dot {
auto apply_on_var_node(
const OpDef&,
const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 2);
return opr::Dot::make(inputs[0], inputs[1]);
}
OP_TRAIT_REG(Dot, Dot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // dot
namespace { namespace argsort {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argsort = static_cast<const Argsort&>(def);
return opr::Argsort::make(inputs[0], argsort.param());
}
OP_TRAIT_REG(Argsort, Argsort)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argsort
namespace { namespace argmax {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argmax = static_cast<const Argmax&>(def);
return opr::Argmax::make(inputs[0], argmax.param());
}
OP_TRAIT_REG(Argmax, Argmax)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argmax
namespace { namespace argmin {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& argmin = static_cast<const Argmin&>(def);
return opr::Argmin::make(inputs[0], argmin.param());
}
OP_TRAIT_REG(Argmin, Argmin)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // argmin
namespace { namespace warp_perspective {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& warp = static_cast<const WarpPerspective&>(def);
if (inputs.size() == 3) {
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param());
} else {
mgb_assert(inputs.size() == 4);
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param());
}
}
OP_TRAIT_REG(WarpPerspective, WarpPerspective)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // warp_perspective
namespace { namespace group_local {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& local = static_cast<const GroupLocal&>(def);
mgb_assert(inputs.size() == 2);
return opr::GroupLocal::make(inputs[0], inputs[1], local.param());
}
OP_TRAIT_REG(GroupLocal, GroupLocal)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // group_local
namespace { namespace indexing_one_hot {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2);
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // indexing_one_hot
namespace { namespace indexing_set_one_hot {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3);
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param());
}
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // indexing_set_one_hot
namespace { namespace typecvt {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const TypeCvt&>(def);
mgb_assert(inputs.size() == 1);
return opr::TypeCvt::make(inputs[0], op.dtype);
}
OP_TRAIT_REG(TypeCvt, TypeCvt)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // typecvt
namespace { namespace concat {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Concat&>(def);
cg::OperatorNodeConfig config{op.comp_node};
return opr::Concat::make(inputs, op.axis, config);
}
OP_TRAIT_REG(Concat, Concat)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // concat
namespace { namespace copy {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Copy&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
return opr::Copy::make(inputs[0], config);
}
OP_TRAIT_REG(Copy, Copy)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // copy
namespace { namespace identity {
auto apply_on_var_node(
const OpDef&,
const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 1);
return opr::Identity::make(inputs[0]);
}
OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // identity
namespace { namespace uniform_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const UniformRNG&>(def);
mgb_assert(inputs.size() == 1);
return opr::UniformRNG::make(inputs[0], op.param());
}
OP_TRAIT_REG(UniformRNG, UniformRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // uniform_rng
namespace { namespace gaussian_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const GaussianRNG&>(def);
mgb_assert(inputs.size() == 1);
return opr::GaussianRNG::make(inputs[0], op.param());
}
OP_TRAIT_REG(GaussianRNG, GaussianRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // gaussian_rng
namespace { namespace roi_align {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2);
return opr::ROIAlign::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(ROIAlign, ROIAlign)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // roi_align
#if MGB_CUDA
namespace { namespace nvof {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const NvOf&>(def);
mgb_assert(inputs.size() == 1);
return opr::NvOf::make(inputs[0], op.param());
}
OP_TRAIT_REG(NvOf, NvOf)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // nvof
#endif
namespace { namespace linspace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Linspace&>(def);
mgb_assert(inputs.size() == 3);
cg::OperatorNodeConfig config{op.comp_node};
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
}
OP_TRAIT_REG(Linspace, Linspace)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // linspace
namespace { namespace eye {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Eye&>(def);
mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node};
opr::Eye::Param param{op.k, op.dtype.enumv()};
return opr::Eye::make(inputs[0], param, config);
}
OP_TRAIT_REG(Eye, Eye)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // eye
namespace { namespace roi_pooling {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3);
return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param());
}
OP_TRAIT_REG(ROIPooling, ROIPooling)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // roi_pooling
namespace { namespace remap {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Remap&>(def);
mgb_assert(inputs.size() == 2);
return opr::Remap::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(Remap, Remap)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // remap
namespace { namespace reshape {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2);
return opr::Reshape::make(inputs[0], inputs[1], op.param());
}
OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // reshape
namespace {
auto get_index(
const VarNodeArray& inputs, size_t vidx,
const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) {
size_t length = mask.size();
opr::Subtensor::IndexDesc ret(length);
for (size_t i = 0; i < length; ++ i) {
auto&& [axis, begin, end, step, idx] = mask[i];
ret[i].axis = axis;
if (idx) {
ret[i].idx = inputs[vidx++];
} else {
mgb_assert(begin || end || step);
if (begin) ret[i].begin = inputs[vidx++];
if (end) ret[i].end = inputs[vidx++];
if (step) ret[i].step = inputs[vidx++];
}
}
mgb_assert(vidx == inputs.size());
return ret;
}
#define IN1 inputs[0]
#define IN2 inputs[0], inputs[1]
#define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \
namespace NAME##_impl { \
auto apply_on_var_node( \
const OpDef& def, \
const VarNodeArray& inputs) { \
auto&& op = static_cast<const NAME&>(def); \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \
} \
OP_TRAIT_REG(NAME, NAME) \
.apply_on_var_node(apply_on_var_node) \
.fallback(); \
}
FANCY_INDEXING_IMPL(Subtensor, 1)
FANCY_INDEXING_IMPL(SetSubtensor, 2)
FANCY_INDEXING_IMPL(IncrSubtensor, 2)
FANCY_INDEXING_IMPL(IndexingMultiAxisVec, 1)
FANCY_INDEXING_IMPL(IndexingSetMultiAxisVec, 2)
FANCY_INDEXING_IMPL(IndexingIncrMultiAxisVec, 2)
FANCY_INDEXING_IMPL(MeshIndexing, 1)
FANCY_INDEXING_IMPL(IncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(SetMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedMeshIndexing, 1)
FANCY_INDEXING_IMPL(BatchedIncrMeshIndexing, 2)
FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2)
#undef FANCY_INDEXING_IMPL
#undef IN1
#undef IN2
} // anonymous namespace
namespace { namespace fake_quant {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const FakeQuant&>(def);
mgb_assert(inputs.size() == 3);
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param());
}
OP_TRAIT_REG(FakeQuant, FakeQuant)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fake_quant
namespace { namespace elemwise_multi_type {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const ElemwiseMultiType&>(def);
OperatorNodeConfig config{op.dtype};
return opr::ElemwiseMultiType::make(inputs, op.param(), config);
}
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // fake_quant
namespace { namespace svd {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const SVD&>(def);
mgb_assert(inputs.size() == 1);
return opr::SVD::make(inputs[0], op.param());
}
OP_TRAIT_REG(SVD, SVD)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // svd
} // namespace mgb::imperative
\ No newline at end of file
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megbrain/imperative/ops/tensor_manip.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "../op_trait.h" #include "../op_trait.h"
...@@ -140,8 +140,4 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) ...@@ -140,8 +140,4 @@ OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat)
.fallback(); .fallback();
} // namespace } // namespace
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GetVarShape);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat);
} // namespace mgb::imperative } // namespace mgb::imperative
...@@ -130,7 +130,7 @@ void Profiler::start(uint32_t flags) { ...@@ -130,7 +130,7 @@ void Profiler::start(uint32_t flags) {
// TODO: assign parent // TODO: assign parent
entry.parent = 0; entry.parent = 0;
// Record apply context and save to m_profile // Record apply context and save to m_profile
entry.op = def.copy(); entry.op = const_cast<OpDef&>(def).shared_from_this();
for (auto&& input : inputs) { for (auto&& input : inputs) {
entry.inputs.push_back({m_tensor_recorder.record_tensor(input), entry.inputs.push_back({m_tensor_recorder.record_tensor(input),
shape2vector(input->layout()), shape2vector(input->layout()),
...@@ -172,31 +172,31 @@ void Profiler::start(uint32_t flags) { ...@@ -172,31 +172,31 @@ void Profiler::start(uint32_t flags) {
if (flags & PROFILE_FOOTPRINT) { if (flags & PROFILE_FOOTPRINT) {
hook_apply_on_var_node->apply_hook( hook_apply_on_var_node->apply_hook(
[this](auto&& apply, const OpDef& def, [this](auto&& apply, const OpDef& def,
VarNodeArray inputs) -> cg::OperatorNodeBase* { VarNodeArray inputs) -> VarNodeArray {
auto* operator_node = apply(def, std::move(inputs)); auto vars = apply(def, std::move(inputs));
std::remove_reference_t<decltype(m_entry_stack.top())> std::remove_reference_t<decltype(m_entry_stack.top())>
top; top;
{ {
MGB_LOCK_GUARD(m_lock); MGB_LOCK_GUARD(m_lock);
if (m_entry_stack.empty()) { if (m_entry_stack.empty()) {
return operator_node; return vars;
} }
top = m_entry_stack.top(); top = m_entry_stack.top();
} }
auto [current_op, current_entry, thread_id] = top; auto [current_op, current_entry, thread_id] = top;
if (current_op != &def || if (current_op != &def ||
thread_id != std::this_thread::get_id()) { thread_id != std::this_thread::get_id()) {
return operator_node; return vars;
} }
auto&& footprint_result = auto&& footprint_result =
footprint.calc_footprint(operator_node); footprint.calc_footprint(vars[0]->owner_opr());
current_entry->memory = footprint_result.memory; current_entry->memory = footprint_result.memory;
current_entry->computation = current_entry->computation =
footprint_result.computation; footprint_result.computation;
#if MGB_ENABLE_JSON #if MGB_ENABLE_JSON
current_entry->param = footprint_result.param; current_entry->param = footprint_result.param;
#endif #endif
return operator_node; return vars;
}); });
} }
m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor));
......
...@@ -590,7 +590,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( ...@@ -590,7 +590,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(
for (size_t i = 0; i < inputs.size(); ++ i) { for (size_t i = 0; i < inputs.size(); ++ i) {
vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node();
} }
auto opr = OpDef::apply_on_var_node(opdef, vinputs); auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
mgb_assert(!opr->same_type<InputPlaceholder>()); mgb_assert(!opr->same_type<InputPlaceholder>());
for (auto &&i : opr->input()) { for (auto &&i : opr->input()) {
mgb_assert(i->owner_opr()->same_type<InputPlaceholder>()); mgb_assert(i->owner_opr()->same_type<InputPlaceholder>());
...@@ -639,7 +639,7 @@ ProxyGraph::make_backward_graph( ...@@ -639,7 +639,7 @@ ProxyGraph::make_backward_graph(
return ret.first->second; return ret.first->second;
}; };
auto inputs = make_input_place_holders(input_descs); auto inputs = make_input_place_holders(input_descs);
auto fwd = OpDef::apply_on_var_node(opdef, inputs); auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr();
auto&& outputs = fwd->usable_output(); auto&& outputs = fwd->usable_output();
SmallVector<LogicalTensorDesc> output_descs; SmallVector<LogicalTensorDesc> output_descs;
for (auto&& i : outputs) { for (auto&& i : outputs) {
...@@ -799,7 +799,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef, ...@@ -799,7 +799,7 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
mgb_assert(!m_cur_opr); mgb_assert(!m_cur_opr);
auto vinputs = make_input_place_holders(inputs); auto vinputs = make_input_place_holders(inputs);
return OpDef::apply_on_var_node(opdef, vinputs); return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr();
} }
VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTensorDesc>& inputs) { VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTensorDesc>& inputs) {
......
...@@ -26,13 +26,12 @@ struct BackwardGraphResult { ...@@ -26,13 +26,12 @@ struct BackwardGraphResult {
std::vector<bool> input_has_grad; std::vector<bool> input_has_grad;
}; };
class OpDef : public Hashable { class OpDef : public Hashable,
public std::enable_shared_from_this<OpDef> {
mutable const OpTrait* m_trait = nullptr; mutable const OpTrait* m_trait = nullptr;
public: public:
virtual ~OpDef() = default; virtual ~OpDef() = default;
virtual std::shared_ptr<OpDef> copy() const = 0;
static std::shared_ptr<OpDef> make_from_op_node( static std::shared_ptr<OpDef> make_from_op_node(
cg::OperatorNodeBase* node); cg::OperatorNodeBase* node);
...@@ -40,7 +39,7 @@ public: ...@@ -40,7 +39,7 @@ public:
const OpDef& def, const OpDef& def,
const SmallVector<TensorPtr>& inputs); const SmallVector<TensorPtr>& inputs);
static cg::OperatorNodeBase* apply_on_var_node( static cg::VarNodeArray apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs); const VarNodeArray& inputs);
...@@ -56,25 +55,17 @@ public: ...@@ -56,25 +55,17 @@ public:
const OpTrait* trait() const; const OpTrait* trait() const;
virtual size_t hash() const { virtual size_t hash() const;
mgb_throw(MegBrainError, "not implemented");
}
virtual bool is_same_st(const Hashable&) const { virtual bool is_same_st(const Hashable&) const;
mgb_throw(MegBrainError, "not implemented");
}
}; };
template<typename T> template<typename T>
class OpDefImplBase : public OpDef { class OpDefImplBase : public OpDef {
public: public:
virtual std::shared_ptr<OpDef> copy() const override {
return std::shared_ptr<OpDef>(new T(this->cast_final_safe<T>()));
}
template<typename ...Args> template<typename ...Args>
static std::shared_ptr<OpDef> make(const Args& ...args) { static std::shared_ptr<OpDef> make(Args&& ...args) {
return std::shared_ptr<OpDef>(new T(args...)); return std::make_shared<T>(std::forward<Args>(args)...);
} }
}; };
......
/** /**
* \file imperative/src/include/megbrain/imperative/ops/cond_take.h * \file imperative/src/include/megbrain/imperative/ops/autogen.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -12,22 +12,15 @@ ...@@ -12,22 +12,15 @@
#pragma once #pragma once
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"
namespace mgb::imperative { #include "megbrain/utils/hash.h"
class CondTake : public OpDefImplBase<CondTake> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
CondTake() = default;
size_t hash() const override { namespace mgb::imperative {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return rhs.dyn_typeinfo() == dyn_typeinfo();
}
}; // TODO: split into separate files to avoid recompiling all
// impl/ops/*.cpp on each modification of ops.td
#include "./opdef.h.inl"
} // namespace mgb::imperative } // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/batch_norm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
namespace mgb::imperative {
class BatchNorm : public OpDefImplBase<BatchNorm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Param = opr::BatchNorm::Param;
Param::ParamDim param_dim;
Param::FwdMode fwd_mode;
double epsilon;
double avg_factor;
float scale;
float bias;
BatchNorm() = default;
BatchNorm(const Param::ParamDim& param_dim_, const Param::FwdMode& fwd_mode_,
double epsilon_, double avg_factor_, float scale_, float bias_)
: param_dim(param_dim_),
fwd_mode(fwd_mode_),
epsilon(epsilon_),
avg_factor(avg_factor_),
scale(scale_),
bias(bias_) {}
size_t hash() const override {
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(param_dim);
append(fwd_mode);
append(epsilon);
append(avg_factor);
append(scale);
append(bias);
return xxhash.digest();
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const BatchNorm&>(rhs_);
return rhs.param_dim == param_dim
&& rhs.fwd_mode == fwd_mode
&& rhs.epsilon == epsilon
&& rhs.avg_factor == avg_factor
&& rhs.scale == scale
&& rhs.bias == bias;
}
};
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/collective_comm.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/opr/param_defs.h"
namespace mgb {
namespace imperative {
class CollectiveComm : public OpDefImplBase<CollectiveComm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = megdnn::param::CollectiveComm::Mode;
CollectiveComm() = default;
CollectiveComm(const std::string& key_, size_t nr_devices_,
uint32_t rank_, bool is_root_, bool local_grad_,
const std::string& addr_, uint32_t port_,
const Mode& mode_,
const DType& dtype_, const std::string& backend_,
const std::string& comp_node_)
: key(key_),
nr_devices(nr_devices_),
rank(rank_),
is_root(is_root_),
local_grad(local_grad_),
addr(addr_),
port(port_),
mode(mode_),
dtype(dtype_),
backend(backend_),
comp_node(comp_node_) {}
std::string key;
size_t nr_devices;
uint32_t rank;
bool is_root;
bool local_grad;
std::string addr;
uint32_t port;
Mode mode;
DType dtype;
std::string backend;
std::string comp_node;
size_t hash() const override;
bool is_same_st(const Hashable& another) const override;
auto as_tuple() const{
return std::tuple(key, nr_devices, rank, is_root,
local_grad, addr, port, mode, dtype,
backend, comp_node);
}
};
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/include/megbrain/imperative/ops/elemwise.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative/op_def.h"
namespace mgb::imperative {
class Elemwise : public OpDefImplBase<Elemwise> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = opr::Elemwise::Mode;
using ModeTrait = megdnn::Elemwise::ModeTrait;
Mode mode;
Elemwise() = default;
Elemwise(const Mode& mode_): mode(mode_) {}
size_t hash() const override {
return hash_pair_combine(mgb::hash(mode), reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const Elemwise&>(rhs_);
return rhs.mode == mode;
}
};
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/io_remote.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
namespace mgb {
namespace imperative {
class RemoteSend : public OpDefImplBase<RemoteSend> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
RemoteSend() = default;
RemoteSend(const std::string& key_, const std::string& addr_,
uint32_t port_, uint32_t rank_to_)
: key(key_),
addr(addr_),
port(port_),
rank_to(rank_to_) {}
std::string key;
std::string addr;
uint32_t port;
uint32_t rank_to;
size_t hash() const override;
bool is_same_st(const Hashable& another) const override;
auto as_tuple() const{
return std::tuple(key, addr, port, rank_to);
}
};
class RemoteRecv : public OpDefImplBase<RemoteRecv> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
RemoteRecv() = default;
RemoteRecv(const std::string& key_, const std::string& addr_,
uint32_t port_, uint32_t rank_from_, TensorShape shape_,
CompNode cn_, const DType& dtype_)
: key(key_),
addr(addr_),
port(port_),
rank_from(rank_from_),
cn(cn_),
shape(shape_),
dtype(dtype_) {}
std::string key;
std::string addr;
uint32_t port;
uint32_t rank_from;
CompNode cn;
TensorShape shape;
DType dtype;
size_t hash() const override;
bool is_same_st(const Hashable& another) const override;
auto as_tuple() const{
return std::tuple(key, addr, port, rank_from, cn, dtype, shape.to_string());
}
};
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/include/megbrain/imperative/ops/nms.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
namespace mgb::imperative {
class NMSKeep : public OpDefImplBase<NMSKeep> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float iou_thresh; //!< IoU threshold for overlapping
uint32_t max_output; //!< max number of output boxes per batch
NMSKeep() = default;
NMSKeep(float iou_thresh_, uint32_t max_output_):
iou_thresh(iou_thresh_), max_output(max_output_) {}
size_t hash() const override {
return hash_pair_combine(
hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)),
reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const NMSKeep&>(rhs_);
return rhs.iou_thresh == iou_thresh
&& rhs.max_output == max_output;
}
};
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/ops/tensor_manip.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/utils/hash.h"
namespace mgb::imperative {
class GetVarShape : public OpDefImplBase<GetVarShape> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
GetVarShape() = default;
size_t hash() const override {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return rhs.dyn_typeinfo() == dyn_typeinfo();
}
};
class ParamPackSplit : public OpDefImplBase<ParamPackSplit> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
ParamPackSplit() = default;
ParamPackSplit(std::vector<dt_int32>& offsets_,
std::vector<std::vector<size_t>>& shapes_)
: offsets(offsets_), shapes(shapes_) {}
std::vector<dt_int32> offsets;
std::vector<std::vector<size_t>> shapes;
size_t hash() const override {
XXHash builder;
for (auto&& offset : offsets) {
builder.update(&offset, sizeof(offset));
}
auto&& offset_cnt = offsets.size();
builder.update(&offset_cnt, sizeof(offset_cnt));
for (auto&& shape : shapes) {
for (auto&& dim_len : shape) {
builder.update(&dim_len, sizeof(dim_len));
}
auto&& dim_cnt = shape.size();
builder.update(&dim_cnt, sizeof(dim_cnt));
}
auto&& shape_cnt = shapes.size();
builder.update(&shape_cnt, sizeof(shape_cnt));
return builder.digest();
}
bool is_same_st(const Hashable& rhs) const override {
auto&& pps = rhs.cast_final_safe<ParamPackSplit>();
return offsets == pps.offsets && shapes == pps.shapes;
}
};
class ParamPackConcat : public OpDefImplBase<ParamPackConcat> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
ParamPackConcat() = default;
ParamPackConcat(std::vector<dt_int32>& offsets_)
: offsets(offsets_) {}
std::vector<dt_int32> offsets;
size_t hash() const override {
XXHash builder;
for (auto&& offset : offsets) {
builder.update(&offset, sizeof(offset));
}
auto&& offset_cnt = offsets.size();
builder.update(&offset_cnt, sizeof(offset_cnt));
return builder.digest();
}
bool is_same_st(const Hashable& rhs) const override {
auto&& ppc = rhs.cast_final_safe<ParamPackConcat>();
return offsets == ppc.offsets;
}
};
} // namespace mgb::imperative
...@@ -29,18 +29,18 @@ TEST(TestImperative, BackwardGraphBasic) { ...@@ -29,18 +29,18 @@ TEST(TestImperative, BackwardGraphBasic) {
using Param = opr::Elemwise::Param; using Param = opr::Elemwise::Param;
Param param{Param::Mode::MUL}; Param param{Param::Mode::MUL};
OprAttr attr{"Elemwise", {}, {}}; auto attr = OprAttr::make("Elemwise");
attr.param.write_pod(param); attr->cast_final_safe<OprAttr>().param.write_pod(param);
SmallVector<LogicalTensorDesc> input_descs; SmallVector<LogicalTensorDesc> input_descs;
for (auto&& i : inputs) { for (auto&& i : inputs) {
input_descs.push_back({i->layout(), i->comp_node()}); input_descs.push_back({i->layout(), i->comp_node()});
} }
auto result = OpDef::make_backward_graph(attr, input_descs, {true, true}, {true}); auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true});
auto&& save_for_backward = result.save_for_backward; auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad; auto&& input_has_grad = result.input_has_grad;
auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
inputs.push_back(outputs[0]); inputs.push_back(outputs[0]);
hvs.push_back(*gen({42})); hvs.push_back(*gen({42}));
inputs.push_back(Tensor::make(hvs.back())); inputs.push_back(Tensor::make(hvs.back()));
...@@ -82,16 +82,16 @@ TEST(TestImperative, BackwardGraphIdentity) { ...@@ -82,16 +82,16 @@ TEST(TestImperative, BackwardGraphIdentity) {
SmallVector<TensorPtr> inputs; SmallVector<TensorPtr> inputs;
inputs.push_back(a); inputs.push_back(a);
OprAttr attr{"Identity", {}, {}}; auto attr = OprAttr::make("Identity");
attr.param.write_pod<megdnn::param::Empty>({}); attr->cast_final_safe<OprAttr>().param.write_pod<megdnn::param::Empty>({});
SmallVector<LogicalTensorDesc> input_descs; SmallVector<LogicalTensorDesc> input_descs;
input_descs.push_back({a->layout(), a->comp_node()}); input_descs.push_back({a->layout(), a->comp_node()});
auto result = OpDef::make_backward_graph(attr, input_descs, {true}, {true}); auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true});
auto&& save_for_backward = result.save_for_backward; auto&& save_for_backward = result.save_for_backward;
auto&& input_has_grad = result.input_has_grad; auto&& input_has_grad = result.input_has_grad;
auto outputs = OpDef::apply_on_physical_tensor(attr, inputs); auto outputs = OpDef::apply_on_physical_tensor(*attr, inputs);
inputs.push_back(outputs[0]); inputs.push_back(outputs[0]);
inputs.push_back(dc); inputs.push_back(dc);
mgb_assert(save_for_backward.size() == inputs.size()); mgb_assert(save_for_backward.size() == inputs.size());
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
*/ */
#include "./helper.h" #include "./helper.h"
#include "megbrain/imperative/ops/collective_comm.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/mm_handler.h" #include "megbrain/opr/mm_handler.h"
using namespace mgb; using namespace mgb;
...@@ -32,12 +32,13 @@ TEST(TestImperative, AllReduceBasic) { ...@@ -32,12 +32,13 @@ TEST(TestImperative, AllReduceBasic) {
} }
auto run = [&](std::shared_ptr<HostTensorND> hnd, uint32_t idx) { auto run = [&](std::shared_ptr<HostTensorND> hnd, uint32_t idx) {
imperative::CollectiveComm auto def =
def{"all_reduce", 2, idx, idx==0, false, server_addr, port, imperative::CollectiveComm::make(
megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM, megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM,
dtype::Float32(), "nccl", ""}; "all_reduce", 2, idx, idx==0, false, server_addr, port,
dtype::Float32(), "nccl", "");
auto inp = Tensor::make(*hnd); auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(def, {inp}); auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
HostTensorND host_v; HostTensorND host_v;
host_v.copy_from(oup[0]->dev_tensor()).sync(); host_v.copy_from(oup[0]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6);
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
*/ */
#include "./helper.h" #include "./helper.h"
#include "megbrain/imperative/ops/cond_take.h" #include "megbrain/imperative/ops/autogen.h"
using namespace mgb; using namespace mgb;
using namespace imperative; using namespace imperative;
......
...@@ -119,7 +119,7 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) { ...@@ -119,7 +119,7 @@ void OprChecker::run(std::vector<InputSpec> inp_keys) {
}, inp_keys[i]); }, inp_keys[i]);
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node(); sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
} }
auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp)->usable_output(); auto sym_oup = OpDef::apply_on_var_node(*m_op, sym_inp);
size_t nr_oups = sym_oup.size(); size_t nr_oups = sym_oup.size();
ComputingGraph::OutputSpec oup_spec(nr_oups); ComputingGraph::OutputSpec oup_spec(nr_oups);
SmallVector<HostTensorND> host_sym_oup(nr_oups); SmallVector<HostTensorND> host_sym_oup(nr_oups);
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
*/ */
#include "./helper.h" #include "./helper.h"
#include "megbrain/imperative/ops/io_remote.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/mm_handler.h" #include "megbrain/opr/mm_handler.h"
using namespace mgb; using namespace mgb;
...@@ -33,24 +33,19 @@ TEST(TestImperative, IORemote) { ...@@ -33,24 +33,19 @@ TEST(TestImperative, IORemote) {
} }
auto run_send = [&](std::shared_ptr<HostTensorND> hnd) { auto run_send = [&](std::shared_ptr<HostTensorND> hnd) {
imperative::RemoteSend def{"io_remote_test", server_addr, port, 1}; auto def = imperative::RemoteSend::make(
"io_remote_test", server_addr, port, 1);
auto inp = Tensor::make(*hnd); auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(def, {inp}); auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
}; };
auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) { auto run_recv = [&](std::shared_ptr<HostTensorND> hnd) {
// auto&& shape = std::initializer_list{vector_size}; auto def = imperative::RemoteRecv::make(
imperative::RemoteRecv def{"io_remote_test", "io_remote_test", server_addr, port, 0,
server_addr, CompNode::load("gpu1"), TensorShape{vector_size},
port, dtype::Float32());
0,
{
vector_size,
},
CompNode::load("gpu1"),
dtype::Float32()};
auto inp = Tensor::make(*hnd); auto inp = Tensor::make(*hnd);
auto oup = OpDef::apply_on_physical_tensor(def, {inp}); auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
HostTensorND host_v; HostTensorND host_v;
host_v.copy_from(oup[0]->dev_tensor()).sync(); host_v.copy_from(oup[0]->dev_tensor()).sync();
MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6); MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6);
......
# mgb tablegen executable
set(TABLE_TARGET mgb-mlir-autogen)
add_executable(${TABLE_TARGET} autogen.cpp)
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR})
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport)
set(MGB_TABLEGEN_EXE ${TABLE_TARGET})
# generate megbrain opdef c header and python bindings
set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td)
tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header")
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body")
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding")
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl param_defs_tblgen)
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR} PARENT_SCOPE)
#include <iostream>
#include <unordered_map>
#include <functional>
#include "./helper.h"
using llvm::raw_ostream;
using llvm::RecordKeeper;
enum ActionType {
None,
CppHeader,
CppBody,
Pybind
};
// NOLINTNEXTLINE
llvm::cl::opt<ActionType> action(
llvm::cl::desc("Action to perform:"),
llvm::cl::values(clEnumValN(CppHeader, "gen-cpp-header",
"Generate operator cpp header"),
clEnumValN(CppBody, "gen-cpp-body",
"Generate operator cpp body"),
clEnumValN(Pybind, "gen-python-binding",
"Generate pybind11 python bindings")));
using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase;
using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin;
using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin;
using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin;
using MgbOp = mlir::tblgen::MgbOpBase;
using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin;
llvm::StringRef attr_to_ctype(const mlir::tblgen::Attribute& attr_) {
// Note: we have already registered the corresponding attr wrappers
// for following basic ctypes so we needn't handle them here
/* auto&& attr_type_name = attr.getAttrDefName();
if (attr_type_name == "UI32Attr") {
return "uint32_t";
}
if (attr_type_name == "UI64Attr") {
return "uint64_t";
}
if (attr_type_name == "I32Attr") {
return "int32_t";
}
if (attr_type_name == "F32Attr") {
return "float";
}
if (attr_type_name == "F64Attr") {
return "double";
}
if (attr_type_name == "StrAttr") {
return "std::string";
}
if (attr_type_name == "BoolAttr") {
return "bool";
}*/
auto&& attr = llvm::cast<MgbAttrWrapper>(attr_);
if (auto e = llvm::dyn_cast<MgbEnumAttr>(&attr)) {
return e->getEnumName();
}
return attr.getUnderlyingType();
}
static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
os << formatv(
"class {0} : public OpDefImplBase<{0}> {{\n"
" MGB_DYN_TYPE_OBJ_FINAL_DECL;\n\n"
"public:\n",
op.getCppClassName()
);
// handle enum alias
for (auto &&i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
os << formatv(
" using {0} = {1};\n",
attr->getEnumName(), attr->getUnderlyingType()
);
}
}
for (auto &&i : op.getMgbAttributes()) {
auto defaultValue = i.attr.getDefaultValue().str();
if (!defaultValue.empty()) {
defaultValue = formatv(" = {0}", defaultValue);
}
os << formatv(
" {0} {1}{2};\n",
attr_to_ctype(i.attr), i.name, defaultValue
);
}
auto gen_ctor = [&](auto&& paramList, auto&& memInitList, auto&& body) {
os << formatv(
" {0}({1}){2}{3}\n",
op.getCppClassName(), paramList, memInitList, body
);
};
gen_ctor("", "", " = default;");
if (!op.getMgbAttributes().empty()) {
std::vector<std::string> paramList, initList;
for (auto &&i : op.getMgbAttributes()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
gen_ctor(llvm::join(paramList, ", "),
": " + llvm::join(initList, ", "),
" {}");
}
auto packedParams = op.getPackedParams();
if (!packedParams.empty()) {
std::vector<std::string> paramList, initList;
for (auto &&p : packedParams) {
auto&& paramFields = p.getFields();
auto&& paramType = p.getFullName();
auto&& paramName = formatv("packed_param_{0}", paramList.size());
paramList.push_back(
paramFields.empty() ? paramType.str()
: formatv("{0} {1}", paramType, paramName)
);
for (auto&& i : paramFields) {
initList.push_back(formatv(
"{0}({1}.{0})", i.name, paramName
));
}
}
for (auto&& i : op.getExtraArguments()) {
paramList.push_back(formatv(
"{0} {1}_", attr_to_ctype(i.attr), i.name
));
initList.push_back(formatv(
"{0}({0}_)", i.name
));
}
gen_ctor(llvm::join(paramList, ", "),
initList.empty() ? "" : ": " + llvm::join(initList, ", "),
" {}");
}
if (!packedParams.empty()) {
for (auto&& p : packedParams) {
auto accessor = p.getAccessor();
if (!accessor.empty()) {
os << formatv(
" {0} {1}() const {{\n",
p.getFullName(), accessor
);
std::vector<llvm::StringRef> fields;
for (auto&& i : p.getFields()) {
fields.push_back(i.name);
}
os << formatv(
" return {{{0}};\n",
llvm::join(fields, ", ")
);
os << " }\n";
}
}
}
if (auto decl = op.getExtraOpdefDecl()) {
os << decl.getValue();
}
os << formatv(
"};\n\n"
);
}
static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
auto&& className = op.getCppClassName();
os << formatv(
"MGB_DYN_TYPE_OBJ_FINAL_IMPL({0});\n\n", className
);
auto formatMethImpl = [&](auto&& meth) {
return formatv(
"{0}_{1}_impl", className, meth
);
};
std::vector<std::string> methods;
if (auto hashable = llvm::dyn_cast<MgbHashableOp>(&op)) {
os << "namespace {\n";
// generate hash()
mlir::tblgen::FmtContext ctx;
os << formatv(
"size_t {0}(const OpDef& def_) {{\n",
formatMethImpl("hash")
);
os << formatv(
" auto op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
className
);
ctx.withSelf("op_");
os << mlir::tblgen::tgfmt(hashable->getHashFunctionTemplate(), &ctx);
os << "}\n";
// generate is_same_st()
os << formatv(
"bool {0}(const OpDef& lhs_, const OpDef& rhs_) {{\n",
formatMethImpl("is_same_st")
);
os << formatv(
" auto a_ = lhs_.cast_final_safe<{0}>(),\n"
" b_ = rhs_.cast_final_safe<{0}>();\n"
" static_cast<void>(a_);\n"
" static_cast<void>(b_);\n",
className
);
os << mlir::tblgen::tgfmt(hashable->getCmpFunctionTemplate(), &ctx, "a_", "b_");
os << "}\n";
os << "} // anonymous namespace\n";
methods.push_back("hash");
methods.push_back("is_same_st");
}
if (!methods.empty()) {
os << formatv(
"OP_TRAIT_REG({0}, {0})", op.getCppClassName()
);
for (auto&& i : methods) {
os << formatv(
"\n .{0}({1})", i, formatMethImpl(i)
);
}
os << ";\n\n";
}
}
struct PybindContext {
std::unordered_map<unsigned int, std::string> enumAlias;
};
static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, PybindContext& ctx) {
auto class_name = op.getCppClassName();
os << formatv(
"py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n",
class_name
);
for (auto&& i : op.getMgbAttributes()) {
if (auto attr = llvm::dyn_cast<MgbEnumAttr>(&i.attr)) {
unsigned int enumID;
if (auto alias = llvm::dyn_cast<MgbAliasAttr>(attr)) {
auto&& aliasBase = alias->getAliasBase();
enumID =
llvm::cast<MgbEnumAttr>(aliasBase)
.getBaseRecord()->getID();
} else {
enumID = attr->getBaseRecord()->getID();
}
auto&& enumAlias = ctx.enumAlias;
auto&& iter = enumAlias.find(enumID);
if (iter == enumAlias.end()) {
os << formatv(
"py::enum_<{0}::{1}>({0}Inst, \"{1}\")",
class_name, attr->getEnumName()
);
std::vector<std::string> body;
for (auto&& i: attr->getEnumMembers()) {
os << formatv(
"\n .value(\"{2}\", {0}::{1}::{2})",
class_name, attr->getEnumName(), i
);
body.push_back(formatv(
"if (str == \"{2}\") return {0}::{1}::{2};",
class_name, attr->getEnumName(), i
));
}
os << formatv(
"\n .def(py::init([](const std::string& in) {"
"\n auto&& str = normalize_enum(in);"
"\n {0}"
"\n throw py::cast_error(\"invalid enum value \" + in);"
"\n }));\n",
llvm::join(body, "\n ")
);
os << formatv(
"py::implicitly_convertible<std::string, {0}::{1}>();\n\n",
class_name, attr->getEnumName()
);
enumAlias.emplace(enumID, formatv(
"{0}Inst.attr(\"{1}\")", class_name, attr->getEnumName()
));
} else {
os << formatv(
"{0}Inst.attr(\"{1}\") = {2};\n\n",
class_name, attr->getEnumName(), iter->second
);
}
}
}
// generate op class binding
os << formatv("{0}Inst", class_name);
bool hasDefaultCtor = op.getMgbAttributes().empty();
if (!hasDefaultCtor) {
os << "\n .def(py::init<";
std::vector<llvm::StringRef> targs;
for (auto &&i : op.getMgbAttributes()) {
targs.push_back(i.attr.getReturnType());
}
os << llvm::join(targs, ", ");
os << ">()";
for (auto &&i : op.getMgbAttributes()) {
os << formatv(", py::arg(\"{0}\")", i.name);
auto defaultValue = i.attr.getDefaultValue();
if (!defaultValue.empty()) {
os << formatv(" = {0}", defaultValue);
} else {
hasDefaultCtor = true;
}
}
os << ")";
}
if (hasDefaultCtor) {
os << "\n .def(py::init<>())";
}
for (auto &&i : op.getMgbAttributes()) {
os << formatv(
"\n .def_readwrite(\"{0}\", &{1}::{0})",
i.name, class_name
);
}
os << ";\n\n";
}
static void for_each_operator(raw_ostream &os, RecordKeeper &keeper,
std::function<void(raw_ostream&, MgbOp&)> callback) {
auto op_base_class = keeper.getClass("Op");
ASSERT(op_base_class, "could not find base class Op");
for (auto&& i: keeper.getDefs()) {
auto&& r = i.second;
if (r->isSubClassOf(op_base_class)) {
auto op = mlir::tblgen::Operator(r.get());
if (op.getDialectName().str() == "mgb") {
std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl;
callback(os, llvm::cast<MgbOp>(op));
}
}
}
}
static bool gen_op_def_c_header(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_header_single);
return false;
}
static bool gen_op_def_c_body(raw_ostream &os, RecordKeeper &keeper) {
for_each_operator(os, keeper, gen_op_def_c_body_single);
return false;
}
static bool gen_op_def_pybind11(raw_ostream &os, RecordKeeper &keeper) {
PybindContext ctx;
using namespace std::placeholders;
for_each_operator(os, keeper,
std::bind(gen_op_def_pybind11_single, _1, _2, std::ref(ctx)));
return false;
}
int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
if (action == ActionType::CppHeader) {
return TableGenMain(argv[0], &gen_op_def_c_header);
}
if (action == ActionType::CppBody) {
return TableGenMain(argv[0], &gen_op_def_c_body);
}
if (action == ActionType::Pybind) {
return TableGenMain(argv[0], &gen_op_def_pybind11);
}
return -1;
}
\ No newline at end of file
此差异已折叠。
...@@ -11,7 +11,7 @@ endif() ...@@ -11,7 +11,7 @@ endif()
# TODO: turn python binding into a static/object library # TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS}) add_executable(imperative_test ${SOURCES} ${SRCS})
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include) target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR})
# Python binding # Python binding
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR})
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册