提交 013bb14f 编写于 作者: M Megvii Engine Team

build(tablegen): pregen opdef tablegen targets

GitOrigin-RevId: ab7564df14099007d188af4ba7cff28a0cda526c
上级 f12b75c0
......@@ -66,8 +66,8 @@ target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json)
target_include_directories(
${MODULE_NAME}
PUBLIC src/include
PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}
${CPP_REDIS_INCLUDES})
PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${CPP_REDIS_INCLUDES})
target_link_libraries(${MODULE_NAME} PRIVATE mgb_opdef_inc)
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME})
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter)
if(CXX_SUPPORT_WCLASS_MEMACCESS)
......
# mgb tablegen executable
set(TABLE_TARGET mgb-mlir-autogen)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h
set(MGB_OPDEF_OUT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/generated)
set(MGB_OPDEF_OPS_SRC ${CMAKE_SOURCE_DIR}/src/core/include/megbrain/ir/ops.td)
set(MGB_OPDEF_PARAMS_SRC ${CMAKE_SOURCE_DIR}/dnn/scripts/opr_param_defs.py)
# we set CMAKE_CONFIGURE_DEPENDS so that when source files or hash.txt was modified,
# cmake configure would be triggered to update ${MD5_MISMATCH}
execute_process(
COMMAND ${CMAKE_COMMAND} -P ${CMAKE_CURRENT_SOURCE_DIR}/checkhash.cmake
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
ERROR_QUIET
RESULT_VARIABLE MD5_MISMATCH)
if(${MD5_MISMATCH})
# mgb tablegen executable
set(TABLE_TARGET mgb-mlir-autogen)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR}/*.h
${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(${TABLE_TARGET} ${SRCS})
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR})
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport)
set(MGB_TABLEGEN_EXE ${TABLE_TARGET})
add_executable(${TABLE_TARGET} ${SRCS})
target_include_directories(${TABLE_TARGET} PRIVATE ${MLIR_LLVM_INCLUDE_DIR})
target_link_libraries(${TABLE_TARGET} PRIVATE LLVMTableGen MLIRTableGen LLVMSupport)
set(MGB_TABLEGEN_EXE ${TABLE_TARGET})
# generate megbrain opdef c header and python bindings basically same as
# third_party/llvm-project/llvm/cmake/modules/TableGen.cmake but change output folder
# and add extra dependency
set(LLVM_SOURCE_DIR ${CMAKE_SOURCE_DIR}/third_party/llvm-project/llvm)
set(LLVM_BINARY_DIR ${CMAKE_BINARY_DIR}/third_party/llvm-project/llvm)
set(MLIR_SOURCE_DIR ${CMAKE_SOURCE_DIR}/third_party/llvm-project/mlir)
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}/third_party/llvm-project/mlir)
set(MGB_TABLEGEN_INCLUDES
-I${LLVM_SOURCE_DIR}/include
-I${LLVM_BINARY_DIR}/include
-I${MLIR_SOURCE_DIR}/include
-I${MLIR_BINARY_DIR}/include
-I${CMAKE_SOURCE_DIR}/src/core/include/megbrain/ir
-I${CMAKE_BINARY_DIR}/src/core/include/megbrain/ir)
set(MGB_TABLEGEN_FLAGS --write-if-changed)
set(MGB_TABLEGEN_TARGETS)
# generate megbrain opdef c header and python bindings
set(LLVM_TARGET_DEFINITIONS ${MGE_IR_DIR}/ops.td)
tablegen(MGB opdef.h.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-header")
tablegen(MGB opdef.cpp.inl ${MGE_IR_INCLUDE_DIRS} "--gen-cpp-body")
tablegen(MGB opdef.py.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-binding")
tablegen(MGB opdef.cpy.inl ${MGE_IR_INCLUDE_DIRS} "--gen-python-c-extension")
tablegen(MGB enum_macro.h ${MGE_IR_INCLUDE_DIRS} "--gen-enum-list-macro")
add_custom_target(mgb_opdef ALL DEPENDS opdef.h.inl opdef.cpp.inl opdef.py.inl
opdef.cpy.inl enum_macro.h param_defs_tblgen)
set(MGB_OPDEF_OUT_DIR
${CMAKE_CURRENT_BINARY_DIR}
function(tablegen_opdef target output)
add_custom_target(
mgb_opdef_${target}
COMMAND
${MGB_TABLEGEN_EXE} ${MGB_TABLEGEN_INCLUDES} --gen-${target}
${MGB_OPDEF_OPS_SRC} ${MGB_TABLEGEN_FLAGS} -o ${MGB_OPDEF_OUT_DIR}/${output}
DEPENDS param_defs_tblgen)
set(MGB_TABLEGEN_TARGETS
${MGB_TABLEGEN_TARGETS} mgb_opdef_${target}
PARENT_SCOPE)
endfunction()
tablegen_opdef(cpp-header opdef.h.inl)
tablegen_opdef(cpp-body opdef.cpp.inl)
tablegen_opdef(python-binding opdef.py.inl)
tablegen_opdef(python-c-extension opdef.cpy.inl)
tablegen_opdef(enum-list-macro enum_macro.h)
add_custom_target(
mgb_opdef_genhash
${CMAKE_COMMAND} -P genhash.cmake
DEPENDS ${MGB_TABLEGEN_TARGETS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(mgb_opdef DEPENDS ${MGB_TABLEGEN_TARGETS} mgb_opdef_genhash)
else()
# add extra dependencies for auto reconfiguration
set_property(
DIRECTORY
APPEND
PROPERTY CMAKE_CONFIGURE_DEPENDS
${MGB_OPDEF_OPS_SRC}
${MGB_OPDEF_PARAMS_SRC}
generated/opdef.h.inl
generated/opdef.cpp.inl
generated/opdef.py.inl
generated/opdef.cpy.inl
generated/enum_macro
generated/hash.txt)
# additional check for safety
add_custom_target(
mgb_opdef_checkhash
${CMAKE_COMMAND} -P ${CMAKE_CURRENT_SOURCE_DIR}/checkhash.cmake
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(mgb_opdef DEPENDS mgb_opdef_checkhash)
endif()
add_library(mgb_opdef_inc INTERFACE)
target_include_directories(mgb_opdef_inc INTERFACE ${MGB_OPDEF_OUT_DIR})
add_dependencies(mgb_opdef_inc mgb_opdef)
......@@ -3,8 +3,9 @@
#include "./targets/pybind11.h"
#include "./targets/python_c_extension.h"
using llvm::raw_ostream;
using llvm::RecordKeeper;
namespace {
using namespace mlir::tblgen;
enum ActionType { None, CppHeader, CppBody, Pybind, CPython, EnumListMacro };
......@@ -24,25 +25,35 @@ llvm::cl::opt<ActionType> action(
EnumListMacro, "gen-enum-list-macro",
"Generate enum param list macro")));
using namespace mlir::tblgen;
template <llvm::TableGenMainFn* MainFn>
llvm::TableGenMainFn* WrapMain() {
return [](llvm::raw_ostream& os, llvm::RecordKeeper& records) -> bool {
os << "// clang-format off\n";
auto ret = MainFn(os, records);
os << "// clang-format on\n";
return ret;
};
}
} // namespace
int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
if (action == ActionType::CppHeader) {
return TableGenMain(argv[0], &gen_op_def_c_header);
return TableGenMain(argv[0], WrapMain<&gen_op_def_c_header>());
}
if (action == ActionType::CppBody) {
return TableGenMain(argv[0], &gen_op_def_c_body);
return TableGenMain(argv[0], WrapMain<&gen_op_def_c_body>());
}
if (action == ActionType::Pybind) {
return TableGenMain(argv[0], &gen_op_def_pybind11);
return TableGenMain(argv[0], WrapMain<&gen_op_def_pybind11>());
}
if (action == ActionType::CPython) {
return TableGenMain(argv[0], &gen_op_def_python_c_extension);
return TableGenMain(argv[0], WrapMain<&gen_op_def_python_c_extension>());
}
if (action == ActionType::EnumListMacro) {
return TableGenMain(argv[0], &gen_enum_param_list_macro);
return TableGenMain(argv[0], WrapMain<&gen_enum_param_list_macro>());
}
return -1;
}
set(SOURCES
../../dnn/scripts/opr_param_defs.py
../../src/core/include/megbrain/ir/ops.td
generated/opdef.h.inl
generated/opdef.cpp.inl
generated/opdef.py.inl
generated/opdef.cpy.inl
generated/enum_macro.h)
execute_process(COMMAND ${CMAKE_COMMAND} -E md5sum ${SOURCES}
OUTPUT_VARIABLE GENERATED_HASH_CONTENT)
file(READ generated/hash.txt HASH_CONTENT)
if(NOT "${GENERATED_HASH_CONTENT}" STREQUAL "${HASH_CONTENT}")
message(FATAL_ERROR "File ops.td was changed, please rerun cmake configure")
endif()
// clang-format off
#define FOR_EACH_ENUM_PARAM(cb) \
cb(::megdnn::param::PoolingV0::Mode); \
cb(::megdnn::param::Convolution::Format); \
cb(::megdnn::param::Argsort::Order); \
cb(::megdnn::param::ConvBiasV0::NonlineMode); \
cb(::megdnn::param::ConvolutionV0::Mode); \
cb(::megdnn::param::ConvolutionV0::Sparse); \
cb(::megdnn::param::ConvolutionV1::ComputeMode); \
cb(::megdnn::param::BN::ParamDim); \
cb(::megdnn::param::BN::FwdMode); \
cb(::megdnn::param::MatrixMulV1::ComputeMode); \
cb(::megdnn::param::MatrixMul::Format); \
cb(::megdnn::param::CollectiveComm::Mode); \
cb(::megdnn::param::Convolution3D::Mode); \
cb(::megdnn::param::Convolution3D::Sparse); \
cb(::megdnn::param::Convolution3D::DataType); \
cb(::megdnn::param::Convolution3D::Format); \
cb(::megdnn::param::ConvolutionV0::Format); \
cb(::megdnn::param::CvtColor::Mode); \
cb(::megdnn::param::Elemwise::Mode); \
cb(::megdnn::param::ElemwiseMultiType::Mode); \
cb(::megdnn::param::Padding::PaddingMode); \
cb(::megdnn::param::RNNCell::NonlineMode); \
cb(::megdnn::param::ROIAlignV0::Mode); \
cb(::megdnn::param::ROIPooling::Mode); \
cb(::megdnn::param::Reduce::Mode); \
cb(::megdnn::param::Reduce::DataType); \
cb(::megdnn::param::WarpPerspectiveV1::InterpolationMode); \
cb(::megdnn::param::WarpPerspectiveV1::BorderMode); \
cb(::megdnn::param::TopK::Mode);
#define FOR_EACH_BIT_COMBINED_ENUM_PARAM(cb) \
cb(::megdnn::param::ExecutionPolicy::Strategy);
// clang-format on
0df57b38e71a4d1882ed6c24f3a26b57 ../../dnn/scripts/opr_param_defs.py
759bfbf27fd3f0dd6b6edf06377e1d6b ../../src/core/include/megbrain/ir/ops.td
c613316001b5f0294ede198f5563f041 generated/opdef.h.inl
a1f7f13c909f9d4c173277f4ed28fb61 generated/opdef.cpp.inl
cf48f9ca352fabaeb6c846c11c6b1662 generated/opdef.py.inl
12365b938f564e5b3639d309f7c83414 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
因为 它太大了无法显示 source diff 。你可以改为 查看blob
因为 它太大了无法显示 source diff 。你可以改为 查看blob
// clang-format off
class AdaptivePooling : public OpDefImplBase<AdaptivePooling> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::AdaptivePooling::Mode;
using Format = ::megdnn::param::AdaptivePooling::Format;
Mode mode = ::megdnn::param::AdaptivePooling::Mode::MAX;
Format format = ::megdnn::param::AdaptivePooling::Format::NCHW;
std::vector<int32_t> shape;
AdaptivePooling() = default;
AdaptivePooling(Mode mode_, Format format_, std::vector<int32_t> shape_, std::string scope_ = {}): mode(mode_), format(format_), shape(shape_) { set_scope(scope_); }
AdaptivePooling(::megdnn::param::AdaptivePooling packed_param_0, std::vector<int32_t> shape_): mode(packed_param_0.mode), format(packed_param_0.format), shape(shape_) {}
::megdnn::param::AdaptivePooling param() const {
return {mode, format};
}
};
class AddAxis : public OpDefImplBase<AddAxis> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<int32_t> axis;
AddAxis() = default;
AddAxis(std::vector<int32_t> axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
};
class Argmax : public OpDefImplBase<Argmax> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 0;
Argmax() = default;
Argmax(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
Argmax(::megdnn::param::Axis packed_param_0): axis(packed_param_0.axis) {}
::megdnn::param::Axis param() const {
return {axis};
}
};
class Argmin : public OpDefImplBase<Argmin> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 0;
Argmin() = default;
Argmin(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
Argmin(::megdnn::param::Axis packed_param_0): axis(packed_param_0.axis) {}
::megdnn::param::Axis param() const {
return {axis};
}
};
class Argsort : public OpDefImplBase<Argsort> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Order = ::megdnn::param::Argsort::Order;
Order order = ::megdnn::param::Argsort::Order::ASCENDING;
Argsort() = default;
Argsort(Order order_, std::string scope_ = {}): order(order_) { set_scope(scope_); }
Argsort(::megdnn::param::Argsort packed_param_0): order(packed_param_0.order) {}
::megdnn::param::Argsort param() const {
return {order};
}
};
class AssertEqual : public OpDefImplBase<AssertEqual> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float maxerr = 0.0001;
bool verbose = false;
AssertEqual() = default;
AssertEqual(float maxerr_, bool verbose_, std::string scope_ = {}): maxerr(maxerr_), verbose(verbose_) { set_scope(scope_); }
AssertEqual(::megdnn::param::AssertEqual packed_param_0): maxerr(packed_param_0.maxerr), verbose(packed_param_0.verbose) {}
::megdnn::param::AssertEqual param() const {
return {maxerr, verbose};
}
};
class AtlasRuntime : public OpDefImplBase<AtlasRuntime> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::string buf;
size_t buf_size;
AtlasRuntime() = default;
AtlasRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
};
class Barrier : public OpDefImplBase<Barrier> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
::mgb::CompNode comp_node;
uint32_t nr_outputs;
Barrier() = default;
Barrier(::mgb::CompNode comp_node_, uint32_t nr_outputs_, std::string scope_ = {}): comp_node(comp_node_), nr_outputs(nr_outputs_) { set_scope(scope_); }
};
class BatchConvBias : public OpDefImplBase<BatchConvBias> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using NonlineMode = ::megdnn::param::BatchConvBias::NonlineMode;
using Mode = ::megdnn::param::BatchConvBias::Mode;
using Sparse = ::megdnn::param::BatchConvBias::Sparse;
using Format = ::megdnn::param::BatchConvBias::Format;
using ComputeMode = ::megdnn::param::BatchConvBias::ComputeMode;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
NonlineMode nonlineMode = ::megdnn::param::BatchConvBias::NonlineMode::IDENTITY;
Mode mode = ::megdnn::param::BatchConvBias::Mode::CROSS_CORRELATION;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::BatchConvBias::Sparse::DENSE;
Format format = ::megdnn::param::BatchConvBias::Format::NCHW;
ComputeMode compute_mode = ::megdnn::param::BatchConvBias::ComputeMode::DEFAULT;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
::megdnn::DType dtype;
BatchConvBias() = default;
BatchConvBias(NonlineMode nonlineMode_, Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): nonlineMode(nonlineMode_), mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
BatchConvBias(::megdnn::param::BatchConvBias packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): nonlineMode(packed_param_0.nonlineMode), mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::BatchConvBias param() const {
return {nonlineMode, mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class BatchNorm : public OpDefImplBase<BatchNorm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using ParamDim = ::megdnn::param::BN::ParamDim;
using FwdMode = ::megdnn::param::BN::FwdMode;
ParamDim param_dim = ::megdnn::param::BN::ParamDim::DIM_11HW;
FwdMode fwd_mode = ::megdnn::param::BN::FwdMode::TRAINING;
double epsilon = 1e-4f;
double avg_factor = 1.f;
float scale = 1.f;
float bias = 0.f;
BatchNorm() = default;
BatchNorm(ParamDim param_dim_, FwdMode fwd_mode_, double epsilon_, double avg_factor_, float scale_, float bias_, std::string scope_ = {}): param_dim(param_dim_), fwd_mode(fwd_mode_), epsilon(epsilon_), avg_factor(avg_factor_), scale(scale_), bias(bias_) { set_scope(scope_); }
BatchNorm(::megdnn::param::BN packed_param_0): param_dim(packed_param_0.param_dim), fwd_mode(packed_param_0.fwd_mode), epsilon(packed_param_0.epsilon), avg_factor(packed_param_0.avg_factor), scale(packed_param_0.scale), bias(packed_param_0.bias) {}
::megdnn::param::BN param() const {
return {param_dim, fwd_mode, epsilon, avg_factor, scale, bias};
}
};
class BatchNormBackward : public OpDefImplBase<BatchNormBackward> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using ParamDim = ::megdnn::param::BN::ParamDim;
using FwdMode = ::megdnn::param::BN::FwdMode;
ParamDim param_dim = ::megdnn::param::BN::ParamDim::DIM_11HW;
FwdMode fwd_mode = ::megdnn::param::BN::FwdMode::TRAINING;
double epsilon = 1e-4f;
double avg_factor = 1.f;
float scale = 1.f;
float bias = 0.f;
BatchNormBackward() = default;
BatchNormBackward(ParamDim param_dim_, FwdMode fwd_mode_, double epsilon_, double avg_factor_, float scale_, float bias_, std::string scope_ = {}): param_dim(param_dim_), fwd_mode(fwd_mode_), epsilon(epsilon_), avg_factor(avg_factor_), scale(scale_), bias(bias_) { set_scope(scope_); }
BatchNormBackward(::megdnn::param::BN packed_param_0): param_dim(packed_param_0.param_dim), fwd_mode(packed_param_0.fwd_mode), epsilon(packed_param_0.epsilon), avg_factor(packed_param_0.avg_factor), scale(packed_param_0.scale), bias(packed_param_0.bias) {}
::megdnn::param::BN param() const {
return {param_dim, fwd_mode, epsilon, avg_factor, scale, bias};
}
};
class BatchedIncrMeshIndexing : public OpDefImplBase<BatchedIncrMeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
BatchedIncrMeshIndexing() = default;
BatchedIncrMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class BatchedMatrixMul : public OpDefImplBase<BatchedMatrixMul> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using ComputeMode = ::megdnn::param::MatrixMul::ComputeMode;
using Format = ::megdnn::param::MatrixMul::Format;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
bool transposeA = false;
bool transposeB = false;
ComputeMode compute_mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
Format format = ::megdnn::param::MatrixMul::Format::DEFAULT;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
uint32_t dimA;
uint32_t dimB;
BatchedMatrixMul() = default;
BatchedMatrixMul(bool transposeA_, bool transposeB_, ComputeMode compute_mode_, Format format_, Strategy strategy_, uint64_t workspace_limit_, uint32_t dimA_, uint32_t dimB_, std::string scope_ = {}): transposeA(transposeA_), transposeB(transposeB_), compute_mode(compute_mode_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_), dimA(dimA_), dimB(dimB_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
BatchedMatrixMul(::megdnn::param::MatrixMul packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, uint32_t dimA_, uint32_t dimB_): transposeA(packed_param_0.transposeA), transposeB(packed_param_0.transposeB), compute_mode(packed_param_0.compute_mode), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dimA(dimA_), dimB(dimB_) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::MatrixMul param() const {
return {transposeA, transposeB, compute_mode, format};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class BatchedMeshIndexing : public OpDefImplBase<BatchedMeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
BatchedMeshIndexing() = default;
BatchedMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class BatchedSetMeshIndexing : public OpDefImplBase<BatchedSetMeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
BatchedSetMeshIndexing() = default;
BatchedSetMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class BetaRNG : public OpDefImplBase<BetaRNG> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint64_t seed = 0;
size_t handle;
BetaRNG() = default;
BetaRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
BetaRNG(::megdnn::param::BetaRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
::megdnn::param::BetaRNG param() const {
return {seed};
}
};
class Borrow : public OpDefImplBase<Borrow> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
::mgb::CompNode comp_node;
Borrow() = default;
Borrow(::mgb::CompNode comp_node_, std::string scope_ = {}): comp_node(comp_node_) { set_scope(scope_); }
};
class Broadcast : public OpDefImplBase<Broadcast> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<int32_t> shape;
Broadcast() = default;
Broadcast(std::vector<int32_t> shape_, std::string scope_ = {}): shape(shape_) { set_scope(scope_); }
Broadcast(::megdnn::param::Empty, std::vector<int32_t> shape_): shape(shape_) {}
::megdnn::param::Empty param() const {
return {};
}
};
class CambriconRuntime : public OpDefImplBase<CambriconRuntime> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::string buf;
size_t buf_size;
std::string symbol;
bool tensor_dim_mutable;
CambriconRuntime() = default;
CambriconRuntime(std::string buf_, size_t buf_size_, std::string symbol_, bool tensor_dim_mutable_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_), symbol(symbol_), tensor_dim_mutable(tensor_dim_mutable_) { set_scope(scope_); }
};
class CheckNonFinite : public OpDefImplBase<CheckNonFinite> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float scale = 1.0;
CheckNonFinite() = default;
CheckNonFinite(float scale_, std::string scope_ = {}): scale(scale_) { set_scope(scope_); }
CheckNonFinite(::megdnn::param::CheckNonFinite packed_param_0): scale(packed_param_0.scale) {}
::megdnn::param::CheckNonFinite param() const {
return {scale};
}
};
class CollectiveComm : public OpDefImplBase<CollectiveComm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::CollectiveComm::Mode;
Mode mode = ::megdnn::param::CollectiveComm::Mode::REDUCE_SUM;
std::string key;
uint32_t nr_devices;
uint32_t rank;
bool is_root;
bool local_grad;
std::string addr;
uint32_t port;
::megdnn::DType dtype;
std::string backend;
std::string comp_node;
CollectiveComm() = default;
CollectiveComm(Mode mode_, std::string key_, uint32_t nr_devices_, uint32_t rank_, bool is_root_, bool local_grad_, std::string addr_, uint32_t port_, ::megdnn::DType dtype_, std::string backend_, std::string comp_node_, std::string scope_ = {}): mode(mode_), key(key_), nr_devices(nr_devices_), rank(rank_), is_root(is_root_), local_grad(local_grad_), addr(addr_), port(port_), dtype(dtype_), backend(backend_), comp_node(comp_node_) { set_scope(scope_); }
CollectiveComm(::megdnn::param::CollectiveComm packed_param_0, std::string key_, uint32_t nr_devices_, uint32_t rank_, bool is_root_, bool local_grad_, std::string addr_, uint32_t port_, ::megdnn::DType dtype_, std::string backend_, std::string comp_node_): mode(packed_param_0.mode), key(key_), nr_devices(nr_devices_), rank(rank_), is_root(is_root_), local_grad(local_grad_), addr(addr_), port(port_), dtype(dtype_), backend(backend_), comp_node(comp_node_) {}
::megdnn::param::CollectiveComm param() const {
return {mode};
}
};
class Concat : public OpDefImplBase<Concat> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 0;
::mgb::CompNode comp_node;
Concat() = default;
Concat(int32_t axis_, ::mgb::CompNode comp_node_, std::string scope_ = {}): axis(axis_), comp_node(comp_node_) { set_scope(scope_); }
Concat(::megdnn::param::Axis packed_param_0, ::mgb::CompNode comp_node_): axis(packed_param_0.axis), comp_node(comp_node_) {}
::megdnn::param::Axis param() const {
return {axis};
}
};
class CondTake : public OpDefImplBase<CondTake> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
CondTake() = default;
};
class ConvBias : public OpDefImplBase<ConvBias> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using NonlineMode = ::megdnn::param::ConvBias::NonlineMode;
using Mode = ::megdnn::param::ConvBias::Mode;
using Sparse = ::megdnn::param::ConvBias::Sparse;
using Format = ::megdnn::param::ConvBias::Format;
using ComputeMode = ::megdnn::param::ConvBias::ComputeMode;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
NonlineMode nonlineMode = ::megdnn::param::ConvBias::NonlineMode::IDENTITY;
Mode mode = ::megdnn::param::ConvBias::Mode::CROSS_CORRELATION;
Sparse sparse = ::megdnn::param::ConvBias::Sparse::DENSE;
Format format = ::megdnn::param::ConvBias::Format::NCHW;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
ComputeMode compute_mode = ::megdnn::param::ConvBias::ComputeMode::DEFAULT;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
::megdnn::DType dtype;
ConvBias() = default;
ConvBias(NonlineMode nonlineMode_, Mode mode_, Sparse sparse_, Format format_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): nonlineMode(nonlineMode_), mode(mode_), sparse(sparse_), format(format_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
ConvBias(::megdnn::param::ConvBias packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): nonlineMode(packed_param_0.nonlineMode), mode(packed_param_0.mode), sparse(packed_param_0.sparse), format(packed_param_0.format), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::ConvBias param() const {
return {nonlineMode, mode, sparse, format, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, compute_mode};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class Convolution : public OpDefImplBase<Convolution> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Convolution::Mode;
using Sparse = ::megdnn::param::Convolution::Sparse;
using Format = ::megdnn::param::Convolution::Format;
using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
Format format = ::megdnn::param::Convolution::Format::NCHW;
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
Convolution() = default;
Convolution(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
Convolution(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::Convolution param() const {
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class Convolution3D : public OpDefImplBase<Convolution3D> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Convolution3D::Mode;
using Sparse = ::megdnn::param::Convolution3D::Sparse;
using DataType = ::megdnn::param::Convolution3D::DataType;
using Format = ::megdnn::param::Convolution3D::Format;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
Mode mode = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION;
uint32_t pad_d = 0;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_d = 1;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_d = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution3D::Sparse::DENSE;
DataType data_type = ::megdnn::param::Convolution3D::DataType::FLOAT;
Format format = ::megdnn::param::Convolution3D::Format::NCDHW;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
Convolution3D() = default;
Convolution3D(Mode mode_, uint32_t pad_d_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_d_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_d_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, DataType data_type_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_d(pad_d_), pad_h(pad_h_), pad_w(pad_w_), stride_d(stride_d_), stride_h(stride_h_), stride_w(stride_w_), dilate_d(dilate_d_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), data_type(data_type_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
Convolution3D(::megdnn::param::Convolution3D packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_d(packed_param_0.pad_d), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_d(packed_param_0.stride_d), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_d(packed_param_0.dilate_d), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), data_type(packed_param_0.data_type), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::Convolution3D param() const {
return {mode, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, dilate_d, dilate_h, dilate_w, sparse, data_type, format};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class Convolution3DBackwardData : public OpDefImplBase<Convolution3DBackwardData> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Convolution3D::Mode;
using Sparse = ::megdnn::param::Convolution3D::Sparse;
using DataType = ::megdnn::param::Convolution3D::DataType;
using Format = ::megdnn::param::Convolution3D::Format;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
Mode mode = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION;
uint32_t pad_d = 0;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_d = 1;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_d = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution3D::Sparse::DENSE;
DataType data_type = ::megdnn::param::Convolution3D::DataType::FLOAT;
Format format = ::megdnn::param::Convolution3D::Format::NCDHW;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
Convolution3DBackwardData() = default;
Convolution3DBackwardData(Mode mode_, uint32_t pad_d_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_d_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_d_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, DataType data_type_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_d(pad_d_), pad_h(pad_h_), pad_w(pad_w_), stride_d(stride_d_), stride_h(stride_h_), stride_w(stride_w_), dilate_d(dilate_d_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), data_type(data_type_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
Convolution3DBackwardData(::megdnn::param::Convolution3D packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_d(packed_param_0.pad_d), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_d(packed_param_0.stride_d), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_d(packed_param_0.dilate_d), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), data_type(packed_param_0.data_type), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::Convolution3D param() const {
return {mode, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, dilate_d, dilate_h, dilate_w, sparse, data_type, format};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class ConvolutionBackwardData : public OpDefImplBase<ConvolutionBackwardData> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Convolution::Mode;
using Sparse = ::megdnn::param::Convolution::Sparse;
using Format = ::megdnn::param::Convolution::Format;
using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
Format format = ::megdnn::param::Convolution::Format::NCHW;
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
::megdnn::DType dtype;
ConvolutionBackwardData() = default;
ConvolutionBackwardData(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, ::megdnn::DType dtype_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_), dtype(dtype_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
ConvolutionBackwardData(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, ::megdnn::DType dtype_): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dtype(dtype_) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::Convolution param() const {
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class Copy : public OpDefImplBase<Copy> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
::mgb::CompNode comp_node;
Copy() = default;
Copy(::mgb::CompNode comp_node_, std::string scope_ = {}): comp_node(comp_node_) { set_scope(scope_); }
};
class Correlation : public OpDefImplBase<Correlation> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Format = ::megdnn::param::Correlation::Format;
Format format = ::megdnn::param::Correlation::Format::NCHW;
uint32_t kernel_size = 1;
uint32_t max_displacement = 1;
uint32_t stride1 = 1;
uint32_t stride2 = 1;
uint32_t pad_size = 0;
bool is_multiply = true;
Correlation() = default;
Correlation(Format format_, uint32_t kernel_size_, uint32_t max_displacement_, uint32_t stride1_, uint32_t stride2_, uint32_t pad_size_, bool is_multiply_, std::string scope_ = {}): format(format_), kernel_size(kernel_size_), max_displacement(max_displacement_), stride1(stride1_), stride2(stride2_), pad_size(pad_size_), is_multiply(is_multiply_) { set_scope(scope_); }
Correlation(::megdnn::param::Correlation packed_param_0): format(packed_param_0.format), kernel_size(packed_param_0.kernel_size), max_displacement(packed_param_0.max_displacement), stride1(packed_param_0.stride1), stride2(packed_param_0.stride2), pad_size(packed_param_0.pad_size), is_multiply(packed_param_0.is_multiply) {}
::megdnn::param::Correlation param() const {
return {format, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply};
}
};
class Cumsum : public OpDefImplBase<Cumsum> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 2147483647;
bool exclusive = true;
bool reverse = false;
Cumsum() = default;
Cumsum(int32_t axis_, bool exclusive_, bool reverse_, std::string scope_ = {}): axis(axis_), exclusive(exclusive_), reverse(reverse_) { set_scope(scope_); }
Cumsum(::megdnn::param::Cumsum packed_param_0): axis(packed_param_0.axis), exclusive(packed_param_0.exclusive), reverse(packed_param_0.reverse) {}
::megdnn::param::Cumsum param() const {
return {axis, exclusive, reverse};
}
};
class CvtColor : public OpDefImplBase<CvtColor> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::CvtColor::Mode;
Mode mode = ::megdnn::param::CvtColor::Mode::RGB2GRAY;
CvtColor() = default;
CvtColor(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
CvtColor(::megdnn::param::CvtColor packed_param_0): mode(packed_param_0.mode) {}
::megdnn::param::CvtColor param() const {
return {mode};
}
};
class DeformableConv : public OpDefImplBase<DeformableConv> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Convolution::Mode;
using Sparse = ::megdnn::param::Convolution::Sparse;
using Format = ::megdnn::param::Convolution::Format;
using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
Format format = ::megdnn::param::Convolution::Format::NCHW;
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
DeformableConv() = default;
DeformableConv(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_), strategy(strategy_), workspace_limit(workspace_limit_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
DeformableConv(::megdnn::param::Convolution packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::Convolution param() const {
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class DeformablePSROIPooling : public OpDefImplBase<DeformablePSROIPooling> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
bool no_trans = true;
float spatial_scale = 1;
float trans_std = 1;
uint32_t pooled_h = 1;
uint32_t pooled_w = 1;
uint32_t part_size = 1;
uint32_t sample_per_part = 1;
DeformablePSROIPooling() = default;
DeformablePSROIPooling(bool no_trans_, float spatial_scale_, float trans_std_, uint32_t pooled_h_, uint32_t pooled_w_, uint32_t part_size_, uint32_t sample_per_part_, std::string scope_ = {}): no_trans(no_trans_), spatial_scale(spatial_scale_), trans_std(trans_std_), pooled_h(pooled_h_), pooled_w(pooled_w_), part_size(part_size_), sample_per_part(sample_per_part_) { set_scope(scope_); }
DeformablePSROIPooling(::megdnn::param::DeformablePSROIPooling packed_param_0): no_trans(packed_param_0.no_trans), spatial_scale(packed_param_0.spatial_scale), trans_std(packed_param_0.trans_std), pooled_h(packed_param_0.pooled_h), pooled_w(packed_param_0.pooled_w), part_size(packed_param_0.part_size), sample_per_part(packed_param_0.sample_per_part) {}
::megdnn::param::DeformablePSROIPooling param() const {
return {no_trans, spatial_scale, trans_std, pooled_h, pooled_w, part_size, sample_per_part};
}
};
class Diag : public OpDefImplBase<Diag> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t k = 0;
Diag() = default;
Diag(int32_t k_, std::string scope_ = {}): k(k_) { set_scope(scope_); }
Diag(::megdnn::param::Diag packed_param_0): k(packed_param_0.k) {}
::megdnn::param::Diag param() const {
return {k};
}
};
class Dimshuffle : public OpDefImplBase<Dimshuffle> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<int32_t> pattern;
Dimshuffle() = default;
Dimshuffle(std::vector<int32_t> pattern_, std::string scope_ = {}): pattern(pattern_) { set_scope(scope_); }
};
class Dot : public OpDefImplBase<Dot> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
Dot() = default;
Dot(::megdnn::param::Empty) {}
::megdnn::param::Empty param() const {
return {};
}
};
class Dropout : public OpDefImplBase<Dropout> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float drop_prob = 0;
uint64_t seed = 0;
size_t handle;
Dropout() = default;
Dropout(float drop_prob_, uint64_t seed_, size_t handle_, std::string scope_ = {}): drop_prob(drop_prob_), seed(seed_), handle(handle_) { set_scope(scope_); }
Dropout(::megdnn::param::Dropout packed_param_0, size_t handle_): drop_prob(packed_param_0.drop_prob), seed(packed_param_0.seed), handle(handle_) {}
::megdnn::param::Dropout param() const {
return {drop_prob, seed};
}
};
class Elemwise : public OpDefImplBase<Elemwise> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Elemwise::Mode;
Mode mode = ::megdnn::param::Elemwise::Mode::RELU;
Elemwise() = default;
Elemwise(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
Elemwise(::megdnn::param::Elemwise packed_param_0): mode(packed_param_0.mode) {}
::megdnn::param::Elemwise param() const {
return {mode};
}
};
template <>
struct ToStringTrait<Elemwise::Mode> {
std::string operator()(Elemwise::Mode e) const {
switch (e) {
case Elemwise::Mode::RELU: return "RELU";
case Elemwise::Mode::ABS: return "ABS";
case Elemwise::Mode::ACOS: return "ACOS";
case Elemwise::Mode::ASIN: return "ASIN";
case Elemwise::Mode::CEIL: return "CEIL";
case Elemwise::Mode::COS: return "COS";
case Elemwise::Mode::EXP: return "EXP";
case Elemwise::Mode::EXPM1: return "EXPM1";
case Elemwise::Mode::FLOOR: return "FLOOR";
case Elemwise::Mode::LOG: return "LOG";
case Elemwise::Mode::LOG1P: return "LOG1P";
case Elemwise::Mode::NEGATE: return "NEGATE";
case Elemwise::Mode::SIGMOID: return "SIGMOID";
case Elemwise::Mode::SIN: return "SIN";
case Elemwise::Mode::TANH: return "TANH";
case Elemwise::Mode::ABS_GRAD: return "ABS_GRAD";
case Elemwise::Mode::ADD: return "ADD";
case Elemwise::Mode::FLOOR_DIV: return "FLOOR_DIV";
case Elemwise::Mode::MAX: return "MAX";
case Elemwise::Mode::MIN: return "MIN";
case Elemwise::Mode::MOD: return "MOD";
case Elemwise::Mode::MUL: return "MUL";
case Elemwise::Mode::POW: return "POW";
case Elemwise::Mode::SIGMOID_GRAD: return "SIGMOID_GRAD";
case Elemwise::Mode::SUB: return "SUB";
case Elemwise::Mode::SWITCH_GT0: return "SWITCH_GT0";
case Elemwise::Mode::TANH_GRAD: return "TANH_GRAD";
case Elemwise::Mode::TRUE_DIV: return "TRUE_DIV";
case Elemwise::Mode::LOG_SUM_EXP: return "LOG_SUM_EXP";
case Elemwise::Mode::LT: return "LT";
case Elemwise::Mode::LEQ: return "LEQ";
case Elemwise::Mode::EQ: return "EQ";
case Elemwise::Mode::SHL: return "SHL";
case Elemwise::Mode::SHR: return "SHR";
case Elemwise::Mode::COND_LEQ_MOV: return "COND_LEQ_MOV";
case Elemwise::Mode::FUSE_MUL_ADD3: return "FUSE_MUL_ADD3";
case Elemwise::Mode::FUSE_MUL_ADD4: return "FUSE_MUL_ADD4";
case Elemwise::Mode::FUSE_ADD_RELU: return "FUSE_ADD_RELU";
case Elemwise::Mode::FUSE_ADD_SIGMOID: return "FUSE_ADD_SIGMOID";
case Elemwise::Mode::FUSE_ADD_TANH: return "FUSE_ADD_TANH";
case Elemwise::Mode::FAST_TANH: return "FAST_TANH";
case Elemwise::Mode::FAST_TANH_GRAD: return "FAST_TANH_GRAD";
case Elemwise::Mode::ROUND: return "ROUND";
case Elemwise::Mode::RMULH: return "RMULH";
case Elemwise::Mode::ATAN2: return "ATAN2";
case Elemwise::Mode::ERF: return "ERF";
case Elemwise::Mode::ERFINV: return "ERFINV";
case Elemwise::Mode::ERFC: return "ERFC";
case Elemwise::Mode::ERFCINV: return "ERFCINV";
case Elemwise::Mode::H_SWISH: return "H_SWISH";
case Elemwise::Mode::H_SWISH_GRAD: return "H_SWISH_GRAD";
case Elemwise::Mode::FUSE_ADD_H_SWISH: return "FUSE_ADD_H_SWISH";
case Elemwise::Mode::NOT: return "NOT";
case Elemwise::Mode::AND: return "AND";
case Elemwise::Mode::OR: return "OR";
case Elemwise::Mode::XOR: return "XOR";
case Elemwise::Mode::SILU: return "SILU";
case Elemwise::Mode::SILU_GRAD: return "SILU_GRAD";
case Elemwise::Mode::GELU: return "GELU";
case Elemwise::Mode::GELU_GRAD: return "GELU_GRAD";
case Elemwise::Mode::COND_LT_MOV: return "COND_LT_MOV";
default:
return "Elemwise::Mode::Unknown";
}
}
};
class ElemwiseMultiType : public OpDefImplBase<ElemwiseMultiType> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::ElemwiseMultiType::Mode;
Mode mode = ::megdnn::param::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
::megdnn::DType dtype;
ElemwiseMultiType() = default;
ElemwiseMultiType(Mode mode_, ::megdnn::DType dtype_, std::string scope_ = {}): mode(mode_), dtype(dtype_) { set_scope(scope_); }
ElemwiseMultiType(::megdnn::param::ElemwiseMultiType packed_param_0, ::megdnn::DType dtype_): mode(packed_param_0.mode), dtype(dtype_) {}
::megdnn::param::ElemwiseMultiType param() const {
return {mode};
}
};
template <>
struct ToStringTrait<ElemwiseMultiType::Mode> {
std::string operator()(ElemwiseMultiType::Mode e) const {
switch (e) {
case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32: return "FUSE_MUL_ADD3_INT16x32x32x32";
case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: return "FUSE_MUL_ADD3_IXxF32xF32xI8";
case ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8: return "ROUND_SHR_SATURATE_IXxI8xI8";
case ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: return "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8";
case ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8: return "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8";
case ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16: return "ROUND_SHR_SATURATE_IXxI8xI16";
case ElemwiseMultiType::Mode::QADD: return "QADD";
case ElemwiseMultiType::Mode::QFUSE_ADD_RELU: return "QFUSE_ADD_RELU";
case ElemwiseMultiType::Mode::QMUL: return "QMUL";
case ElemwiseMultiType::Mode::QMIN: return "QMIN";
case ElemwiseMultiType::Mode::QMAX: return "QMAX";
case ElemwiseMultiType::Mode::QSUB: return "QSUB";
case ElemwiseMultiType::Mode::QTRUE_DIV: return "QTRUE_DIV";
case ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID: return "QFUSE_ADD_SIGMOID";
case ElemwiseMultiType::Mode::QFUSE_ADD_TANH: return "QFUSE_ADD_TANH";
case ElemwiseMultiType::Mode::QRELU: return "QRELU";
case ElemwiseMultiType::Mode::QABS: return "QABS";
case ElemwiseMultiType::Mode::QSIGMOID: return "QSIGMOID";
case ElemwiseMultiType::Mode::QEXP: return "QEXP";
case ElemwiseMultiType::Mode::QTANH: return "QTANH";
case ElemwiseMultiType::Mode::QFUSE_MUL_ADD3: return "QFUSE_MUL_ADD3";
case ElemwiseMultiType::Mode::QFAST_TANH: return "QFAST_TANH";
case ElemwiseMultiType::Mode::QNEGATE: return "QNEGATE";
case ElemwiseMultiType::Mode::QACOS: return "QACOS";
case ElemwiseMultiType::Mode::QASIN: return "QASIN";
case ElemwiseMultiType::Mode::QCEIL: return "QCEIL";
case ElemwiseMultiType::Mode::QCOS: return "QCOS";
case ElemwiseMultiType::Mode::QEXPM1: return "QEXPM1";
case ElemwiseMultiType::Mode::QFLOOR: return "QFLOOR";
case ElemwiseMultiType::Mode::QLOG: return "QLOG";
case ElemwiseMultiType::Mode::QLOG1P: return "QLOG1P";
case ElemwiseMultiType::Mode::QSIN: return "QSIN";
case ElemwiseMultiType::Mode::QROUND: return "QROUND";
case ElemwiseMultiType::Mode::QERF: return "QERF";
case ElemwiseMultiType::Mode::QERFINV: return "QERFINV";
case ElemwiseMultiType::Mode::QERFC: return "QERFC";
case ElemwiseMultiType::Mode::QERFCINV: return "QERFCINV";
case ElemwiseMultiType::Mode::QABS_GRAD: return "QABS_GRAD";
case ElemwiseMultiType::Mode::QFLOOR_DIV: return "QFLOOR_DIV";
case ElemwiseMultiType::Mode::QMOD: return "QMOD";
case ElemwiseMultiType::Mode::QSIGMOID_GRAD: return "QSIGMOID_GRAD";
case ElemwiseMultiType::Mode::QSWITCH_GT0: return "QSWITCH_GT0";
case ElemwiseMultiType::Mode::QTANH_GRAD: return "QTANH_GRAD";
case ElemwiseMultiType::Mode::QLT: return "QLT";
case ElemwiseMultiType::Mode::QLEQ: return "QLEQ";
case ElemwiseMultiType::Mode::QEQ: return "QEQ";
case ElemwiseMultiType::Mode::QPOW: return "QPOW";
case ElemwiseMultiType::Mode::QLOG_SUM_EXP: return "QLOG_SUM_EXP";
case ElemwiseMultiType::Mode::QFAST_TANH_GRAD: return "QFAST_TANH_GRAD";
case ElemwiseMultiType::Mode::QATAN2: return "QATAN2";
case ElemwiseMultiType::Mode::QCOND_LEQ_MOV: return "QCOND_LEQ_MOV";
case ElemwiseMultiType::Mode::QH_SWISH: return "QH_SWISH";
case ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH: return "QFUSE_ADD_H_SWISH";
case ElemwiseMultiType::Mode::QH_SWISH_GRAD: return "QH_SWISH_GRAD";
case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32: return "FUSE_MUL_ADD3_INT16xF32xF32xF32";
case ElemwiseMultiType::Mode::MUL_INT16xF32xF32: return "MUL_INT16xF32xF32";
case ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32: return "FUSE_MUL_ADD3_UINT8xF32xF32xF32";
case ElemwiseMultiType::Mode::QCOND_LT_MOV: return "QCOND_LT_MOV";
default:
return "ElemwiseMultiType::Mode::Unknown";
}
}
};
class ExternOpr : public OpDefImplBase<ExternOpr> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::vector<size_t>> output_shapes;
std::string name;
std::string data;
size_t data_len;
std::vector<::megdnn::DType> output_dtypes;
ExternOpr() = default;
ExternOpr(std::vector<std::vector<size_t>> output_shapes_, std::string name_, std::string data_, size_t data_len_, std::vector<::megdnn::DType> output_dtypes_, std::string scope_ = {}): output_shapes(output_shapes_), name(name_), data(data_), data_len(data_len_), output_dtypes(output_dtypes_) { set_scope(scope_); }
};
class Eye : public OpDefImplBase<Eye> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t k = 0;
::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
::mgb::CompNode comp_node;
Eye() = default;
Eye(int32_t k_, ::megdnn::DType dtype_, ::mgb::CompNode comp_node_, std::string scope_ = {}): k(k_), dtype(dtype_), comp_node(comp_node_) { set_scope(scope_); }
};
class FakeQuant : public OpDefImplBase<FakeQuant> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t qmin = -2147483648;
int32_t qmax = 2147483647;
FakeQuant() = default;
FakeQuant(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
FakeQuant(::megdnn::param::FakeQuant packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
::megdnn::param::FakeQuant param() const {
return {qmin, qmax};
}
};
class FastpathCopy : public OpDefImplBase<FastpathCopy> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
FastpathCopy() = default;
};
class GammaRNG : public OpDefImplBase<GammaRNG> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint64_t seed = 0;
size_t handle;
GammaRNG() = default;
GammaRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
GammaRNG(::megdnn::param::GammaRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
::megdnn::param::GammaRNG param() const {
return {seed};
}
};
class GaussianRNG : public OpDefImplBase<GaussianRNG> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint64_t seed = 0;
float mean = 0;
float std = 1;
::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
size_t handle;
GaussianRNG() = default;
GaussianRNG(uint64_t seed_, float mean_, float std_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), mean(mean_), std(std_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
};
class GetVarShape : public OpDefImplBase<GetVarShape> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = ::megdnn::param::OptionalAxisV1::INVALID_AXIS;
GetVarShape() = default;
GetVarShape(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
GetVarShape(::megdnn::param::OptionalAxisV1 packed_param_0): axis(packed_param_0.axis) {}
::megdnn::param::OptionalAxisV1 param() const {
return {axis};
}
};
class GroupLocal : public OpDefImplBase<GroupLocal> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Convolution::Mode;
using Sparse = ::megdnn::param::Convolution::Sparse;
using Format = ::megdnn::param::Convolution::Format;
using ComputeMode = ::megdnn::param::Convolution::ComputeMode;
Mode mode = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
Sparse sparse = ::megdnn::param::Convolution::Sparse::DENSE;
Format format = ::megdnn::param::Convolution::Format::NCHW;
ComputeMode compute_mode = ::megdnn::param::Convolution::ComputeMode::DEFAULT;
GroupLocal() = default;
GroupLocal(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, Sparse sparse_, Format format_, ComputeMode compute_mode_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), sparse(sparse_), format(format_), compute_mode(compute_mode_) { set_scope(scope_); }
GroupLocal(::megdnn::param::Convolution packed_param_0): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), sparse(packed_param_0.sparse), format(packed_param_0.format), compute_mode(packed_param_0.compute_mode) {}
::megdnn::param::Convolution param() const {
return {mode, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, sparse, format, compute_mode};
}
};
class Identity : public OpDefImplBase<Identity> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
Identity() = default;
};
class Images2Neibs : public OpDefImplBase<Images2Neibs> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
uint32_t window_h = 3;
uint32_t window_w = 3;
Images2Neibs() = default;
Images2Neibs(uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, uint32_t window_h_, uint32_t window_w_, std::string scope_ = {}): pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), window_h(window_h_), window_w(window_w_) { set_scope(scope_); }
Images2Neibs(::megdnn::param::Images2Neibs packed_param_0): pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w) {}
::megdnn::param::Images2Neibs param() const {
return {pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, window_h, window_w};
}
};
class IncrMeshIndexing : public OpDefImplBase<IncrMeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
IncrMeshIndexing() = default;
IncrMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class IncrSubtensor : public OpDefImplBase<IncrSubtensor> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
IncrSubtensor() = default;
IncrSubtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class IndexingIncrMultiAxisVec : public OpDefImplBase<IndexingIncrMultiAxisVec> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
IndexingIncrMultiAxisVec() = default;
IndexingIncrMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class IndexingMultiAxisVec : public OpDefImplBase<IndexingMultiAxisVec> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
IndexingMultiAxisVec() = default;
IndexingMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class IndexingOneHot : public OpDefImplBase<IndexingOneHot> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 0;
int32_t ndim;
IndexingOneHot() = default;
IndexingOneHot(int32_t axis_, int32_t ndim_, std::string scope_ = {}): axis(axis_), ndim(ndim_) { set_scope(scope_); }
IndexingOneHot(::megdnn::param::Axis packed_param_0, int32_t ndim_): axis(packed_param_0.axis), ndim(ndim_) {}
::megdnn::param::Axis param() const {
return {axis};
}
};
class IndexingSetMultiAxisVec : public OpDefImplBase<IndexingSetMultiAxisVec> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
IndexingSetMultiAxisVec() = default;
IndexingSetMultiAxisVec(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class IndexingSetOneHot : public OpDefImplBase<IndexingSetOneHot> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = 0;
int32_t ndim;
IndexingSetOneHot() = default;
IndexingSetOneHot(int32_t axis_, int32_t ndim_, std::string scope_ = {}): axis(axis_), ndim(ndim_) { set_scope(scope_); }
IndexingSetOneHot(::megdnn::param::Axis packed_param_0, int32_t ndim_): axis(packed_param_0.axis), ndim(ndim_) {}
::megdnn::param::Axis param() const {
return {axis};
}
};
class InplaceAdd : public OpDefImplBase<InplaceAdd> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
InplaceAdd() = default;
InplaceAdd(::megdnn::param::Empty) {}
::megdnn::param::Empty param() const {
return {};
}
};
class LAMBUpdate : public OpDefImplBase<LAMBUpdate> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float beta_1 = 1.f;
float beta_2 = 1.f;
float step = 1.f;
float lr = 1.f;
float weight_decay = 1.f;
float eps = 1.f;
bool bias_correction = true;
bool always_adapt = false;
LAMBUpdate() = default;
LAMBUpdate(float beta_1_, float beta_2_, float step_, float lr_, float weight_decay_, float eps_, bool bias_correction_, bool always_adapt_, std::string scope_ = {}): beta_1(beta_1_), beta_2(beta_2_), step(step_), lr(lr_), weight_decay(weight_decay_), eps(eps_), bias_correction(bias_correction_), always_adapt(always_adapt_) { set_scope(scope_); }
LAMBUpdate(::megdnn::param::LAMBUpdate packed_param_0): beta_1(packed_param_0.beta_1), beta_2(packed_param_0.beta_2), step(packed_param_0.step), lr(packed_param_0.lr), weight_decay(packed_param_0.weight_decay), eps(packed_param_0.eps), bias_correction(packed_param_0.bias_correction), always_adapt(packed_param_0.always_adapt) {}
::megdnn::param::LAMBUpdate param() const {
return {beta_1, beta_2, step, lr, weight_decay, eps, bias_correction, always_adapt};
}
};
class LRN : public OpDefImplBase<LRN> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint32_t n = 5;
float k = 2.f;
float alpha = 1e-4f;
float beta = 0.75f;
LRN() = default;
LRN(uint32_t n_, float k_, float alpha_, float beta_, std::string scope_ = {}): n(n_), k(k_), alpha(alpha_), beta(beta_) { set_scope(scope_); }
LRN(::megdnn::param::LRN packed_param_0): n(packed_param_0.n), k(packed_param_0.k), alpha(packed_param_0.alpha), beta(packed_param_0.beta) {}
::megdnn::param::LRN param() const {
return {n, k, alpha, beta};
}
};
class LSQ : public OpDefImplBase<LSQ> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t qmin = -2147483648;
int32_t qmax = 2147483647;
LSQ() = default;
LSQ(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
LSQ(::megdnn::param::LSQ packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
::megdnn::param::LSQ param() const {
return {qmin, qmax};
}
};
class LSTM : public OpDefImplBase<LSTM> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using FwdMode = ::megdnn::param::LSTM::FwdMode;
uint32_t num_layers = 1;
bool bidirectional = false;
bool bias = true;
uint32_t hidden_size = 128;
uint32_t proj_size = 0;
float dropout = 0.f;
FwdMode fwd_mode = ::megdnn::param::LSTM::FwdMode::TRAINING;
LSTM() = default;
LSTM(uint32_t num_layers_, bool bidirectional_, bool bias_, uint32_t hidden_size_, uint32_t proj_size_, float dropout_, FwdMode fwd_mode_, std::string scope_ = {}): num_layers(num_layers_), bidirectional(bidirectional_), bias(bias_), hidden_size(hidden_size_), proj_size(proj_size_), dropout(dropout_), fwd_mode(fwd_mode_) { set_scope(scope_); }
LSTM(::megdnn::param::LSTM packed_param_0): num_layers(packed_param_0.num_layers), bidirectional(packed_param_0.bidirectional), bias(packed_param_0.bias), hidden_size(packed_param_0.hidden_size), proj_size(packed_param_0.proj_size), dropout(packed_param_0.dropout), fwd_mode(packed_param_0.fwd_mode) {}
::megdnn::param::LSTM param() const {
return {num_layers, bidirectional, bias, hidden_size, proj_size, dropout, fwd_mode};
}
};
class LSTMCell : public OpDefImplBase<LSTMCell> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
LSTMCell() = default;
LSTMCell(::megdnn::param::Empty) {}
::megdnn::param::Empty param() const {
return {};
}
};
class LayerNorm : public OpDefImplBase<LayerNorm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
bool affine = true;
float eps = 1e-5f;
uint64_t normalized_dim = 1;
uint64_t normalized_size = 1;
LayerNorm() = default;
LayerNorm(bool affine_, float eps_, uint64_t normalized_dim_, uint64_t normalized_size_, std::string scope_ = {}): affine(affine_), eps(eps_), normalized_dim(normalized_dim_), normalized_size(normalized_size_) { set_scope(scope_); }
LayerNorm(::megdnn::param::LayerNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), normalized_dim(packed_param_0.normalized_dim), normalized_size(packed_param_0.normalized_size) {}
::megdnn::param::LayerNorm param() const {
return {affine, eps, normalized_dim, normalized_size};
}
};
class Linspace : public OpDefImplBase<Linspace> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
bool endpoint = true;
::mgb::CompNode comp_node;
Linspace() = default;
Linspace(bool endpoint_, ::mgb::CompNode comp_node_, std::string scope_ = {}): endpoint(endpoint_), comp_node(comp_node_) { set_scope(scope_); }
Linspace(::megdnn::param::Linspace packed_param_0, ::mgb::CompNode comp_node_): endpoint(packed_param_0.endpoint), comp_node(comp_node_) {}
::megdnn::param::Linspace param() const {
return {endpoint};
}
};
class MagicMindRuntime : public OpDefImplBase<MagicMindRuntime> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::string buf;
size_t buf_size;
MagicMindRuntime() = default;
MagicMindRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
};
class MatrixInverse : public OpDefImplBase<MatrixInverse> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
MatrixInverse() = default;
MatrixInverse(::megdnn::param::Empty) {}
::megdnn::param::Empty param() const {
return {};
}
};
class MatrixMul : public OpDefImplBase<MatrixMul> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using ComputeMode = ::megdnn::param::MatrixMul::ComputeMode;
using Format = ::megdnn::param::MatrixMul::Format;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
bool transposeA = false;
bool transposeB = false;
ComputeMode compute_mode = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
Format format = ::megdnn::param::MatrixMul::Format::DEFAULT;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
uint32_t dimA;
uint32_t dimB;
MatrixMul() = default;
MatrixMul(bool transposeA_, bool transposeB_, ComputeMode compute_mode_, Format format_, Strategy strategy_, uint64_t workspace_limit_, uint32_t dimA_, uint32_t dimB_, std::string scope_ = {}): transposeA(transposeA_), transposeB(transposeB_), compute_mode(compute_mode_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_), dimA(dimA_), dimB(dimB_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
MatrixMul(::megdnn::param::MatrixMul packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1, uint32_t dimA_, uint32_t dimB_): transposeA(packed_param_0.transposeA), transposeB(packed_param_0.transposeB), compute_mode(packed_param_0.compute_mode), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit), dimA(dimA_), dimB(dimB_) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::MatrixMul param() const {
return {transposeA, transposeB, compute_mode, format};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class MeshIndexing : public OpDefImplBase<MeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
MeshIndexing() = default;
MeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class NMSKeep : public OpDefImplBase<NMSKeep> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float iou_thresh;
uint32_t max_output;
NMSKeep() = default;
NMSKeep(float iou_thresh_, uint32_t max_output_, std::string scope_ = {}): iou_thresh(iou_thresh_), max_output(max_output_) { set_scope(scope_); }
};
class NvOf : public OpDefImplBase<NvOf> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint32_t precision = 1;
NvOf() = default;
NvOf(uint32_t precision_, std::string scope_ = {}): precision(precision_) { set_scope(scope_); }
NvOf(::megdnn::param::NvOf packed_param_0): precision(packed_param_0.precision) {}
::megdnn::param::NvOf param() const {
return {precision};
}
};
class Padding : public OpDefImplBase<Padding> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using PaddingMode = ::megdnn::param::Padding::PaddingMode;
uint32_t front_offset_dim0 = 0;
uint32_t front_offset_dim1 = 0;
uint32_t front_offset_dim2 = 0;
uint32_t front_offset_dim3 = 0;
uint32_t front_offset_dim4 = 0;
uint32_t front_offset_dim5 = 0;
uint32_t front_offset_dim6 = 0;
uint32_t back_offset_dim0 = 0;
uint32_t back_offset_dim1 = 0;
uint32_t back_offset_dim2 = 0;
uint32_t back_offset_dim3 = 0;
uint32_t back_offset_dim4 = 0;
uint32_t back_offset_dim5 = 0;
uint32_t back_offset_dim6 = 0;
float padding_val = 0;
PaddingMode padding_mode = ::megdnn::param::Padding::PaddingMode::CONSTANT;
Padding() = default;
Padding(uint32_t front_offset_dim0_, uint32_t front_offset_dim1_, uint32_t front_offset_dim2_, uint32_t front_offset_dim3_, uint32_t front_offset_dim4_, uint32_t front_offset_dim5_, uint32_t front_offset_dim6_, uint32_t back_offset_dim0_, uint32_t back_offset_dim1_, uint32_t back_offset_dim2_, uint32_t back_offset_dim3_, uint32_t back_offset_dim4_, uint32_t back_offset_dim5_, uint32_t back_offset_dim6_, float padding_val_, PaddingMode padding_mode_, std::string scope_ = {}): front_offset_dim0(front_offset_dim0_), front_offset_dim1(front_offset_dim1_), front_offset_dim2(front_offset_dim2_), front_offset_dim3(front_offset_dim3_), front_offset_dim4(front_offset_dim4_), front_offset_dim5(front_offset_dim5_), front_offset_dim6(front_offset_dim6_), back_offset_dim0(back_offset_dim0_), back_offset_dim1(back_offset_dim1_), back_offset_dim2(back_offset_dim2_), back_offset_dim3(back_offset_dim3_), back_offset_dim4(back_offset_dim4_), back_offset_dim5(back_offset_dim5_), back_offset_dim6(back_offset_dim6_), padding_val(padding_val_), padding_mode(padding_mode_) { set_scope(scope_); }
Padding(::megdnn::param::Padding packed_param_0): front_offset_dim0(packed_param_0.front_offset_dim0), front_offset_dim1(packed_param_0.front_offset_dim1), front_offset_dim2(packed_param_0.front_offset_dim2), front_offset_dim3(packed_param_0.front_offset_dim3), front_offset_dim4(packed_param_0.front_offset_dim4), front_offset_dim5(packed_param_0.front_offset_dim5), front_offset_dim6(packed_param_0.front_offset_dim6), back_offset_dim0(packed_param_0.back_offset_dim0), back_offset_dim1(packed_param_0.back_offset_dim1), back_offset_dim2(packed_param_0.back_offset_dim2), back_offset_dim3(packed_param_0.back_offset_dim3), back_offset_dim4(packed_param_0.back_offset_dim4), back_offset_dim5(packed_param_0.back_offset_dim5), back_offset_dim6(packed_param_0.back_offset_dim6), padding_val(packed_param_0.padding_val), padding_mode(packed_param_0.padding_mode) {}
::megdnn::param::Padding param() const {
return {front_offset_dim0, front_offset_dim1, front_offset_dim2, front_offset_dim3, front_offset_dim4, front_offset_dim5, front_offset_dim6, back_offset_dim0, back_offset_dim1, back_offset_dim2, back_offset_dim3, back_offset_dim4, back_offset_dim5, back_offset_dim6, padding_val, padding_mode};
}
};
class ParamPackConcat : public OpDefImplBase<ParamPackConcat> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<int32_t> offsets;
ParamPackConcat() = default;
ParamPackConcat(std::vector<int32_t> offsets_, std::string scope_ = {}): offsets(offsets_) { set_scope(scope_); }
};
class ParamPackSplit : public OpDefImplBase<ParamPackSplit> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<int32_t> offsets;
std::vector<std::vector<size_t>> shapes;
ParamPackSplit() = default;
ParamPackSplit(std::vector<int32_t> offsets_, std::vector<std::vector<size_t>> shapes_, std::string scope_ = {}): offsets(offsets_), shapes(shapes_) { set_scope(scope_); }
};
class PermutationRNG : public OpDefImplBase<PermutationRNG> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint64_t seed = 0;
::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Int32);
size_t handle;
PermutationRNG() = default;
PermutationRNG(uint64_t seed_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
};
class PixelShuffle : public OpDefImplBase<PixelShuffle> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t factor;
PixelShuffle() = default;
PixelShuffle(int32_t factor_, std::string scope_ = {}): factor(factor_) { set_scope(scope_); }
};
class PixelShuffleBackward : public OpDefImplBase<PixelShuffleBackward> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t factor;
PixelShuffleBackward() = default;
PixelShuffleBackward(int32_t factor_, std::string scope_ = {}): factor(factor_) { set_scope(scope_); }
};
class PoissonRNG : public OpDefImplBase<PoissonRNG> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint64_t seed = 0;
size_t handle;
PoissonRNG() = default;
PoissonRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
PoissonRNG(::megdnn::param::PoissonRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
::megdnn::param::PoissonRNG param() const {
return {seed};
}
};
class Pooling : public OpDefImplBase<Pooling> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Pooling::Mode;
using Format = ::megdnn::param::Pooling::Format;
using Strategy = ::megdnn::param::ExecutionPolicy::Strategy;
Mode mode = ::megdnn::param::Pooling::Mode::MAX;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 2;
uint32_t stride_w = 2;
uint32_t window_h = 2;
uint32_t window_w = 2;
Format format = ::megdnn::param::Pooling::Format::NCHW;
Strategy strategy = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1);
uint64_t workspace_limit = 18446744073709551615ull;
Pooling() = default;
Pooling(Mode mode_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t window_h_, uint32_t window_w_, Format format_, Strategy strategy_, uint64_t workspace_limit_, std::string scope_ = {}): mode(mode_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), window_h(window_h_), window_w(window_w_), format(format_), strategy(strategy_), workspace_limit(workspace_limit_) {
set_scope(scope_);
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
Pooling(::megdnn::param::Pooling packed_param_0, ::megdnn::param::ExecutionPolicy packed_param_1): mode(packed_param_0.mode), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w), format(packed_param_0.format), strategy(packed_param_1.strategy), workspace_limit(packed_param_1.workspace_limit) {
mgb_assert(static_cast<uint32_t>(strategy) <= uint32_t(8));
}
::megdnn::param::Pooling param() const {
return {mode, pad_h, pad_w, stride_h, stride_w, window_h, window_w, format};
}
::megdnn::param::ExecutionPolicy policy() const {
return {strategy, workspace_limit};
}
};
class RNN : public OpDefImplBase<RNN> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using NonlineMode = ::megdnn::param::RNN::NonlineMode;
using FwdMode = ::megdnn::param::RNN::FwdMode;
uint32_t num_layers = 1;
bool bidirectional = false;
bool bias = true;
uint32_t hidden_size = 128;
float dropout = 0.f;
NonlineMode nonlineMode = ::megdnn::param::RNN::NonlineMode::IDENTITY;
FwdMode fwd_mode = ::megdnn::param::RNN::FwdMode::TRAINING;
RNN() = default;
RNN(uint32_t num_layers_, bool bidirectional_, bool bias_, uint32_t hidden_size_, float dropout_, NonlineMode nonlineMode_, FwdMode fwd_mode_, std::string scope_ = {}): num_layers(num_layers_), bidirectional(bidirectional_), bias(bias_), hidden_size(hidden_size_), dropout(dropout_), nonlineMode(nonlineMode_), fwd_mode(fwd_mode_) { set_scope(scope_); }
RNN(::megdnn::param::RNN packed_param_0): num_layers(packed_param_0.num_layers), bidirectional(packed_param_0.bidirectional), bias(packed_param_0.bias), hidden_size(packed_param_0.hidden_size), dropout(packed_param_0.dropout), nonlineMode(packed_param_0.nonlineMode), fwd_mode(packed_param_0.fwd_mode) {}
::megdnn::param::RNN param() const {
return {num_layers, bidirectional, bias, hidden_size, dropout, nonlineMode, fwd_mode};
}
};
class RNNCell : public OpDefImplBase<RNNCell> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using NonlineMode = ::megdnn::param::RNNCell::NonlineMode;
NonlineMode nonlineMode = ::megdnn::param::RNNCell::NonlineMode::IDENTITY;
RNNCell() = default;
RNNCell(NonlineMode nonlineMode_, std::string scope_ = {}): nonlineMode(nonlineMode_) { set_scope(scope_); }
RNNCell(::megdnn::param::RNNCell packed_param_0): nonlineMode(packed_param_0.nonlineMode) {}
::megdnn::param::RNNCell param() const {
return {nonlineMode};
}
};
class ROIAlign : public OpDefImplBase<ROIAlign> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::ROIAlign::Mode;
using Format = ::megdnn::param::ROIAlign::Format;
Mode mode = ::megdnn::param::ROIAlign::Mode::MAX;
Format format = ::megdnn::param::ROIAlign::Format::NCHW;
float spatial_scale = 1.0;
float offset = 0.0;
uint32_t pooled_height = 1;
uint32_t pooled_width = 1;
uint32_t sample_height = 2;
uint32_t sample_width = 2;
ROIAlign() = default;
ROIAlign(Mode mode_, Format format_, float spatial_scale_, float offset_, uint32_t pooled_height_, uint32_t pooled_width_, uint32_t sample_height_, uint32_t sample_width_, std::string scope_ = {}): mode(mode_), format(format_), spatial_scale(spatial_scale_), offset(offset_), pooled_height(pooled_height_), pooled_width(pooled_width_), sample_height(sample_height_), sample_width(sample_width_) { set_scope(scope_); }
ROIAlign(::megdnn::param::ROIAlign packed_param_0): mode(packed_param_0.mode), format(packed_param_0.format), spatial_scale(packed_param_0.spatial_scale), offset(packed_param_0.offset), pooled_height(packed_param_0.pooled_height), pooled_width(packed_param_0.pooled_width), sample_height(packed_param_0.sample_height), sample_width(packed_param_0.sample_width) {}
::megdnn::param::ROIAlign param() const {
return {mode, format, spatial_scale, offset, pooled_height, pooled_width, sample_height, sample_width};
}
};
class ROIPooling : public OpDefImplBase<ROIPooling> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::ROIPooling::Mode;
Mode mode = ::megdnn::param::ROIPooling::Mode::MAX;
float scale = 1.f;
ROIPooling() = default;
ROIPooling(Mode mode_, float scale_, std::string scope_ = {}): mode(mode_), scale(scale_) { set_scope(scope_); }
ROIPooling(::megdnn::param::ROIPooling packed_param_0): mode(packed_param_0.mode), scale(packed_param_0.scale) {}
::megdnn::param::ROIPooling param() const {
return {mode, scale};
}
};
class Reduce : public OpDefImplBase<Reduce> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::Reduce::Mode;
using DataType = ::megdnn::param::Reduce::DataType;
Mode mode = ::megdnn::param::Reduce::Mode::SUM;
int32_t axis = 2147483647;
DataType data_type = ::megdnn::param::Reduce::DataType::DEFAULT;
bool keepdim = true;
Reduce() = default;
Reduce(Mode mode_, int32_t axis_, DataType data_type_, bool keepdim_, std::string scope_ = {}): mode(mode_), axis(axis_), data_type(data_type_), keepdim(keepdim_) { set_scope(scope_); }
Reduce(::megdnn::param::Reduce packed_param_0, bool keepdim_): mode(packed_param_0.mode), axis(packed_param_0.axis), data_type(packed_param_0.data_type), keepdim(keepdim_) {}
::megdnn::param::Reduce param() const {
return {mode, axis, data_type};
}
};
class Remap : public OpDefImplBase<Remap> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::Remap::InterpolationMode;
using BorderMode = ::megdnn::param::Remap::BorderMode;
using Format = ::megdnn::param::Remap::Format;
InterpolationMode imode = ::megdnn::param::Remap::InterpolationMode::LINEAR;
BorderMode border_type = ::megdnn::param::Remap::BorderMode::REPLICATE;
Format format = ::megdnn::param::Remap::Format::NHWC;
float scalar = 0.f;
Remap() = default;
Remap(InterpolationMode imode_, BorderMode border_type_, Format format_, float scalar_, std::string scope_ = {}): imode(imode_), border_type(border_type_), format(format_), scalar(scalar_) { set_scope(scope_); }
Remap(::megdnn::param::Remap packed_param_0): imode(packed_param_0.imode), border_type(packed_param_0.border_type), format(packed_param_0.format), scalar(packed_param_0.scalar) {}
::megdnn::param::Remap param() const {
return {imode, border_type, format, scalar};
}
};
class RemoteRecv : public OpDefImplBase<RemoteRecv> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::string key;
std::string addr;
uint32_t port;
uint32_t rank_from;
::mgb::CompNode cn;
std::vector<int32_t> shape;
::megdnn::DType dtype;
std::string backend;
RemoteRecv() = default;
RemoteRecv(std::string key_, std::string addr_, uint32_t port_, uint32_t rank_from_, ::mgb::CompNode cn_, std::vector<int32_t> shape_, ::megdnn::DType dtype_, std::string backend_, std::string scope_ = {}): key(key_), addr(addr_), port(port_), rank_from(rank_from_), cn(cn_), shape(shape_), dtype(dtype_), backend(backend_) { set_scope(scope_); }
};
class RemoteSend : public OpDefImplBase<RemoteSend> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::string key;
std::string addr;
uint32_t port;
uint32_t rank_to;
std::string backend;
RemoteSend() = default;
RemoteSend(std::string key_, std::string addr_, uint32_t port_, uint32_t rank_to_, std::string backend_, std::string scope_ = {}): key(key_), addr(addr_), port(port_), rank_to(rank_to_), backend(backend_) { set_scope(scope_); }
};
class RemoveAxis : public OpDefImplBase<RemoveAxis> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<int32_t> axis;
RemoveAxis() = default;
RemoveAxis(std::vector<int32_t> axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
};
class Reshape : public OpDefImplBase<Reshape> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = ::megdnn::param::OptionalAxisV1::INVALID_AXIS;
std::vector<int32_t> shape;
Reshape() = default;
Reshape(int32_t axis_, std::vector<int32_t> shape_, std::string scope_ = {}): axis(axis_), shape(shape_) { set_scope(scope_); }
Reshape(::megdnn::param::OptionalAxisV1 packed_param_0, std::vector<int32_t> shape_): axis(packed_param_0.axis), shape(shape_) {}
::megdnn::param::OptionalAxisV1 param() const {
return {axis};
}
};
class Resize : public OpDefImplBase<Resize> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::Resize::InterpolationMode;
using Format = ::megdnn::param::Resize::Format;
InterpolationMode imode = ::megdnn::param::Resize::InterpolationMode::LINEAR;
Format format = ::megdnn::param::Resize::Format::NHWC;
Resize() = default;
Resize(InterpolationMode imode_, Format format_, std::string scope_ = {}): imode(imode_), format(format_) { set_scope(scope_); }
Resize(::megdnn::param::Resize packed_param_0): imode(packed_param_0.imode), format(packed_param_0.format) {}
::megdnn::param::Resize param() const {
return {imode, format};
}
};
class SVD : public OpDefImplBase<SVD> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
bool full_matrices = false;
bool compute_uv = true;
SVD() = default;
SVD(bool full_matrices_, bool compute_uv_, std::string scope_ = {}): full_matrices(full_matrices_), compute_uv(compute_uv_) { set_scope(scope_); }
SVD(::megdnn::param::SVD packed_param_0): full_matrices(packed_param_0.full_matrices), compute_uv(packed_param_0.compute_uv) {}
::megdnn::param::SVD param() const {
return {full_matrices, compute_uv};
}
};
class SetMeshIndexing : public OpDefImplBase<SetMeshIndexing> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
SetMeshIndexing() = default;
SetMeshIndexing(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class SetSubtensor : public OpDefImplBase<SetSubtensor> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
SetSubtensor() = default;
SetSubtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class ShuffleRNG : public OpDefImplBase<ShuffleRNG> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint64_t seed = 0;
size_t handle;
ShuffleRNG() = default;
ShuffleRNG(uint64_t seed_, size_t handle_, std::string scope_ = {}): seed(seed_), handle(handle_) { set_scope(scope_); }
ShuffleRNG(::megdnn::param::ShuffleRNG packed_param_0, size_t handle_): seed(packed_param_0.seed), handle(handle_) {}
::megdnn::param::ShuffleRNG param() const {
return {seed};
}
};
class SlidingWindowTranspose : public OpDefImplBase<SlidingWindowTranspose> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint32_t out_h = 0;
uint32_t out_w = 0;
uint32_t pad_h = 0;
uint32_t pad_w = 0;
uint32_t stride_h = 1;
uint32_t stride_w = 1;
uint32_t dilate_h = 1;
uint32_t dilate_w = 1;
uint32_t window_h = 3;
uint32_t window_w = 3;
SlidingWindowTranspose() = default;
SlidingWindowTranspose(uint32_t out_h_, uint32_t out_w_, uint32_t pad_h_, uint32_t pad_w_, uint32_t stride_h_, uint32_t stride_w_, uint32_t dilate_h_, uint32_t dilate_w_, uint32_t window_h_, uint32_t window_w_, std::string scope_ = {}): out_h(out_h_), out_w(out_w_), pad_h(pad_h_), pad_w(pad_w_), stride_h(stride_h_), stride_w(stride_w_), dilate_h(dilate_h_), dilate_w(dilate_w_), window_h(window_h_), window_w(window_w_) { set_scope(scope_); }
SlidingWindowTranspose(::megdnn::param::SlidingWindowTranspose packed_param_0): out_h(packed_param_0.out_h), out_w(packed_param_0.out_w), pad_h(packed_param_0.pad_h), pad_w(packed_param_0.pad_w), stride_h(packed_param_0.stride_h), stride_w(packed_param_0.stride_w), dilate_h(packed_param_0.dilate_h), dilate_w(packed_param_0.dilate_w), window_h(packed_param_0.window_h), window_w(packed_param_0.window_w) {}
::megdnn::param::SlidingWindowTranspose param() const {
return {out_h, out_w, pad_h, pad_w, stride_h, stride_w, dilate_h, dilate_w, window_h, window_w};
}
};
class Softmax : public OpDefImplBase<Softmax> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis = -1;
Softmax() = default;
Softmax(int32_t axis_, std::string scope_ = {}): axis(axis_) { set_scope(scope_); }
Softmax(::megdnn::param::Softmax packed_param_0): axis(packed_param_0.axis) {}
::megdnn::param::Softmax param() const {
return {axis};
}
};
class Split : public OpDefImplBase<Split> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t axis;
int32_t nsections;
Split() = default;
Split(int32_t axis_, int32_t nsections_, std::string scope_ = {}): axis(axis_), nsections(nsections_) { set_scope(scope_); }
Split(::megdnn::param::Empty, int32_t axis_, int32_t nsections_): axis(axis_), nsections(nsections_) {}
::megdnn::param::Empty param() const {
return {};
}
};
class Subtensor : public OpDefImplBase<Subtensor> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items;
Subtensor() = default;
Subtensor(std::vector<std::tuple<int8_t, bool, bool, bool, bool>> items_, std::string scope_ = {}): items(items_) { set_scope(scope_); }
};
class TQT : public OpDefImplBase<TQT> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
int32_t qmin = -2147483648;
int32_t qmax = 2147483647;
TQT() = default;
TQT(int32_t qmin_, int32_t qmax_, std::string scope_ = {}): qmin(qmin_), qmax(qmax_) { set_scope(scope_); }
TQT(::megdnn::param::TQT packed_param_0): qmin(packed_param_0.qmin), qmax(packed_param_0.qmax) {}
::megdnn::param::TQT param() const {
return {qmin, qmax};
}
};
class TensorRTRuntime : public OpDefImplBase<TensorRTRuntime> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
std::string buf;
size_t buf_size;
TensorRTRuntime() = default;
TensorRTRuntime(std::string buf_, size_t buf_size_, std::string scope_ = {}): buf(buf_), buf_size(buf_size_) { set_scope(scope_); }
};
class TopK : public OpDefImplBase<TopK> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Mode = ::megdnn::param::TopK::Mode;
Mode mode = ::megdnn::param::TopK::Mode::KTH_ONLY;
TopK() = default;
TopK(Mode mode_, std::string scope_ = {}): mode(mode_) { set_scope(scope_); }
TopK(::megdnn::param::TopK packed_param_0): mode(packed_param_0.mode) {}
::megdnn::param::TopK param() const {
return {mode};
}
};
class TypeCvt : public OpDefImplBase<TypeCvt> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
::megdnn::DType dtype;
TypeCvt() = default;
TypeCvt(::megdnn::DType dtype_, std::string scope_ = {}): dtype(dtype_) { set_scope(scope_); }
};
class UniformRNG : public OpDefImplBase<UniformRNG> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
uint64_t seed = 0;
::megdnn::DType dtype = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32);
size_t handle;
UniformRNG() = default;
UniformRNG(uint64_t seed_, ::megdnn::DType dtype_, size_t handle_, std::string scope_ = {}): seed(seed_), dtype(dtype_), handle(handle_) { set_scope(scope_); }
};
class WarpAffine : public OpDefImplBase<WarpAffine> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::WarpAffine::InterpolationMode;
using BorderMode = ::megdnn::param::WarpAffine::BorderMode;
using Format = ::megdnn::param::WarpAffine::Format;
InterpolationMode imode = ::megdnn::param::WarpAffine::InterpolationMode::LINEAR;
BorderMode border_mode = ::megdnn::param::WarpAffine::BorderMode::REPLICATE;
float border_val = .0f;
Format format = ::megdnn::param::WarpAffine::Format::NHWC;
WarpAffine() = default;
WarpAffine(InterpolationMode imode_, BorderMode border_mode_, float border_val_, Format format_, std::string scope_ = {}): imode(imode_), border_mode(border_mode_), border_val(border_val_), format(format_) { set_scope(scope_); }
WarpAffine(::megdnn::param::WarpAffine packed_param_0): imode(packed_param_0.imode), border_mode(packed_param_0.border_mode), border_val(packed_param_0.border_val), format(packed_param_0.format) {}
::megdnn::param::WarpAffine param() const {
return {imode, border_mode, border_val, format};
}
};
class WarpPerspective : public OpDefImplBase<WarpPerspective> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using InterpolationMode = ::megdnn::param::WarpPerspective::InterpolationMode;
using BorderMode = ::megdnn::param::WarpPerspective::BorderMode;
using Format = ::megdnn::param::WarpPerspective::Format;
InterpolationMode imode = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR;
BorderMode bmode = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE;
Format format = ::megdnn::param::WarpPerspective::Format::NCHW;
float border_val = .0f;
WarpPerspective() = default;
WarpPerspective(InterpolationMode imode_, BorderMode bmode_, Format format_, float border_val_, std::string scope_ = {}): imode(imode_), bmode(bmode_), format(format_), border_val(border_val_) { set_scope(scope_); }
WarpPerspective(::megdnn::param::WarpPerspective packed_param_0): imode(packed_param_0.imode), bmode(packed_param_0.bmode), format(packed_param_0.format), border_val(packed_param_0.border_val) {}
::megdnn::param::WarpPerspective param() const {
return {imode, bmode, format, border_val};
}
};
// clang-format on
// clang-format off
py::class_<AdaptivePooling, std::shared_ptr<AdaptivePooling>, OpDef> AdaptivePoolingInst(m, "AdaptivePooling");
py::enum_<AdaptivePooling::Mode>(AdaptivePoolingInst, "Mode")
.value("MAX", AdaptivePooling::Mode::MAX)
.value("AVERAGE", AdaptivePooling::Mode::AVERAGE)
.value("AVERAGE_COUNT_EXCLUDE_PADDING", AdaptivePooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "MAX") return AdaptivePooling::Mode::MAX;
if (str == "AVERAGE") return AdaptivePooling::Mode::AVERAGE;
if (str == "AVERAGE_COUNT_EXCLUDE_PADDING") return AdaptivePooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, AdaptivePooling::Mode>();
py::enum_<AdaptivePooling::Format>(AdaptivePoolingInst, "Format")
.value("NCHW", AdaptivePooling::Format::NCHW)
.value("NHWC", AdaptivePooling::Format::NHWC)
.value("NHWCD4", AdaptivePooling::Format::NHWCD4)
.value("NCHW4", AdaptivePooling::Format::NCHW4)
.value("NCHW8", AdaptivePooling::Format::NCHW8)
.value("NCHW32", AdaptivePooling::Format::NCHW32)
.value("NCHW88", AdaptivePooling::Format::NCHW88)
.value("NCHW44", AdaptivePooling::Format::NCHW44)
.value("NCHW44_DOT", AdaptivePooling::Format::NCHW44_DOT)
.value("NCHW4_NCHW32", AdaptivePooling::Format::NCHW4_NCHW32)
.value("NCHW32_NCHW4", AdaptivePooling::Format::NCHW32_NCHW4)
.value("NCHW4_NCHW", AdaptivePooling::Format::NCHW4_NCHW)
.value("NHWC_NCHW", AdaptivePooling::Format::NHWC_NCHW)
.value("NHWC_NCHW4_IC_SMALL", AdaptivePooling::Format::NHWC_NCHW4_IC_SMALL)
.value("NCHW_NCHW4_IC_SMALL", AdaptivePooling::Format::NCHW_NCHW4_IC_SMALL)
.value("CHWN4", AdaptivePooling::Format::CHWN4)
.value("NCHW64", AdaptivePooling::Format::NCHW64)
.value("NCHW4_NHWC", AdaptivePooling::Format::NCHW4_NHWC)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NCHW") return AdaptivePooling::Format::NCHW;
if (str == "NHWC") return AdaptivePooling::Format::NHWC;
if (str == "NHWCD4") return AdaptivePooling::Format::NHWCD4;
if (str == "NCHW4") return AdaptivePooling::Format::NCHW4;
if (str == "NCHW8") return AdaptivePooling::Format::NCHW8;
if (str == "NCHW32") return AdaptivePooling::Format::NCHW32;
if (str == "NCHW88") return AdaptivePooling::Format::NCHW88;
if (str == "NCHW44") return AdaptivePooling::Format::NCHW44;
if (str == "NCHW44_DOT") return AdaptivePooling::Format::NCHW44_DOT;
if (str == "NCHW4_NCHW32") return AdaptivePooling::Format::NCHW4_NCHW32;
if (str == "NCHW32_NCHW4") return AdaptivePooling::Format::NCHW32_NCHW4;
if (str == "NCHW4_NCHW") return AdaptivePooling::Format::NCHW4_NCHW;
if (str == "NHWC_NCHW") return AdaptivePooling::Format::NHWC_NCHW;
if (str == "NHWC_NCHW4_IC_SMALL") return AdaptivePooling::Format::NHWC_NCHW4_IC_SMALL;
if (str == "NCHW_NCHW4_IC_SMALL") return AdaptivePooling::Format::NCHW_NCHW4_IC_SMALL;
if (str == "CHWN4") return AdaptivePooling::Format::CHWN4;
if (str == "NCHW64") return AdaptivePooling::Format::NCHW64;
if (str == "NCHW4_NHWC") return AdaptivePooling::Format::NCHW4_NHWC;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, AdaptivePooling::Format>();
AdaptivePoolingInst
.def(py::init<::megdnn::param::AdaptivePooling::Mode, ::megdnn::param::AdaptivePooling::Format, std::vector<int32_t>, std::string>(), py::arg("mode") = ::megdnn::param::AdaptivePooling::Mode::MAX, py::arg("format") = ::megdnn::param::AdaptivePooling::Format::NCHW, py::arg("shape"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("mode", &AdaptivePooling::mode)
.def_readwrite("format", &AdaptivePooling::format)
.def_readwrite("shape", &AdaptivePooling::shape);
py::class_<AddAxis, std::shared_ptr<AddAxis>, OpDef> AddAxisInst(m, "AddAxis");
AddAxisInst
.def(py::init<std::vector<int32_t>, std::string>(), py::arg("axis"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &AddAxis::axis);
py::class_<Argmax, std::shared_ptr<Argmax>, OpDef> ArgmaxInst(m, "Argmax");
ArgmaxInst
.def(py::init<int32_t, std::string>(), py::arg("axis") = 0, py::arg("scope") = {})
.def_readwrite("axis", &Argmax::axis);
py::class_<Argmin, std::shared_ptr<Argmin>, OpDef> ArgminInst(m, "Argmin");
ArgminInst
.def(py::init<int32_t, std::string>(), py::arg("axis") = 0, py::arg("scope") = {})
.def_readwrite("axis", &Argmin::axis);
py::class_<Argsort, std::shared_ptr<Argsort>, OpDef> ArgsortInst(m, "Argsort");
py::enum_<Argsort::Order>(ArgsortInst, "Order")
.value("ASCENDING", Argsort::Order::ASCENDING)
.value("DESCENDING", Argsort::Order::DESCENDING)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "ASCENDING") return Argsort::Order::ASCENDING;
if (str == "DESCENDING") return Argsort::Order::DESCENDING;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Argsort::Order>();
ArgsortInst
.def(py::init<::megdnn::param::Argsort::Order, std::string>(), py::arg("order") = ::megdnn::param::Argsort::Order::ASCENDING, py::arg("scope") = {})
.def_readwrite("order", &Argsort::order);
py::class_<AssertEqual, std::shared_ptr<AssertEqual>, OpDef> AssertEqualInst(m, "AssertEqual");
AssertEqualInst
.def(py::init<float, bool, std::string>(), py::arg("maxerr") = 0.0001, py::arg("verbose") = false, py::arg("scope") = {})
.def_readwrite("maxerr", &AssertEqual::maxerr)
.def_readwrite("verbose", &AssertEqual::verbose);
py::class_<AtlasRuntime, std::shared_ptr<AtlasRuntime>, OpDef> AtlasRuntimeInst(m, "AtlasRuntime");
AtlasRuntimeInst
.def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("buf", &AtlasRuntime::buf)
.def_readwrite("buf_size", &AtlasRuntime::buf_size);
py::class_<Barrier, std::shared_ptr<Barrier>, OpDef> BarrierInst(m, "Barrier");
BarrierInst
.def(py::init<::mgb::CompNode, uint32_t, std::string>(), py::arg("comp_node"), py::arg("nr_outputs"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("comp_node", &Barrier::comp_node)
.def_readwrite("nr_outputs", &Barrier::nr_outputs);
py::class_<BatchConvBias, std::shared_ptr<BatchConvBias>, OpDef> BatchConvBiasInst(m, "BatchConvBias");
py::enum_<BatchConvBias::NonlineMode>(BatchConvBiasInst, "NonlineMode")
.value("IDENTITY", BatchConvBias::NonlineMode::IDENTITY)
.value("RELU", BatchConvBias::NonlineMode::RELU)
.value("SIGMOID", BatchConvBias::NonlineMode::SIGMOID)
.value("H_SWISH", BatchConvBias::NonlineMode::H_SWISH)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "IDENTITY") return BatchConvBias::NonlineMode::IDENTITY;
if (str == "RELU") return BatchConvBias::NonlineMode::RELU;
if (str == "SIGMOID") return BatchConvBias::NonlineMode::SIGMOID;
if (str == "H_SWISH") return BatchConvBias::NonlineMode::H_SWISH;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchConvBias::NonlineMode>();
py::enum_<BatchConvBias::Mode>(BatchConvBiasInst, "Mode")
.value("CROSS_CORRELATION", BatchConvBias::Mode::CROSS_CORRELATION)
.value("CONVOLUTION", BatchConvBias::Mode::CONVOLUTION)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "CROSS_CORRELATION") return BatchConvBias::Mode::CROSS_CORRELATION;
if (str == "CONVOLUTION") return BatchConvBias::Mode::CONVOLUTION;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchConvBias::Mode>();
py::enum_<BatchConvBias::Sparse>(BatchConvBiasInst, "Sparse")
.value("DENSE", BatchConvBias::Sparse::DENSE)
.value("GROUP", BatchConvBias::Sparse::GROUP)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "DENSE") return BatchConvBias::Sparse::DENSE;
if (str == "GROUP") return BatchConvBias::Sparse::GROUP;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchConvBias::Sparse>();
BatchConvBiasInst.attr("Format") = AdaptivePoolingInst.attr("Format");
py::enum_<BatchConvBias::ComputeMode>(BatchConvBiasInst, "ComputeMode")
.value("DEFAULT", BatchConvBias::ComputeMode::DEFAULT)
.value("FLOAT32", BatchConvBias::ComputeMode::FLOAT32)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "DEFAULT") return BatchConvBias::ComputeMode::DEFAULT;
if (str == "FLOAT32") return BatchConvBias::ComputeMode::FLOAT32;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchConvBias::ComputeMode>();
py::enum_<BatchConvBias::Strategy>(BatchConvBiasInst, "Strategy")
.value("HEURISTIC", BatchConvBias::Strategy::HEURISTIC)
.value("PROFILE", BatchConvBias::Strategy::PROFILE)
.value("REPRODUCIBLE", BatchConvBias::Strategy::REPRODUCIBLE)
.value("OPTIMIZED", BatchConvBias::Strategy::OPTIMIZED)
.def("__or__", [](BatchConvBias::Strategy s0, BatchConvBias::Strategy s1) {
return static_cast<BatchConvBias::Strategy>(uint32_t(s0) | uint32_t(s1));
})
.def("__and__", [](BatchConvBias::Strategy s0, BatchConvBias::Strategy s1) {
return static_cast<BatchConvBias::Strategy>(uint32_t(s0) & uint32_t(s1));
})
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "HEURISTIC") return BatchConvBias::Strategy::HEURISTIC;
if (str == "PROFILE") return BatchConvBias::Strategy::PROFILE;
if (str == "REPRODUCIBLE") return BatchConvBias::Strategy::REPRODUCIBLE;
if (str == "OPTIMIZED") return BatchConvBias::Strategy::OPTIMIZED;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchConvBias::Strategy>();
BatchConvBiasInst
.def(py::init<::megdnn::param::BatchConvBias::NonlineMode, ::megdnn::param::BatchConvBias::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::BatchConvBias::Sparse, ::megdnn::param::BatchConvBias::Format, ::megdnn::param::BatchConvBias::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("nonlineMode") = ::megdnn::param::BatchConvBias::NonlineMode::IDENTITY, py::arg("mode") = ::megdnn::param::BatchConvBias::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::BatchConvBias::Sparse::DENSE, py::arg("format") = ::megdnn::param::BatchConvBias::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::BatchConvBias::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("nonlineMode", &BatchConvBias::nonlineMode)
.def_readwrite("mode", &BatchConvBias::mode)
.def_readwrite("pad_h", &BatchConvBias::pad_h)
.def_readwrite("pad_w", &BatchConvBias::pad_w)
.def_readwrite("stride_h", &BatchConvBias::stride_h)
.def_readwrite("stride_w", &BatchConvBias::stride_w)
.def_readwrite("dilate_h", &BatchConvBias::dilate_h)
.def_readwrite("dilate_w", &BatchConvBias::dilate_w)
.def_readwrite("sparse", &BatchConvBias::sparse)
.def_readwrite("format", &BatchConvBias::format)
.def_readwrite("compute_mode", &BatchConvBias::compute_mode)
.def_readwrite("strategy", &BatchConvBias::strategy)
.def_readwrite("workspace_limit", &BatchConvBias::workspace_limit)
.def_readwrite("dtype", &BatchConvBias::dtype);
py::class_<BatchNorm, std::shared_ptr<BatchNorm>, OpDef> BatchNormInst(m, "BatchNorm");
py::enum_<BatchNorm::ParamDim>(BatchNormInst, "ParamDim")
.value("DIM_11HW", BatchNorm::ParamDim::DIM_11HW)
.value("DIM_1CHW", BatchNorm::ParamDim::DIM_1CHW)
.value("DIM_1C11", BatchNorm::ParamDim::DIM_1C11)
.value("DIM_111C", BatchNorm::ParamDim::DIM_111C)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "DIM_11HW") return BatchNorm::ParamDim::DIM_11HW;
if (str == "DIM_1CHW") return BatchNorm::ParamDim::DIM_1CHW;
if (str == "DIM_1C11") return BatchNorm::ParamDim::DIM_1C11;
if (str == "DIM_111C") return BatchNorm::ParamDim::DIM_111C;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchNorm::ParamDim>();
py::enum_<BatchNorm::FwdMode>(BatchNormInst, "FwdMode")
.value("TRAINING", BatchNorm::FwdMode::TRAINING)
.value("INFERENCE", BatchNorm::FwdMode::INFERENCE)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "TRAINING") return BatchNorm::FwdMode::TRAINING;
if (str == "INFERENCE") return BatchNorm::FwdMode::INFERENCE;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchNorm::FwdMode>();
BatchNormInst
.def(py::init<::megdnn::param::BN::ParamDim, ::megdnn::param::BN::FwdMode, double, double, float, float, std::string>(), py::arg("param_dim") = ::megdnn::param::BN::ParamDim::DIM_11HW, py::arg("fwd_mode") = ::megdnn::param::BN::FwdMode::TRAINING, py::arg("epsilon") = 1e-4f, py::arg("avg_factor") = 1.f, py::arg("scale") = 1.f, py::arg("bias") = 0.f, py::arg("scope") = {})
.def_readwrite("param_dim", &BatchNorm::param_dim)
.def_readwrite("fwd_mode", &BatchNorm::fwd_mode)
.def_readwrite("epsilon", &BatchNorm::epsilon)
.def_readwrite("avg_factor", &BatchNorm::avg_factor)
.def_readwrite("scale", &BatchNorm::scale)
.def_readwrite("bias", &BatchNorm::bias);
py::class_<BatchNormBackward, std::shared_ptr<BatchNormBackward>, OpDef> BatchNormBackwardInst(m, "BatchNormBackward");
BatchNormBackwardInst.attr("ParamDim") = BatchNormInst.attr("ParamDim");
BatchNormBackwardInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
BatchNormBackwardInst
.def(py::init<::megdnn::param::BN::ParamDim, ::megdnn::param::BN::FwdMode, double, double, float, float, std::string>(), py::arg("param_dim") = ::megdnn::param::BN::ParamDim::DIM_11HW, py::arg("fwd_mode") = ::megdnn::param::BN::FwdMode::TRAINING, py::arg("epsilon") = 1e-4f, py::arg("avg_factor") = 1.f, py::arg("scale") = 1.f, py::arg("bias") = 0.f, py::arg("scope") = {})
.def_readwrite("param_dim", &BatchNormBackward::param_dim)
.def_readwrite("fwd_mode", &BatchNormBackward::fwd_mode)
.def_readwrite("epsilon", &BatchNormBackward::epsilon)
.def_readwrite("avg_factor", &BatchNormBackward::avg_factor)
.def_readwrite("scale", &BatchNormBackward::scale)
.def_readwrite("bias", &BatchNormBackward::bias);
py::class_<BatchedIncrMeshIndexing, std::shared_ptr<BatchedIncrMeshIndexing>, OpDef> BatchedIncrMeshIndexingInst(m, "BatchedIncrMeshIndexing");
BatchedIncrMeshIndexingInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &BatchedIncrMeshIndexing::items);
py::class_<BatchedMatrixMul, std::shared_ptr<BatchedMatrixMul>, OpDef> BatchedMatrixMulInst(m, "BatchedMatrixMul");
py::enum_<BatchedMatrixMul::ComputeMode>(BatchedMatrixMulInst, "ComputeMode")
.value("DEFAULT", BatchedMatrixMul::ComputeMode::DEFAULT)
.value("FLOAT32", BatchedMatrixMul::ComputeMode::FLOAT32)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "DEFAULT") return BatchedMatrixMul::ComputeMode::DEFAULT;
if (str == "FLOAT32") return BatchedMatrixMul::ComputeMode::FLOAT32;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchedMatrixMul::ComputeMode>();
py::enum_<BatchedMatrixMul::Format>(BatchedMatrixMulInst, "Format")
.value("DEFAULT", BatchedMatrixMul::Format::DEFAULT)
.value("MK4", BatchedMatrixMul::Format::MK4)
.value("MK8", BatchedMatrixMul::Format::MK8)
.value("MK4_DOT", BatchedMatrixMul::Format::MK4_DOT)
.value("N32K4_DOT", BatchedMatrixMul::Format::N32K4_DOT)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "DEFAULT") return BatchedMatrixMul::Format::DEFAULT;
if (str == "MK4") return BatchedMatrixMul::Format::MK4;
if (str == "MK8") return BatchedMatrixMul::Format::MK8;
if (str == "MK4_DOT") return BatchedMatrixMul::Format::MK4_DOT;
if (str == "N32K4_DOT") return BatchedMatrixMul::Format::N32K4_DOT;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, BatchedMatrixMul::Format>();
BatchedMatrixMulInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
BatchedMatrixMulInst
.def(py::init<bool, bool, ::megdnn::param::MatrixMul::ComputeMode, ::megdnn::param::MatrixMul::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, uint32_t, uint32_t, std::string>(), py::arg("transposeA") = false, py::arg("transposeB") = false, py::arg("compute_mode") = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT, py::arg("format") = ::megdnn::param::MatrixMul::Format::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dimA"), py::arg("dimB"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("transposeA", &BatchedMatrixMul::transposeA)
.def_readwrite("transposeB", &BatchedMatrixMul::transposeB)
.def_readwrite("compute_mode", &BatchedMatrixMul::compute_mode)
.def_readwrite("format", &BatchedMatrixMul::format)
.def_readwrite("strategy", &BatchedMatrixMul::strategy)
.def_readwrite("workspace_limit", &BatchedMatrixMul::workspace_limit)
.def_readwrite("dimA", &BatchedMatrixMul::dimA)
.def_readwrite("dimB", &BatchedMatrixMul::dimB);
py::class_<BatchedMeshIndexing, std::shared_ptr<BatchedMeshIndexing>, OpDef> BatchedMeshIndexingInst(m, "BatchedMeshIndexing");
BatchedMeshIndexingInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &BatchedMeshIndexing::items);
py::class_<BatchedSetMeshIndexing, std::shared_ptr<BatchedSetMeshIndexing>, OpDef> BatchedSetMeshIndexingInst(m, "BatchedSetMeshIndexing");
BatchedSetMeshIndexingInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &BatchedSetMeshIndexing::items);
py::class_<BetaRNG, std::shared_ptr<BetaRNG>, OpDef> BetaRNGInst(m, "BetaRNG");
BetaRNGInst
.def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("seed", &BetaRNG::seed)
.def_readwrite("handle", &BetaRNG::handle);
py::class_<Borrow, std::shared_ptr<Borrow>, OpDef> BorrowInst(m, "Borrow");
BorrowInst
.def(py::init<::mgb::CompNode, std::string>(), py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("comp_node", &Borrow::comp_node);
py::class_<Broadcast, std::shared_ptr<Broadcast>, OpDef> BroadcastInst(m, "Broadcast");
BroadcastInst
.def(py::init<std::vector<int32_t>, std::string>(), py::arg("shape"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("shape", &Broadcast::shape);
py::class_<CambriconRuntime, std::shared_ptr<CambriconRuntime>, OpDef> CambriconRuntimeInst(m, "CambriconRuntime");
CambriconRuntimeInst
.def(py::init<std::string, size_t, std::string, bool, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("symbol"), py::arg("tensor_dim_mutable"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("buf", &CambriconRuntime::buf)
.def_readwrite("buf_size", &CambriconRuntime::buf_size)
.def_readwrite("symbol", &CambriconRuntime::symbol)
.def_readwrite("tensor_dim_mutable", &CambriconRuntime::tensor_dim_mutable);
py::class_<CheckNonFinite, std::shared_ptr<CheckNonFinite>, OpDef> CheckNonFiniteInst(m, "CheckNonFinite");
CheckNonFiniteInst
.def(py::init<float, std::string>(), py::arg("scale") = 1.0, py::arg("scope") = {})
.def_readwrite("scale", &CheckNonFinite::scale);
py::class_<CollectiveComm, std::shared_ptr<CollectiveComm>, OpDef> CollectiveCommInst(m, "CollectiveComm");
py::enum_<CollectiveComm::Mode>(CollectiveCommInst, "Mode")
.value("REDUCE_SUM", CollectiveComm::Mode::REDUCE_SUM)
.value("BROADCAST", CollectiveComm::Mode::BROADCAST)
.value("ALL_GATHER", CollectiveComm::Mode::ALL_GATHER)
.value("REDUCE_SCATTER_SUM", CollectiveComm::Mode::REDUCE_SCATTER_SUM)
.value("ALL_REDUCE_SUM", CollectiveComm::Mode::ALL_REDUCE_SUM)
.value("ALL_REDUCE_MAX", CollectiveComm::Mode::ALL_REDUCE_MAX)
.value("ALL_REDUCE_MIN", CollectiveComm::Mode::ALL_REDUCE_MIN)
.value("ALL_REDUCE_PROD", CollectiveComm::Mode::ALL_REDUCE_PROD)
.value("GATHER", CollectiveComm::Mode::GATHER)
.value("SCATTER", CollectiveComm::Mode::SCATTER)
.value("ALL_TO_ALL", CollectiveComm::Mode::ALL_TO_ALL)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "REDUCE_SUM") return CollectiveComm::Mode::REDUCE_SUM;
if (str == "BROADCAST") return CollectiveComm::Mode::BROADCAST;
if (str == "ALL_GATHER") return CollectiveComm::Mode::ALL_GATHER;
if (str == "REDUCE_SCATTER_SUM") return CollectiveComm::Mode::REDUCE_SCATTER_SUM;
if (str == "ALL_REDUCE_SUM") return CollectiveComm::Mode::ALL_REDUCE_SUM;
if (str == "ALL_REDUCE_MAX") return CollectiveComm::Mode::ALL_REDUCE_MAX;
if (str == "ALL_REDUCE_MIN") return CollectiveComm::Mode::ALL_REDUCE_MIN;
if (str == "ALL_REDUCE_PROD") return CollectiveComm::Mode::ALL_REDUCE_PROD;
if (str == "GATHER") return CollectiveComm::Mode::GATHER;
if (str == "SCATTER") return CollectiveComm::Mode::SCATTER;
if (str == "ALL_TO_ALL") return CollectiveComm::Mode::ALL_TO_ALL;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, CollectiveComm::Mode>();
CollectiveCommInst
.def(py::init<::megdnn::param::CollectiveComm::Mode, std::string, uint32_t, uint32_t, bool, bool, std::string, uint32_t, ::megdnn::DType, std::string, std::string, std::string>(), py::arg("mode") = ::megdnn::param::CollectiveComm::Mode::REDUCE_SUM, py::arg("key"), py::arg("nr_devices"), py::arg("rank"), py::arg("is_root"), py::arg("local_grad"), py::arg("addr"), py::arg("port"), py::arg("dtype"), py::arg("backend"), py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("mode", &CollectiveComm::mode)
.def_readwrite("key", &CollectiveComm::key)
.def_readwrite("nr_devices", &CollectiveComm::nr_devices)
.def_readwrite("rank", &CollectiveComm::rank)
.def_readwrite("is_root", &CollectiveComm::is_root)
.def_readwrite("local_grad", &CollectiveComm::local_grad)
.def_readwrite("addr", &CollectiveComm::addr)
.def_readwrite("port", &CollectiveComm::port)
.def_readwrite("dtype", &CollectiveComm::dtype)
.def_readwrite("backend", &CollectiveComm::backend)
.def_readwrite("comp_node", &CollectiveComm::comp_node);
py::class_<Concat, std::shared_ptr<Concat>, OpDef> ConcatInst(m, "Concat");
ConcatInst
.def(py::init<int32_t, ::mgb::CompNode, std::string>(), py::arg("axis") = 0, py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &Concat::axis)
.def_readwrite("comp_node", &Concat::comp_node);
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef> CondTakeInst(m, "CondTake");
CondTakeInst
.def(py::init<>());
py::class_<ConvBias, std::shared_ptr<ConvBias>, OpDef> ConvBiasInst(m, "ConvBias");
ConvBiasInst.attr("NonlineMode") = BatchConvBiasInst.attr("NonlineMode");
ConvBiasInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
ConvBiasInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
ConvBiasInst.attr("Format") = AdaptivePoolingInst.attr("Format");
ConvBiasInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
ConvBiasInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
ConvBiasInst
.def(py::init<::megdnn::param::ConvBias::NonlineMode, ::megdnn::param::ConvBias::Mode, ::megdnn::param::ConvBias::Sparse, ::megdnn::param::ConvBias::Format, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::ConvBias::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("nonlineMode") = ::megdnn::param::ConvBias::NonlineMode::IDENTITY, py::arg("mode") = ::megdnn::param::ConvBias::Mode::CROSS_CORRELATION, py::arg("sparse") = ::megdnn::param::ConvBias::Sparse::DENSE, py::arg("format") = ::megdnn::param::ConvBias::Format::NCHW, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("compute_mode") = ::megdnn::param::ConvBias::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("nonlineMode", &ConvBias::nonlineMode)
.def_readwrite("mode", &ConvBias::mode)
.def_readwrite("sparse", &ConvBias::sparse)
.def_readwrite("format", &ConvBias::format)
.def_readwrite("pad_h", &ConvBias::pad_h)
.def_readwrite("pad_w", &ConvBias::pad_w)
.def_readwrite("stride_h", &ConvBias::stride_h)
.def_readwrite("stride_w", &ConvBias::stride_w)
.def_readwrite("dilate_h", &ConvBias::dilate_h)
.def_readwrite("dilate_w", &ConvBias::dilate_w)
.def_readwrite("compute_mode", &ConvBias::compute_mode)
.def_readwrite("strategy", &ConvBias::strategy)
.def_readwrite("workspace_limit", &ConvBias::workspace_limit)
.def_readwrite("dtype", &ConvBias::dtype);
py::class_<Convolution, std::shared_ptr<Convolution>, OpDef> ConvolutionInst(m, "Convolution");
ConvolutionInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
ConvolutionInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
ConvolutionInst.attr("Format") = AdaptivePoolingInst.attr("Format");
ConvolutionInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
ConvolutionInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
ConvolutionInst
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
.def_readwrite("mode", &Convolution::mode)
.def_readwrite("pad_h", &Convolution::pad_h)
.def_readwrite("pad_w", &Convolution::pad_w)
.def_readwrite("stride_h", &Convolution::stride_h)
.def_readwrite("stride_w", &Convolution::stride_w)
.def_readwrite("dilate_h", &Convolution::dilate_h)
.def_readwrite("dilate_w", &Convolution::dilate_w)
.def_readwrite("sparse", &Convolution::sparse)
.def_readwrite("format", &Convolution::format)
.def_readwrite("compute_mode", &Convolution::compute_mode)
.def_readwrite("strategy", &Convolution::strategy)
.def_readwrite("workspace_limit", &Convolution::workspace_limit);
py::class_<Convolution3D, std::shared_ptr<Convolution3D>, OpDef> Convolution3DInst(m, "Convolution3D");
py::enum_<Convolution3D::Mode>(Convolution3DInst, "Mode")
.value("CROSS_CORRELATION", Convolution3D::Mode::CROSS_CORRELATION)
.value("CONVOLUTION", Convolution3D::Mode::CONVOLUTION)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "CROSS_CORRELATION") return Convolution3D::Mode::CROSS_CORRELATION;
if (str == "CONVOLUTION") return Convolution3D::Mode::CONVOLUTION;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Convolution3D::Mode>();
py::enum_<Convolution3D::Sparse>(Convolution3DInst, "Sparse")
.value("DENSE", Convolution3D::Sparse::DENSE)
.value("GROUP", Convolution3D::Sparse::GROUP)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "DENSE") return Convolution3D::Sparse::DENSE;
if (str == "GROUP") return Convolution3D::Sparse::GROUP;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Convolution3D::Sparse>();
py::enum_<Convolution3D::DataType>(Convolution3DInst, "DataType")
.value("FLOAT", Convolution3D::DataType::FLOAT)
.value("FLOAT_IO16xC32", Convolution3D::DataType::FLOAT_IO16xC32)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "FLOAT") return Convolution3D::DataType::FLOAT;
if (str == "FLOAT_IO16xC32") return Convolution3D::DataType::FLOAT_IO16xC32;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Convolution3D::DataType>();
py::enum_<Convolution3D::Format>(Convolution3DInst, "Format")
.value("NCDHW", Convolution3D::Format::NCDHW)
.value("NDHWC", Convolution3D::Format::NDHWC)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NCDHW") return Convolution3D::Format::NCDHW;
if (str == "NDHWC") return Convolution3D::Format::NDHWC;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Convolution3D::Format>();
Convolution3DInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
Convolution3DInst
.def(py::init<::megdnn::param::Convolution3D::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution3D::Sparse, ::megdnn::param::Convolution3D::DataType, ::megdnn::param::Convolution3D::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION, py::arg("pad_d") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_d") = 1, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_d") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution3D::Sparse::DENSE, py::arg("data_type") = ::megdnn::param::Convolution3D::DataType::FLOAT, py::arg("format") = ::megdnn::param::Convolution3D::Format::NCDHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
.def_readwrite("mode", &Convolution3D::mode)
.def_readwrite("pad_d", &Convolution3D::pad_d)
.def_readwrite("pad_h", &Convolution3D::pad_h)
.def_readwrite("pad_w", &Convolution3D::pad_w)
.def_readwrite("stride_d", &Convolution3D::stride_d)
.def_readwrite("stride_h", &Convolution3D::stride_h)
.def_readwrite("stride_w", &Convolution3D::stride_w)
.def_readwrite("dilate_d", &Convolution3D::dilate_d)
.def_readwrite("dilate_h", &Convolution3D::dilate_h)
.def_readwrite("dilate_w", &Convolution3D::dilate_w)
.def_readwrite("sparse", &Convolution3D::sparse)
.def_readwrite("data_type", &Convolution3D::data_type)
.def_readwrite("format", &Convolution3D::format)
.def_readwrite("strategy", &Convolution3D::strategy)
.def_readwrite("workspace_limit", &Convolution3D::workspace_limit);
py::class_<Convolution3DBackwardData, std::shared_ptr<Convolution3DBackwardData>, OpDef> Convolution3DBackwardDataInst(m, "Convolution3DBackwardData");
Convolution3DBackwardDataInst.attr("Mode") = Convolution3DInst.attr("Mode");
Convolution3DBackwardDataInst.attr("Sparse") = Convolution3DInst.attr("Sparse");
Convolution3DBackwardDataInst.attr("DataType") = Convolution3DInst.attr("DataType");
Convolution3DBackwardDataInst.attr("Format") = Convolution3DInst.attr("Format");
Convolution3DBackwardDataInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
Convolution3DBackwardDataInst
.def(py::init<::megdnn::param::Convolution3D::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution3D::Sparse, ::megdnn::param::Convolution3D::DataType, ::megdnn::param::Convolution3D::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution3D::Mode::CROSS_CORRELATION, py::arg("pad_d") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_d") = 1, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_d") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution3D::Sparse::DENSE, py::arg("data_type") = ::megdnn::param::Convolution3D::DataType::FLOAT, py::arg("format") = ::megdnn::param::Convolution3D::Format::NCDHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
.def_readwrite("mode", &Convolution3DBackwardData::mode)
.def_readwrite("pad_d", &Convolution3DBackwardData::pad_d)
.def_readwrite("pad_h", &Convolution3DBackwardData::pad_h)
.def_readwrite("pad_w", &Convolution3DBackwardData::pad_w)
.def_readwrite("stride_d", &Convolution3DBackwardData::stride_d)
.def_readwrite("stride_h", &Convolution3DBackwardData::stride_h)
.def_readwrite("stride_w", &Convolution3DBackwardData::stride_w)
.def_readwrite("dilate_d", &Convolution3DBackwardData::dilate_d)
.def_readwrite("dilate_h", &Convolution3DBackwardData::dilate_h)
.def_readwrite("dilate_w", &Convolution3DBackwardData::dilate_w)
.def_readwrite("sparse", &Convolution3DBackwardData::sparse)
.def_readwrite("data_type", &Convolution3DBackwardData::data_type)
.def_readwrite("format", &Convolution3DBackwardData::format)
.def_readwrite("strategy", &Convolution3DBackwardData::strategy)
.def_readwrite("workspace_limit", &Convolution3DBackwardData::workspace_limit);
py::class_<ConvolutionBackwardData, std::shared_ptr<ConvolutionBackwardData>, OpDef> ConvolutionBackwardDataInst(m, "ConvolutionBackwardData");
ConvolutionBackwardDataInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
ConvolutionBackwardDataInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
ConvolutionBackwardDataInst.attr("Format") = AdaptivePoolingInst.attr("Format");
ConvolutionBackwardDataInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
ConvolutionBackwardDataInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
ConvolutionBackwardDataInst
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, ::megdnn::DType, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dtype"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("mode", &ConvolutionBackwardData::mode)
.def_readwrite("pad_h", &ConvolutionBackwardData::pad_h)
.def_readwrite("pad_w", &ConvolutionBackwardData::pad_w)
.def_readwrite("stride_h", &ConvolutionBackwardData::stride_h)
.def_readwrite("stride_w", &ConvolutionBackwardData::stride_w)
.def_readwrite("dilate_h", &ConvolutionBackwardData::dilate_h)
.def_readwrite("dilate_w", &ConvolutionBackwardData::dilate_w)
.def_readwrite("sparse", &ConvolutionBackwardData::sparse)
.def_readwrite("format", &ConvolutionBackwardData::format)
.def_readwrite("compute_mode", &ConvolutionBackwardData::compute_mode)
.def_readwrite("strategy", &ConvolutionBackwardData::strategy)
.def_readwrite("workspace_limit", &ConvolutionBackwardData::workspace_limit)
.def_readwrite("dtype", &ConvolutionBackwardData::dtype);
py::class_<Copy, std::shared_ptr<Copy>, OpDef> CopyInst(m, "Copy");
CopyInst
.def(py::init<::mgb::CompNode, std::string>(), py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("comp_node", &Copy::comp_node);
py::class_<Correlation, std::shared_ptr<Correlation>, OpDef> CorrelationInst(m, "Correlation");
py::enum_<Correlation::Format>(CorrelationInst, "Format")
.value("NCHW", Correlation::Format::NCHW)
.value("NHWC", Correlation::Format::NHWC)
.value("NHWCD4", Correlation::Format::NHWCD4)
.value("NCHW4", Correlation::Format::NCHW4)
.value("NCHW8", Correlation::Format::NCHW8)
.value("NCHW32", Correlation::Format::NCHW32)
.value("NCHW88", Correlation::Format::NCHW88)
.value("NCHW44", Correlation::Format::NCHW44)
.value("NCHW44_DOT", Correlation::Format::NCHW44_DOT)
.value("NCHW_WINOGRAD", Correlation::Format::NCHW_WINOGRAD)
.value("NCHW88_WINOGRAD", Correlation::Format::NCHW88_WINOGRAD)
.value("NCHW44_WINOGRAD", Correlation::Format::NCHW44_WINOGRAD)
.value("NCHW4_NCHW32", Correlation::Format::NCHW4_NCHW32)
.value("NCHW32_NCHW4", Correlation::Format::NCHW32_NCHW4)
.value("NCHW4_NCHW", Correlation::Format::NCHW4_NCHW)
.value("NHWC_NCHW", Correlation::Format::NHWC_NCHW)
.value("NHWC_NCHW4_IC_SMALL", Correlation::Format::NHWC_NCHW4_IC_SMALL)
.value("NCHW_NCHW4_IC_SMALL", Correlation::Format::NCHW_NCHW4_IC_SMALL)
.value("CHWN4", Correlation::Format::CHWN4)
.value("NCHW4_NHWC", Correlation::Format::NCHW4_NHWC)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NCHW") return Correlation::Format::NCHW;
if (str == "NHWC") return Correlation::Format::NHWC;
if (str == "NHWCD4") return Correlation::Format::NHWCD4;
if (str == "NCHW4") return Correlation::Format::NCHW4;
if (str == "NCHW8") return Correlation::Format::NCHW8;
if (str == "NCHW32") return Correlation::Format::NCHW32;
if (str == "NCHW88") return Correlation::Format::NCHW88;
if (str == "NCHW44") return Correlation::Format::NCHW44;
if (str == "NCHW44_DOT") return Correlation::Format::NCHW44_DOT;
if (str == "NCHW_WINOGRAD") return Correlation::Format::NCHW_WINOGRAD;
if (str == "NCHW88_WINOGRAD") return Correlation::Format::NCHW88_WINOGRAD;
if (str == "NCHW44_WINOGRAD") return Correlation::Format::NCHW44_WINOGRAD;
if (str == "NCHW4_NCHW32") return Correlation::Format::NCHW4_NCHW32;
if (str == "NCHW32_NCHW4") return Correlation::Format::NCHW32_NCHW4;
if (str == "NCHW4_NCHW") return Correlation::Format::NCHW4_NCHW;
if (str == "NHWC_NCHW") return Correlation::Format::NHWC_NCHW;
if (str == "NHWC_NCHW4_IC_SMALL") return Correlation::Format::NHWC_NCHW4_IC_SMALL;
if (str == "NCHW_NCHW4_IC_SMALL") return Correlation::Format::NCHW_NCHW4_IC_SMALL;
if (str == "CHWN4") return Correlation::Format::CHWN4;
if (str == "NCHW4_NHWC") return Correlation::Format::NCHW4_NHWC;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Correlation::Format>();
CorrelationInst
.def(py::init<::megdnn::param::Correlation::Format, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, bool, std::string>(), py::arg("format") = ::megdnn::param::Correlation::Format::NCHW, py::arg("kernel_size") = 1, py::arg("max_displacement") = 1, py::arg("stride1") = 1, py::arg("stride2") = 1, py::arg("pad_size") = 0, py::arg("is_multiply") = true, py::arg("scope") = {})
.def_readwrite("format", &Correlation::format)
.def_readwrite("kernel_size", &Correlation::kernel_size)
.def_readwrite("max_displacement", &Correlation::max_displacement)
.def_readwrite("stride1", &Correlation::stride1)
.def_readwrite("stride2", &Correlation::stride2)
.def_readwrite("pad_size", &Correlation::pad_size)
.def_readwrite("is_multiply", &Correlation::is_multiply);
py::class_<Cumsum, std::shared_ptr<Cumsum>, OpDef> CumsumInst(m, "Cumsum");
CumsumInst
.def(py::init<int32_t, bool, bool, std::string>(), py::arg("axis") = 2147483647, py::arg("exclusive") = true, py::arg("reverse") = false, py::arg("scope") = {})
.def_readwrite("axis", &Cumsum::axis)
.def_readwrite("exclusive", &Cumsum::exclusive)
.def_readwrite("reverse", &Cumsum::reverse);
py::class_<CvtColor, std::shared_ptr<CvtColor>, OpDef> CvtColorInst(m, "CvtColor");
py::enum_<CvtColor::Mode>(CvtColorInst, "Mode")
.value("RGB2GRAY", CvtColor::Mode::RGB2GRAY)
.value("RGB2YUV", CvtColor::Mode::RGB2YUV)
.value("YUV2RGB", CvtColor::Mode::YUV2RGB)
.value("GRAY2RGB", CvtColor::Mode::GRAY2RGB)
.value("RGBA2RGB", CvtColor::Mode::RGBA2RGB)
.value("RGBA2BGR", CvtColor::Mode::RGBA2BGR)
.value("RGBA2GRAY", CvtColor::Mode::RGBA2GRAY)
.value("RGB2BGR", CvtColor::Mode::RGB2BGR)
.value("BGR2GRAY", CvtColor::Mode::BGR2GRAY)
.value("BGR2RGB", CvtColor::Mode::BGR2RGB)
.value("YUV2GRAY_NV21", CvtColor::Mode::YUV2GRAY_NV21)
.value("YUV2RGB_NV21", CvtColor::Mode::YUV2RGB_NV21)
.value("YUV2BGR_NV21", CvtColor::Mode::YUV2BGR_NV21)
.value("YUV2GRAY_NV12", CvtColor::Mode::YUV2GRAY_NV12)
.value("YUV2RGB_NV12", CvtColor::Mode::YUV2RGB_NV12)
.value("YUV2BGR_NV12", CvtColor::Mode::YUV2BGR_NV12)
.value("YUV2GRAY_YV12", CvtColor::Mode::YUV2GRAY_YV12)
.value("YUV2RGB_YV12", CvtColor::Mode::YUV2RGB_YV12)
.value("YUV2BGR_YV12", CvtColor::Mode::YUV2BGR_YV12)
.value("YUV2GRAY_YU12", CvtColor::Mode::YUV2GRAY_YU12)
.value("YUV2RGB_YU12", CvtColor::Mode::YUV2RGB_YU12)
.value("YUV2BGR_YU12", CvtColor::Mode::YUV2BGR_YU12)
.value("YCrCb2RGB", CvtColor::Mode::YCrCb2RGB)
.value("YCrCb2BGR", CvtColor::Mode::YCrCb2BGR)
.value("BT601_YUV2RGB_NV21", CvtColor::Mode::BT601_YUV2RGB_NV21)
.value("BT601_YUV2BGR_NV21", CvtColor::Mode::BT601_YUV2BGR_NV21)
.value("BT601_YUV2RGB_NV12", CvtColor::Mode::BT601_YUV2RGB_NV12)
.value("BT601_YUV2BGR_NV12", CvtColor::Mode::BT601_YUV2BGR_NV12)
.value("BT601_YUV2RGB_YV12", CvtColor::Mode::BT601_YUV2RGB_YV12)
.value("BT601_YUV2BGR_YV12", CvtColor::Mode::BT601_YUV2BGR_YV12)
.value("BT601_YUV2RGB_YU12", CvtColor::Mode::BT601_YUV2RGB_YU12)
.value("BT601_YUV2BGR_YU12", CvtColor::Mode::BT601_YUV2BGR_YU12)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "RGB2GRAY") return CvtColor::Mode::RGB2GRAY;
if (str == "RGB2YUV") return CvtColor::Mode::RGB2YUV;
if (str == "YUV2RGB") return CvtColor::Mode::YUV2RGB;
if (str == "GRAY2RGB") return CvtColor::Mode::GRAY2RGB;
if (str == "RGBA2RGB") return CvtColor::Mode::RGBA2RGB;
if (str == "RGBA2BGR") return CvtColor::Mode::RGBA2BGR;
if (str == "RGBA2GRAY") return CvtColor::Mode::RGBA2GRAY;
if (str == "RGB2BGR") return CvtColor::Mode::RGB2BGR;
if (str == "BGR2GRAY") return CvtColor::Mode::BGR2GRAY;
if (str == "BGR2RGB") return CvtColor::Mode::BGR2RGB;
if (str == "YUV2GRAY_NV21") return CvtColor::Mode::YUV2GRAY_NV21;
if (str == "YUV2RGB_NV21") return CvtColor::Mode::YUV2RGB_NV21;
if (str == "YUV2BGR_NV21") return CvtColor::Mode::YUV2BGR_NV21;
if (str == "YUV2GRAY_NV12") return CvtColor::Mode::YUV2GRAY_NV12;
if (str == "YUV2RGB_NV12") return CvtColor::Mode::YUV2RGB_NV12;
if (str == "YUV2BGR_NV12") return CvtColor::Mode::YUV2BGR_NV12;
if (str == "YUV2GRAY_YV12") return CvtColor::Mode::YUV2GRAY_YV12;
if (str == "YUV2RGB_YV12") return CvtColor::Mode::YUV2RGB_YV12;
if (str == "YUV2BGR_YV12") return CvtColor::Mode::YUV2BGR_YV12;
if (str == "YUV2GRAY_YU12") return CvtColor::Mode::YUV2GRAY_YU12;
if (str == "YUV2RGB_YU12") return CvtColor::Mode::YUV2RGB_YU12;
if (str == "YUV2BGR_YU12") return CvtColor::Mode::YUV2BGR_YU12;
if (str == "YCrCb2RGB") return CvtColor::Mode::YCrCb2RGB;
if (str == "YCrCb2BGR") return CvtColor::Mode::YCrCb2BGR;
if (str == "BT601_YUV2RGB_NV21") return CvtColor::Mode::BT601_YUV2RGB_NV21;
if (str == "BT601_YUV2BGR_NV21") return CvtColor::Mode::BT601_YUV2BGR_NV21;
if (str == "BT601_YUV2RGB_NV12") return CvtColor::Mode::BT601_YUV2RGB_NV12;
if (str == "BT601_YUV2BGR_NV12") return CvtColor::Mode::BT601_YUV2BGR_NV12;
if (str == "BT601_YUV2RGB_YV12") return CvtColor::Mode::BT601_YUV2RGB_YV12;
if (str == "BT601_YUV2BGR_YV12") return CvtColor::Mode::BT601_YUV2BGR_YV12;
if (str == "BT601_YUV2RGB_YU12") return CvtColor::Mode::BT601_YUV2RGB_YU12;
if (str == "BT601_YUV2BGR_YU12") return CvtColor::Mode::BT601_YUV2BGR_YU12;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, CvtColor::Mode>();
CvtColorInst
.def(py::init<::megdnn::param::CvtColor::Mode, std::string>(), py::arg("mode") = ::megdnn::param::CvtColor::Mode::RGB2GRAY, py::arg("scope") = {})
.def_readwrite("mode", &CvtColor::mode);
py::class_<DeformableConv, std::shared_ptr<DeformableConv>, OpDef> DeformableConvInst(m, "DeformableConv");
DeformableConvInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
DeformableConvInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
DeformableConvInst.attr("Format") = AdaptivePoolingInst.attr("Format");
DeformableConvInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
DeformableConvInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
DeformableConvInst
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
.def_readwrite("mode", &DeformableConv::mode)
.def_readwrite("pad_h", &DeformableConv::pad_h)
.def_readwrite("pad_w", &DeformableConv::pad_w)
.def_readwrite("stride_h", &DeformableConv::stride_h)
.def_readwrite("stride_w", &DeformableConv::stride_w)
.def_readwrite("dilate_h", &DeformableConv::dilate_h)
.def_readwrite("dilate_w", &DeformableConv::dilate_w)
.def_readwrite("sparse", &DeformableConv::sparse)
.def_readwrite("format", &DeformableConv::format)
.def_readwrite("compute_mode", &DeformableConv::compute_mode)
.def_readwrite("strategy", &DeformableConv::strategy)
.def_readwrite("workspace_limit", &DeformableConv::workspace_limit);
py::class_<DeformablePSROIPooling, std::shared_ptr<DeformablePSROIPooling>, OpDef> DeformablePSROIPoolingInst(m, "DeformablePSROIPooling");
DeformablePSROIPoolingInst
.def(py::init<bool, float, float, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("no_trans") = true, py::arg("spatial_scale") = 1, py::arg("trans_std") = 1, py::arg("pooled_h") = 1, py::arg("pooled_w") = 1, py::arg("part_size") = 1, py::arg("sample_per_part") = 1, py::arg("scope") = {})
.def_readwrite("no_trans", &DeformablePSROIPooling::no_trans)
.def_readwrite("spatial_scale", &DeformablePSROIPooling::spatial_scale)
.def_readwrite("trans_std", &DeformablePSROIPooling::trans_std)
.def_readwrite("pooled_h", &DeformablePSROIPooling::pooled_h)
.def_readwrite("pooled_w", &DeformablePSROIPooling::pooled_w)
.def_readwrite("part_size", &DeformablePSROIPooling::part_size)
.def_readwrite("sample_per_part", &DeformablePSROIPooling::sample_per_part);
py::class_<Diag, std::shared_ptr<Diag>, OpDef> DiagInst(m, "Diag");
DiagInst
.def(py::init<int32_t, std::string>(), py::arg("k") = 0, py::arg("scope") = {})
.def_readwrite("k", &Diag::k);
py::class_<Dimshuffle, std::shared_ptr<Dimshuffle>, OpDef> DimshuffleInst(m, "Dimshuffle");
DimshuffleInst
.def(py::init<std::vector<int32_t>, std::string>(), py::arg("pattern"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("pattern", &Dimshuffle::pattern);
py::class_<Dot, std::shared_ptr<Dot>, OpDef> DotInst(m, "Dot");
DotInst
.def(py::init<>());
py::class_<Dropout, std::shared_ptr<Dropout>, OpDef> DropoutInst(m, "Dropout");
DropoutInst
.def(py::init<float, uint64_t, size_t, std::string>(), py::arg("drop_prob") = 0, py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("drop_prob", &Dropout::drop_prob)
.def_readwrite("seed", &Dropout::seed)
.def_readwrite("handle", &Dropout::handle);
py::class_<Elemwise, std::shared_ptr<Elemwise>, OpDef> ElemwiseInst(m, "Elemwise");
py::enum_<Elemwise::Mode>(ElemwiseInst, "Mode")
.value("RELU", Elemwise::Mode::RELU)
.value("ABS", Elemwise::Mode::ABS)
.value("ACOS", Elemwise::Mode::ACOS)
.value("ASIN", Elemwise::Mode::ASIN)
.value("CEIL", Elemwise::Mode::CEIL)
.value("COS", Elemwise::Mode::COS)
.value("EXP", Elemwise::Mode::EXP)
.value("EXPM1", Elemwise::Mode::EXPM1)
.value("FLOOR", Elemwise::Mode::FLOOR)
.value("LOG", Elemwise::Mode::LOG)
.value("LOG1P", Elemwise::Mode::LOG1P)
.value("NEGATE", Elemwise::Mode::NEGATE)
.value("SIGMOID", Elemwise::Mode::SIGMOID)
.value("SIN", Elemwise::Mode::SIN)
.value("TANH", Elemwise::Mode::TANH)
.value("ABS_GRAD", Elemwise::Mode::ABS_GRAD)
.value("ADD", Elemwise::Mode::ADD)
.value("FLOOR_DIV", Elemwise::Mode::FLOOR_DIV)
.value("MAX", Elemwise::Mode::MAX)
.value("MIN", Elemwise::Mode::MIN)
.value("MOD", Elemwise::Mode::MOD)
.value("MUL", Elemwise::Mode::MUL)
.value("POW", Elemwise::Mode::POW)
.value("SIGMOID_GRAD", Elemwise::Mode::SIGMOID_GRAD)
.value("SUB", Elemwise::Mode::SUB)
.value("SWITCH_GT0", Elemwise::Mode::SWITCH_GT0)
.value("TANH_GRAD", Elemwise::Mode::TANH_GRAD)
.value("TRUE_DIV", Elemwise::Mode::TRUE_DIV)
.value("LOG_SUM_EXP", Elemwise::Mode::LOG_SUM_EXP)
.value("LT", Elemwise::Mode::LT)
.value("LEQ", Elemwise::Mode::LEQ)
.value("EQ", Elemwise::Mode::EQ)
.value("SHL", Elemwise::Mode::SHL)
.value("SHR", Elemwise::Mode::SHR)
.value("COND_LEQ_MOV", Elemwise::Mode::COND_LEQ_MOV)
.value("FUSE_MUL_ADD3", Elemwise::Mode::FUSE_MUL_ADD3)
.value("FUSE_MUL_ADD4", Elemwise::Mode::FUSE_MUL_ADD4)
.value("FUSE_ADD_RELU", Elemwise::Mode::FUSE_ADD_RELU)
.value("FUSE_ADD_SIGMOID", Elemwise::Mode::FUSE_ADD_SIGMOID)
.value("FUSE_ADD_TANH", Elemwise::Mode::FUSE_ADD_TANH)
.value("FAST_TANH", Elemwise::Mode::FAST_TANH)
.value("FAST_TANH_GRAD", Elemwise::Mode::FAST_TANH_GRAD)
.value("ROUND", Elemwise::Mode::ROUND)
.value("RMULH", Elemwise::Mode::RMULH)
.value("ATAN2", Elemwise::Mode::ATAN2)
.value("ERF", Elemwise::Mode::ERF)
.value("ERFINV", Elemwise::Mode::ERFINV)
.value("ERFC", Elemwise::Mode::ERFC)
.value("ERFCINV", Elemwise::Mode::ERFCINV)
.value("H_SWISH", Elemwise::Mode::H_SWISH)
.value("H_SWISH_GRAD", Elemwise::Mode::H_SWISH_GRAD)
.value("FUSE_ADD_H_SWISH", Elemwise::Mode::FUSE_ADD_H_SWISH)
.value("NOT", Elemwise::Mode::NOT)
.value("AND", Elemwise::Mode::AND)
.value("OR", Elemwise::Mode::OR)
.value("XOR", Elemwise::Mode::XOR)
.value("SILU", Elemwise::Mode::SILU)
.value("SILU_GRAD", Elemwise::Mode::SILU_GRAD)
.value("GELU", Elemwise::Mode::GELU)
.value("GELU_GRAD", Elemwise::Mode::GELU_GRAD)
.value("COND_LT_MOV", Elemwise::Mode::COND_LT_MOV)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "RELU") return Elemwise::Mode::RELU;
if (str == "ABS") return Elemwise::Mode::ABS;
if (str == "ACOS") return Elemwise::Mode::ACOS;
if (str == "ASIN") return Elemwise::Mode::ASIN;
if (str == "CEIL") return Elemwise::Mode::CEIL;
if (str == "COS") return Elemwise::Mode::COS;
if (str == "EXP") return Elemwise::Mode::EXP;
if (str == "EXPM1") return Elemwise::Mode::EXPM1;
if (str == "FLOOR") return Elemwise::Mode::FLOOR;
if (str == "LOG") return Elemwise::Mode::LOG;
if (str == "LOG1P") return Elemwise::Mode::LOG1P;
if (str == "NEGATE") return Elemwise::Mode::NEGATE;
if (str == "SIGMOID") return Elemwise::Mode::SIGMOID;
if (str == "SIN") return Elemwise::Mode::SIN;
if (str == "TANH") return Elemwise::Mode::TANH;
if (str == "ABS_GRAD") return Elemwise::Mode::ABS_GRAD;
if (str == "ADD") return Elemwise::Mode::ADD;
if (str == "FLOOR_DIV") return Elemwise::Mode::FLOOR_DIV;
if (str == "MAX") return Elemwise::Mode::MAX;
if (str == "MIN") return Elemwise::Mode::MIN;
if (str == "MOD") return Elemwise::Mode::MOD;
if (str == "MUL") return Elemwise::Mode::MUL;
if (str == "POW") return Elemwise::Mode::POW;
if (str == "SIGMOID_GRAD") return Elemwise::Mode::SIGMOID_GRAD;
if (str == "SUB") return Elemwise::Mode::SUB;
if (str == "SWITCH_GT0") return Elemwise::Mode::SWITCH_GT0;
if (str == "TANH_GRAD") return Elemwise::Mode::TANH_GRAD;
if (str == "TRUE_DIV") return Elemwise::Mode::TRUE_DIV;
if (str == "LOG_SUM_EXP") return Elemwise::Mode::LOG_SUM_EXP;
if (str == "LT") return Elemwise::Mode::LT;
if (str == "LEQ") return Elemwise::Mode::LEQ;
if (str == "EQ") return Elemwise::Mode::EQ;
if (str == "SHL") return Elemwise::Mode::SHL;
if (str == "SHR") return Elemwise::Mode::SHR;
if (str == "COND_LEQ_MOV") return Elemwise::Mode::COND_LEQ_MOV;
if (str == "FUSE_MUL_ADD3") return Elemwise::Mode::FUSE_MUL_ADD3;
if (str == "FUSE_MUL_ADD4") return Elemwise::Mode::FUSE_MUL_ADD4;
if (str == "FUSE_ADD_RELU") return Elemwise::Mode::FUSE_ADD_RELU;
if (str == "FUSE_ADD_SIGMOID") return Elemwise::Mode::FUSE_ADD_SIGMOID;
if (str == "FUSE_ADD_TANH") return Elemwise::Mode::FUSE_ADD_TANH;
if (str == "FAST_TANH") return Elemwise::Mode::FAST_TANH;
if (str == "FAST_TANH_GRAD") return Elemwise::Mode::FAST_TANH_GRAD;
if (str == "ROUND") return Elemwise::Mode::ROUND;
if (str == "RMULH") return Elemwise::Mode::RMULH;
if (str == "ATAN2") return Elemwise::Mode::ATAN2;
if (str == "ERF") return Elemwise::Mode::ERF;
if (str == "ERFINV") return Elemwise::Mode::ERFINV;
if (str == "ERFC") return Elemwise::Mode::ERFC;
if (str == "ERFCINV") return Elemwise::Mode::ERFCINV;
if (str == "H_SWISH") return Elemwise::Mode::H_SWISH;
if (str == "H_SWISH_GRAD") return Elemwise::Mode::H_SWISH_GRAD;
if (str == "FUSE_ADD_H_SWISH") return Elemwise::Mode::FUSE_ADD_H_SWISH;
if (str == "NOT") return Elemwise::Mode::NOT;
if (str == "AND") return Elemwise::Mode::AND;
if (str == "OR") return Elemwise::Mode::OR;
if (str == "XOR") return Elemwise::Mode::XOR;
if (str == "SILU") return Elemwise::Mode::SILU;
if (str == "SILU_GRAD") return Elemwise::Mode::SILU_GRAD;
if (str == "GELU") return Elemwise::Mode::GELU;
if (str == "GELU_GRAD") return Elemwise::Mode::GELU_GRAD;
if (str == "COND_LT_MOV") return Elemwise::Mode::COND_LT_MOV;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Elemwise::Mode>();
ElemwiseInst
.def(py::init<::megdnn::param::Elemwise::Mode, std::string>(), py::arg("mode") = ::megdnn::param::Elemwise::Mode::RELU, py::arg("scope") = {})
.def_readwrite("mode", &Elemwise::mode);
py::class_<ElemwiseMultiType, std::shared_ptr<ElemwiseMultiType>, OpDef> ElemwiseMultiTypeInst(m, "ElemwiseMultiType");
py::enum_<ElemwiseMultiType::Mode>(ElemwiseMultiTypeInst, "Mode")
.value("FUSE_MUL_ADD3_INT16x32x32x32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32)
.value("FUSE_MUL_ADD3_IXxF32xF32xI8", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8)
.value("ROUND_SHR_SATURATE_IXxI8xI8", ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8)
.value("FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8", ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8)
.value("FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8", ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8)
.value("ROUND_SHR_SATURATE_IXxI8xI16", ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16)
.value("QADD", ElemwiseMultiType::Mode::QADD)
.value("QFUSE_ADD_RELU", ElemwiseMultiType::Mode::QFUSE_ADD_RELU)
.value("QMUL", ElemwiseMultiType::Mode::QMUL)
.value("QMIN", ElemwiseMultiType::Mode::QMIN)
.value("QMAX", ElemwiseMultiType::Mode::QMAX)
.value("QSUB", ElemwiseMultiType::Mode::QSUB)
.value("QTRUE_DIV", ElemwiseMultiType::Mode::QTRUE_DIV)
.value("QFUSE_ADD_SIGMOID", ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID)
.value("QFUSE_ADD_TANH", ElemwiseMultiType::Mode::QFUSE_ADD_TANH)
.value("QRELU", ElemwiseMultiType::Mode::QRELU)
.value("QABS", ElemwiseMultiType::Mode::QABS)
.value("QSIGMOID", ElemwiseMultiType::Mode::QSIGMOID)
.value("QEXP", ElemwiseMultiType::Mode::QEXP)
.value("QTANH", ElemwiseMultiType::Mode::QTANH)
.value("QFUSE_MUL_ADD3", ElemwiseMultiType::Mode::QFUSE_MUL_ADD3)
.value("QFAST_TANH", ElemwiseMultiType::Mode::QFAST_TANH)
.value("QNEGATE", ElemwiseMultiType::Mode::QNEGATE)
.value("QACOS", ElemwiseMultiType::Mode::QACOS)
.value("QASIN", ElemwiseMultiType::Mode::QASIN)
.value("QCEIL", ElemwiseMultiType::Mode::QCEIL)
.value("QCOS", ElemwiseMultiType::Mode::QCOS)
.value("QEXPM1", ElemwiseMultiType::Mode::QEXPM1)
.value("QFLOOR", ElemwiseMultiType::Mode::QFLOOR)
.value("QLOG", ElemwiseMultiType::Mode::QLOG)
.value("QLOG1P", ElemwiseMultiType::Mode::QLOG1P)
.value("QSIN", ElemwiseMultiType::Mode::QSIN)
.value("QROUND", ElemwiseMultiType::Mode::QROUND)
.value("QERF", ElemwiseMultiType::Mode::QERF)
.value("QERFINV", ElemwiseMultiType::Mode::QERFINV)
.value("QERFC", ElemwiseMultiType::Mode::QERFC)
.value("QERFCINV", ElemwiseMultiType::Mode::QERFCINV)
.value("QABS_GRAD", ElemwiseMultiType::Mode::QABS_GRAD)
.value("QFLOOR_DIV", ElemwiseMultiType::Mode::QFLOOR_DIV)
.value("QMOD", ElemwiseMultiType::Mode::QMOD)
.value("QSIGMOID_GRAD", ElemwiseMultiType::Mode::QSIGMOID_GRAD)
.value("QSWITCH_GT0", ElemwiseMultiType::Mode::QSWITCH_GT0)
.value("QTANH_GRAD", ElemwiseMultiType::Mode::QTANH_GRAD)
.value("QLT", ElemwiseMultiType::Mode::QLT)
.value("QLEQ", ElemwiseMultiType::Mode::QLEQ)
.value("QEQ", ElemwiseMultiType::Mode::QEQ)
.value("QPOW", ElemwiseMultiType::Mode::QPOW)
.value("QLOG_SUM_EXP", ElemwiseMultiType::Mode::QLOG_SUM_EXP)
.value("QFAST_TANH_GRAD", ElemwiseMultiType::Mode::QFAST_TANH_GRAD)
.value("QATAN2", ElemwiseMultiType::Mode::QATAN2)
.value("QCOND_LEQ_MOV", ElemwiseMultiType::Mode::QCOND_LEQ_MOV)
.value("QH_SWISH", ElemwiseMultiType::Mode::QH_SWISH)
.value("QFUSE_ADD_H_SWISH", ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH)
.value("QH_SWISH_GRAD", ElemwiseMultiType::Mode::QH_SWISH_GRAD)
.value("FUSE_MUL_ADD3_INT16xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32)
.value("MUL_INT16xF32xF32", ElemwiseMultiType::Mode::MUL_INT16xF32xF32)
.value("FUSE_MUL_ADD3_UINT8xF32xF32xF32", ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32)
.value("QCOND_LT_MOV", ElemwiseMultiType::Mode::QCOND_LT_MOV)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "FUSE_MUL_ADD3_INT16x32x32x32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32;
if (str == "FUSE_MUL_ADD3_IXxF32xF32xI8") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_IXxF32xF32xI8;
if (str == "ROUND_SHR_SATURATE_IXxI8xI8") return ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI8;
if (str == "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8") return ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8;
if (str == "FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8") return ElemwiseMultiType::Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8;
if (str == "ROUND_SHR_SATURATE_IXxI8xI16") return ElemwiseMultiType::Mode::ROUND_SHR_SATURATE_IXxI8xI16;
if (str == "QADD") return ElemwiseMultiType::Mode::QADD;
if (str == "QFUSE_ADD_RELU") return ElemwiseMultiType::Mode::QFUSE_ADD_RELU;
if (str == "QMUL") return ElemwiseMultiType::Mode::QMUL;
if (str == "QMIN") return ElemwiseMultiType::Mode::QMIN;
if (str == "QMAX") return ElemwiseMultiType::Mode::QMAX;
if (str == "QSUB") return ElemwiseMultiType::Mode::QSUB;
if (str == "QTRUE_DIV") return ElemwiseMultiType::Mode::QTRUE_DIV;
if (str == "QFUSE_ADD_SIGMOID") return ElemwiseMultiType::Mode::QFUSE_ADD_SIGMOID;
if (str == "QFUSE_ADD_TANH") return ElemwiseMultiType::Mode::QFUSE_ADD_TANH;
if (str == "QRELU") return ElemwiseMultiType::Mode::QRELU;
if (str == "QABS") return ElemwiseMultiType::Mode::QABS;
if (str == "QSIGMOID") return ElemwiseMultiType::Mode::QSIGMOID;
if (str == "QEXP") return ElemwiseMultiType::Mode::QEXP;
if (str == "QTANH") return ElemwiseMultiType::Mode::QTANH;
if (str == "QFUSE_MUL_ADD3") return ElemwiseMultiType::Mode::QFUSE_MUL_ADD3;
if (str == "QFAST_TANH") return ElemwiseMultiType::Mode::QFAST_TANH;
if (str == "QNEGATE") return ElemwiseMultiType::Mode::QNEGATE;
if (str == "QACOS") return ElemwiseMultiType::Mode::QACOS;
if (str == "QASIN") return ElemwiseMultiType::Mode::QASIN;
if (str == "QCEIL") return ElemwiseMultiType::Mode::QCEIL;
if (str == "QCOS") return ElemwiseMultiType::Mode::QCOS;
if (str == "QEXPM1") return ElemwiseMultiType::Mode::QEXPM1;
if (str == "QFLOOR") return ElemwiseMultiType::Mode::QFLOOR;
if (str == "QLOG") return ElemwiseMultiType::Mode::QLOG;
if (str == "QLOG1P") return ElemwiseMultiType::Mode::QLOG1P;
if (str == "QSIN") return ElemwiseMultiType::Mode::QSIN;
if (str == "QROUND") return ElemwiseMultiType::Mode::QROUND;
if (str == "QERF") return ElemwiseMultiType::Mode::QERF;
if (str == "QERFINV") return ElemwiseMultiType::Mode::QERFINV;
if (str == "QERFC") return ElemwiseMultiType::Mode::QERFC;
if (str == "QERFCINV") return ElemwiseMultiType::Mode::QERFCINV;
if (str == "QABS_GRAD") return ElemwiseMultiType::Mode::QABS_GRAD;
if (str == "QFLOOR_DIV") return ElemwiseMultiType::Mode::QFLOOR_DIV;
if (str == "QMOD") return ElemwiseMultiType::Mode::QMOD;
if (str == "QSIGMOID_GRAD") return ElemwiseMultiType::Mode::QSIGMOID_GRAD;
if (str == "QSWITCH_GT0") return ElemwiseMultiType::Mode::QSWITCH_GT0;
if (str == "QTANH_GRAD") return ElemwiseMultiType::Mode::QTANH_GRAD;
if (str == "QLT") return ElemwiseMultiType::Mode::QLT;
if (str == "QLEQ") return ElemwiseMultiType::Mode::QLEQ;
if (str == "QEQ") return ElemwiseMultiType::Mode::QEQ;
if (str == "QPOW") return ElemwiseMultiType::Mode::QPOW;
if (str == "QLOG_SUM_EXP") return ElemwiseMultiType::Mode::QLOG_SUM_EXP;
if (str == "QFAST_TANH_GRAD") return ElemwiseMultiType::Mode::QFAST_TANH_GRAD;
if (str == "QATAN2") return ElemwiseMultiType::Mode::QATAN2;
if (str == "QCOND_LEQ_MOV") return ElemwiseMultiType::Mode::QCOND_LEQ_MOV;
if (str == "QH_SWISH") return ElemwiseMultiType::Mode::QH_SWISH;
if (str == "QFUSE_ADD_H_SWISH") return ElemwiseMultiType::Mode::QFUSE_ADD_H_SWISH;
if (str == "QH_SWISH_GRAD") return ElemwiseMultiType::Mode::QH_SWISH_GRAD;
if (str == "FUSE_MUL_ADD3_INT16xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16xF32xF32xF32;
if (str == "MUL_INT16xF32xF32") return ElemwiseMultiType::Mode::MUL_INT16xF32xF32;
if (str == "FUSE_MUL_ADD3_UINT8xF32xF32xF32") return ElemwiseMultiType::Mode::FUSE_MUL_ADD3_UINT8xF32xF32xF32;
if (str == "QCOND_LT_MOV") return ElemwiseMultiType::Mode::QCOND_LT_MOV;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, ElemwiseMultiType::Mode>();
ElemwiseMultiTypeInst
.def(py::init<::megdnn::param::ElemwiseMultiType::Mode, ::megdnn::DType, std::string>(), py::arg("mode") = ::megdnn::param::ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32, py::arg("dtype"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("mode", &ElemwiseMultiType::mode)
.def_readwrite("dtype", &ElemwiseMultiType::dtype);
py::class_<ExternOpr, std::shared_ptr<ExternOpr>, OpDef> ExternOprInst(m, "ExternOpr");
ExternOprInst
.def(py::init<std::vector<std::vector<size_t>>, std::string, std::string, size_t, std::vector<::megdnn::DType>, std::string>(), py::arg("output_shapes"), py::arg("name"), py::arg("data"), py::arg("data_len"), py::arg("output_dtypes"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("output_shapes", &ExternOpr::output_shapes)
.def_readwrite("name", &ExternOpr::name)
.def_readwrite("data", &ExternOpr::data)
.def_readwrite("data_len", &ExternOpr::data_len)
.def_readwrite("output_dtypes", &ExternOpr::output_dtypes);
py::class_<Eye, std::shared_ptr<Eye>, OpDef> EyeInst(m, "Eye");
EyeInst
.def(py::init<int32_t, ::megdnn::DType, ::mgb::CompNode, std::string>(), py::arg("k") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("k", &Eye::k)
.def_readwrite("dtype", &Eye::dtype)
.def_readwrite("comp_node", &Eye::comp_node);
py::class_<FakeQuant, std::shared_ptr<FakeQuant>, OpDef> FakeQuantInst(m, "FakeQuant");
FakeQuantInst
.def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
.def_readwrite("qmin", &FakeQuant::qmin)
.def_readwrite("qmax", &FakeQuant::qmax);
py::class_<FastpathCopy, std::shared_ptr<FastpathCopy>, OpDef> FastpathCopyInst(m, "FastpathCopy");
FastpathCopyInst
.def(py::init<>());
py::class_<GammaRNG, std::shared_ptr<GammaRNG>, OpDef> GammaRNGInst(m, "GammaRNG");
GammaRNGInst
.def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("seed", &GammaRNG::seed)
.def_readwrite("handle", &GammaRNG::handle);
py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef> GaussianRNGInst(m, "GaussianRNG");
GaussianRNGInst
.def(py::init<uint64_t, float, float, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("mean") = 0, py::arg("std") = 1, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("seed", &GaussianRNG::seed)
.def_readwrite("mean", &GaussianRNG::mean)
.def_readwrite("std", &GaussianRNG::std)
.def_readwrite("dtype", &GaussianRNG::dtype)
.def_readwrite("handle", &GaussianRNG::handle);
py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef> GetVarShapeInst(m, "GetVarShape");
GetVarShapeInst
.def(py::init<int32_t, std::string>(), py::arg("axis") = ::megdnn::param::OptionalAxisV1::INVALID_AXIS, py::arg("scope") = {})
.def_readwrite("axis", &GetVarShape::axis);
py::class_<GroupLocal, std::shared_ptr<GroupLocal>, OpDef> GroupLocalInst(m, "GroupLocal");
GroupLocalInst.attr("Mode") = BatchConvBiasInst.attr("Mode");
GroupLocalInst.attr("Sparse") = BatchConvBiasInst.attr("Sparse");
GroupLocalInst.attr("Format") = AdaptivePoolingInst.attr("Format");
GroupLocalInst.attr("ComputeMode") = BatchConvBiasInst.attr("ComputeMode");
GroupLocalInst
.def(py::init<::megdnn::param::Convolution::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Convolution::Sparse, ::megdnn::param::Convolution::Format, ::megdnn::param::Convolution::ComputeMode, std::string>(), py::arg("mode") = ::megdnn::param::Convolution::Mode::CROSS_CORRELATION, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("sparse") = ::megdnn::param::Convolution::Sparse::DENSE, py::arg("format") = ::megdnn::param::Convolution::Format::NCHW, py::arg("compute_mode") = ::megdnn::param::Convolution::ComputeMode::DEFAULT, py::arg("scope") = {})
.def_readwrite("mode", &GroupLocal::mode)
.def_readwrite("pad_h", &GroupLocal::pad_h)
.def_readwrite("pad_w", &GroupLocal::pad_w)
.def_readwrite("stride_h", &GroupLocal::stride_h)
.def_readwrite("stride_w", &GroupLocal::stride_w)
.def_readwrite("dilate_h", &GroupLocal::dilate_h)
.def_readwrite("dilate_w", &GroupLocal::dilate_w)
.def_readwrite("sparse", &GroupLocal::sparse)
.def_readwrite("format", &GroupLocal::format)
.def_readwrite("compute_mode", &GroupLocal::compute_mode);
py::class_<Identity, std::shared_ptr<Identity>, OpDef> IdentityInst(m, "Identity");
IdentityInst
.def(py::init<>());
py::class_<Images2Neibs, std::shared_ptr<Images2Neibs>, OpDef> Images2NeibsInst(m, "Images2Neibs");
Images2NeibsInst
.def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("window_h") = 3, py::arg("window_w") = 3, py::arg("scope") = {})
.def_readwrite("pad_h", &Images2Neibs::pad_h)
.def_readwrite("pad_w", &Images2Neibs::pad_w)
.def_readwrite("stride_h", &Images2Neibs::stride_h)
.def_readwrite("stride_w", &Images2Neibs::stride_w)
.def_readwrite("dilate_h", &Images2Neibs::dilate_h)
.def_readwrite("dilate_w", &Images2Neibs::dilate_w)
.def_readwrite("window_h", &Images2Neibs::window_h)
.def_readwrite("window_w", &Images2Neibs::window_w);
py::class_<IncrMeshIndexing, std::shared_ptr<IncrMeshIndexing>, OpDef> IncrMeshIndexingInst(m, "IncrMeshIndexing");
IncrMeshIndexingInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &IncrMeshIndexing::items);
py::class_<IncrSubtensor, std::shared_ptr<IncrSubtensor>, OpDef> IncrSubtensorInst(m, "IncrSubtensor");
IncrSubtensorInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &IncrSubtensor::items);
py::class_<IndexingIncrMultiAxisVec, std::shared_ptr<IndexingIncrMultiAxisVec>, OpDef> IndexingIncrMultiAxisVecInst(m, "IndexingIncrMultiAxisVec");
IndexingIncrMultiAxisVecInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &IndexingIncrMultiAxisVec::items);
py::class_<IndexingMultiAxisVec, std::shared_ptr<IndexingMultiAxisVec>, OpDef> IndexingMultiAxisVecInst(m, "IndexingMultiAxisVec");
IndexingMultiAxisVecInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &IndexingMultiAxisVec::items);
py::class_<IndexingOneHot, std::shared_ptr<IndexingOneHot>, OpDef> IndexingOneHotInst(m, "IndexingOneHot");
IndexingOneHotInst
.def(py::init<int32_t, int32_t, std::string>(), py::arg("axis") = 0, py::arg("ndim"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &IndexingOneHot::axis)
.def_readwrite("ndim", &IndexingOneHot::ndim);
py::class_<IndexingSetMultiAxisVec, std::shared_ptr<IndexingSetMultiAxisVec>, OpDef> IndexingSetMultiAxisVecInst(m, "IndexingSetMultiAxisVec");
IndexingSetMultiAxisVecInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &IndexingSetMultiAxisVec::items);
py::class_<IndexingSetOneHot, std::shared_ptr<IndexingSetOneHot>, OpDef> IndexingSetOneHotInst(m, "IndexingSetOneHot");
IndexingSetOneHotInst
.def(py::init<int32_t, int32_t, std::string>(), py::arg("axis") = 0, py::arg("ndim"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &IndexingSetOneHot::axis)
.def_readwrite("ndim", &IndexingSetOneHot::ndim);
py::class_<InplaceAdd, std::shared_ptr<InplaceAdd>, OpDef> InplaceAddInst(m, "InplaceAdd");
InplaceAddInst
.def(py::init<>());
py::class_<LAMBUpdate, std::shared_ptr<LAMBUpdate>, OpDef> LAMBUpdateInst(m, "LAMBUpdate");
LAMBUpdateInst
.def(py::init<float, float, float, float, float, float, bool, bool, std::string>(), py::arg("beta_1") = 1.f, py::arg("beta_2") = 1.f, py::arg("step") = 1.f, py::arg("lr") = 1.f, py::arg("weight_decay") = 1.f, py::arg("eps") = 1.f, py::arg("bias_correction") = true, py::arg("always_adapt") = false, py::arg("scope") = {})
.def_readwrite("beta_1", &LAMBUpdate::beta_1)
.def_readwrite("beta_2", &LAMBUpdate::beta_2)
.def_readwrite("step", &LAMBUpdate::step)
.def_readwrite("lr", &LAMBUpdate::lr)
.def_readwrite("weight_decay", &LAMBUpdate::weight_decay)
.def_readwrite("eps", &LAMBUpdate::eps)
.def_readwrite("bias_correction", &LAMBUpdate::bias_correction)
.def_readwrite("always_adapt", &LAMBUpdate::always_adapt);
py::class_<LRN, std::shared_ptr<LRN>, OpDef> LRNInst(m, "LRN");
LRNInst
.def(py::init<uint32_t, float, float, float, std::string>(), py::arg("n") = 5, py::arg("k") = 2.f, py::arg("alpha") = 1e-4f, py::arg("beta") = 0.75f, py::arg("scope") = {})
.def_readwrite("n", &LRN::n)
.def_readwrite("k", &LRN::k)
.def_readwrite("alpha", &LRN::alpha)
.def_readwrite("beta", &LRN::beta);
py::class_<LSQ, std::shared_ptr<LSQ>, OpDef> LSQInst(m, "LSQ");
LSQInst
.def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
.def_readwrite("qmin", &LSQ::qmin)
.def_readwrite("qmax", &LSQ::qmax);
py::class_<LSTM, std::shared_ptr<LSTM>, OpDef> LSTMInst(m, "LSTM");
LSTMInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
LSTMInst
.def(py::init<uint32_t, bool, bool, uint32_t, uint32_t, float, ::megdnn::param::LSTM::FwdMode, std::string>(), py::arg("num_layers") = 1, py::arg("bidirectional") = false, py::arg("bias") = true, py::arg("hidden_size") = 128, py::arg("proj_size") = 0, py::arg("dropout") = 0.f, py::arg("fwd_mode") = ::megdnn::param::LSTM::FwdMode::TRAINING, py::arg("scope") = {})
.def_readwrite("num_layers", &LSTM::num_layers)
.def_readwrite("bidirectional", &LSTM::bidirectional)
.def_readwrite("bias", &LSTM::bias)
.def_readwrite("hidden_size", &LSTM::hidden_size)
.def_readwrite("proj_size", &LSTM::proj_size)
.def_readwrite("dropout", &LSTM::dropout)
.def_readwrite("fwd_mode", &LSTM::fwd_mode);
py::class_<LSTMCell, std::shared_ptr<LSTMCell>, OpDef> LSTMCellInst(m, "LSTMCell");
LSTMCellInst
.def(py::init<>());
py::class_<LayerNorm, std::shared_ptr<LayerNorm>, OpDef> LayerNormInst(m, "LayerNorm");
LayerNormInst
.def(py::init<bool, float, uint64_t, uint64_t, std::string>(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("normalized_dim") = 1, py::arg("normalized_size") = 1, py::arg("scope") = {})
.def_readwrite("affine", &LayerNorm::affine)
.def_readwrite("eps", &LayerNorm::eps)
.def_readwrite("normalized_dim", &LayerNorm::normalized_dim)
.def_readwrite("normalized_size", &LayerNorm::normalized_size);
py::class_<Linspace, std::shared_ptr<Linspace>, OpDef> LinspaceInst(m, "Linspace");
LinspaceInst
.def(py::init<bool, ::mgb::CompNode, std::string>(), py::arg("endpoint") = true, py::arg("comp_node"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("endpoint", &Linspace::endpoint)
.def_readwrite("comp_node", &Linspace::comp_node);
py::class_<MagicMindRuntime, std::shared_ptr<MagicMindRuntime>, OpDef> MagicMindRuntimeInst(m, "MagicMindRuntime");
MagicMindRuntimeInst
.def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("buf", &MagicMindRuntime::buf)
.def_readwrite("buf_size", &MagicMindRuntime::buf_size);
py::class_<MatrixInverse, std::shared_ptr<MatrixInverse>, OpDef> MatrixInverseInst(m, "MatrixInverse");
MatrixInverseInst
.def(py::init<>());
py::class_<MatrixMul, std::shared_ptr<MatrixMul>, OpDef> MatrixMulInst(m, "MatrixMul");
MatrixMulInst.attr("ComputeMode") = BatchedMatrixMulInst.attr("ComputeMode");
MatrixMulInst.attr("Format") = BatchedMatrixMulInst.attr("Format");
MatrixMulInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
MatrixMulInst
.def(py::init<bool, bool, ::megdnn::param::MatrixMul::ComputeMode, ::megdnn::param::MatrixMul::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, uint32_t, uint32_t, std::string>(), py::arg("transposeA") = false, py::arg("transposeB") = false, py::arg("compute_mode") = ::megdnn::param::MatrixMul::ComputeMode::DEFAULT, py::arg("format") = ::megdnn::param::MatrixMul::Format::DEFAULT, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("dimA"), py::arg("dimB"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("transposeA", &MatrixMul::transposeA)
.def_readwrite("transposeB", &MatrixMul::transposeB)
.def_readwrite("compute_mode", &MatrixMul::compute_mode)
.def_readwrite("format", &MatrixMul::format)
.def_readwrite("strategy", &MatrixMul::strategy)
.def_readwrite("workspace_limit", &MatrixMul::workspace_limit)
.def_readwrite("dimA", &MatrixMul::dimA)
.def_readwrite("dimB", &MatrixMul::dimB);
py::class_<MeshIndexing, std::shared_ptr<MeshIndexing>, OpDef> MeshIndexingInst(m, "MeshIndexing");
MeshIndexingInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &MeshIndexing::items);
py::class_<NMSKeep, std::shared_ptr<NMSKeep>, OpDef> NMSKeepInst(m, "NMSKeep");
NMSKeepInst
.def(py::init<float, uint32_t, std::string>(), py::arg("iou_thresh"), py::arg("max_output"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("iou_thresh", &NMSKeep::iou_thresh)
.def_readwrite("max_output", &NMSKeep::max_output);
py::class_<NvOf, std::shared_ptr<NvOf>, OpDef> NvOfInst(m, "NvOf");
NvOfInst
.def(py::init<uint32_t, std::string>(), py::arg("precision") = 1, py::arg("scope") = {})
.def_readwrite("precision", &NvOf::precision);
py::class_<Padding, std::shared_ptr<Padding>, OpDef> PaddingInst(m, "Padding");
py::enum_<Padding::PaddingMode>(PaddingInst, "PaddingMode")
.value("REPLICATE", Padding::PaddingMode::REPLICATE)
.value("REFLECT", Padding::PaddingMode::REFLECT)
.value("CONSTANT", Padding::PaddingMode::CONSTANT)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "REPLICATE") return Padding::PaddingMode::REPLICATE;
if (str == "REFLECT") return Padding::PaddingMode::REFLECT;
if (str == "CONSTANT") return Padding::PaddingMode::CONSTANT;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Padding::PaddingMode>();
PaddingInst
.def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, float, ::megdnn::param::Padding::PaddingMode, std::string>(), py::arg("front_offset_dim0") = 0, py::arg("front_offset_dim1") = 0, py::arg("front_offset_dim2") = 0, py::arg("front_offset_dim3") = 0, py::arg("front_offset_dim4") = 0, py::arg("front_offset_dim5") = 0, py::arg("front_offset_dim6") = 0, py::arg("back_offset_dim0") = 0, py::arg("back_offset_dim1") = 0, py::arg("back_offset_dim2") = 0, py::arg("back_offset_dim3") = 0, py::arg("back_offset_dim4") = 0, py::arg("back_offset_dim5") = 0, py::arg("back_offset_dim6") = 0, py::arg("padding_val") = 0, py::arg("padding_mode") = ::megdnn::param::Padding::PaddingMode::CONSTANT, py::arg("scope") = {})
.def_readwrite("front_offset_dim0", &Padding::front_offset_dim0)
.def_readwrite("front_offset_dim1", &Padding::front_offset_dim1)
.def_readwrite("front_offset_dim2", &Padding::front_offset_dim2)
.def_readwrite("front_offset_dim3", &Padding::front_offset_dim3)
.def_readwrite("front_offset_dim4", &Padding::front_offset_dim4)
.def_readwrite("front_offset_dim5", &Padding::front_offset_dim5)
.def_readwrite("front_offset_dim6", &Padding::front_offset_dim6)
.def_readwrite("back_offset_dim0", &Padding::back_offset_dim0)
.def_readwrite("back_offset_dim1", &Padding::back_offset_dim1)
.def_readwrite("back_offset_dim2", &Padding::back_offset_dim2)
.def_readwrite("back_offset_dim3", &Padding::back_offset_dim3)
.def_readwrite("back_offset_dim4", &Padding::back_offset_dim4)
.def_readwrite("back_offset_dim5", &Padding::back_offset_dim5)
.def_readwrite("back_offset_dim6", &Padding::back_offset_dim6)
.def_readwrite("padding_val", &Padding::padding_val)
.def_readwrite("padding_mode", &Padding::padding_mode);
py::class_<ParamPackConcat, std::shared_ptr<ParamPackConcat>, OpDef> ParamPackConcatInst(m, "ParamPackConcat");
ParamPackConcatInst
.def(py::init<std::vector<int32_t>, std::string>(), py::arg("offsets"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("offsets", &ParamPackConcat::offsets);
py::class_<ParamPackSplit, std::shared_ptr<ParamPackSplit>, OpDef> ParamPackSplitInst(m, "ParamPackSplit");
ParamPackSplitInst
.def(py::init<std::vector<int32_t>, std::vector<std::vector<size_t>>, std::string>(), py::arg("offsets"), py::arg("shapes"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("offsets", &ParamPackSplit::offsets)
.def_readwrite("shapes", &ParamPackSplit::shapes);
py::class_<PermutationRNG, std::shared_ptr<PermutationRNG>, OpDef> PermutationRNGInst(m, "PermutationRNG");
PermutationRNGInst
.def(py::init<uint64_t, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Int32), py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("seed", &PermutationRNG::seed)
.def_readwrite("dtype", &PermutationRNG::dtype)
.def_readwrite("handle", &PermutationRNG::handle);
py::class_<PixelShuffle, std::shared_ptr<PixelShuffle>, OpDef> PixelShuffleInst(m, "PixelShuffle");
PixelShuffleInst
.def(py::init<int32_t, std::string>(), py::arg("factor"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("factor", &PixelShuffle::factor);
py::class_<PixelShuffleBackward, std::shared_ptr<PixelShuffleBackward>, OpDef> PixelShuffleBackwardInst(m, "PixelShuffleBackward");
PixelShuffleBackwardInst
.def(py::init<int32_t, std::string>(), py::arg("factor"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("factor", &PixelShuffleBackward::factor);
py::class_<PoissonRNG, std::shared_ptr<PoissonRNG>, OpDef> PoissonRNGInst(m, "PoissonRNG");
PoissonRNGInst
.def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("seed", &PoissonRNG::seed)
.def_readwrite("handle", &PoissonRNG::handle);
py::class_<Pooling, std::shared_ptr<Pooling>, OpDef> PoolingInst(m, "Pooling");
PoolingInst.attr("Mode") = AdaptivePoolingInst.attr("Mode");
PoolingInst.attr("Format") = AdaptivePoolingInst.attr("Format");
PoolingInst.attr("Strategy") = BatchConvBiasInst.attr("Strategy");
PoolingInst
.def(py::init<::megdnn::param::Pooling::Mode, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, ::megdnn::param::Pooling::Format, ::megdnn::param::ExecutionPolicy::Strategy, uint64_t, std::string>(), py::arg("mode") = ::megdnn::param::Pooling::Mode::MAX, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 2, py::arg("stride_w") = 2, py::arg("window_h") = 2, py::arg("window_w") = 2, py::arg("format") = ::megdnn::param::Pooling::Format::NCHW, py::arg("strategy") = static_cast<::megdnn::param::ExecutionPolicy::Strategy>(1), py::arg("workspace_limit") = 18446744073709551615ull, py::arg("scope") = {})
.def_readwrite("mode", &Pooling::mode)
.def_readwrite("pad_h", &Pooling::pad_h)
.def_readwrite("pad_w", &Pooling::pad_w)
.def_readwrite("stride_h", &Pooling::stride_h)
.def_readwrite("stride_w", &Pooling::stride_w)
.def_readwrite("window_h", &Pooling::window_h)
.def_readwrite("window_w", &Pooling::window_w)
.def_readwrite("format", &Pooling::format)
.def_readwrite("strategy", &Pooling::strategy)
.def_readwrite("workspace_limit", &Pooling::workspace_limit);
py::class_<RNN, std::shared_ptr<RNN>, OpDef> RNNInst(m, "RNN");
py::enum_<RNN::NonlineMode>(RNNInst, "NonlineMode")
.value("IDENTITY", RNN::NonlineMode::IDENTITY)
.value("RELU", RNN::NonlineMode::RELU)
.value("TANH", RNN::NonlineMode::TANH)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "IDENTITY") return RNN::NonlineMode::IDENTITY;
if (str == "RELU") return RNN::NonlineMode::RELU;
if (str == "TANH") return RNN::NonlineMode::TANH;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, RNN::NonlineMode>();
RNNInst.attr("FwdMode") = BatchNormInst.attr("FwdMode");
RNNInst
.def(py::init<uint32_t, bool, bool, uint32_t, float, ::megdnn::param::RNN::NonlineMode, ::megdnn::param::RNN::FwdMode, std::string>(), py::arg("num_layers") = 1, py::arg("bidirectional") = false, py::arg("bias") = true, py::arg("hidden_size") = 128, py::arg("dropout") = 0.f, py::arg("nonlineMode") = ::megdnn::param::RNN::NonlineMode::IDENTITY, py::arg("fwd_mode") = ::megdnn::param::RNN::FwdMode::TRAINING, py::arg("scope") = {})
.def_readwrite("num_layers", &RNN::num_layers)
.def_readwrite("bidirectional", &RNN::bidirectional)
.def_readwrite("bias", &RNN::bias)
.def_readwrite("hidden_size", &RNN::hidden_size)
.def_readwrite("dropout", &RNN::dropout)
.def_readwrite("nonlineMode", &RNN::nonlineMode)
.def_readwrite("fwd_mode", &RNN::fwd_mode);
py::class_<RNNCell, std::shared_ptr<RNNCell>, OpDef> RNNCellInst(m, "RNNCell");
RNNCellInst.attr("NonlineMode") = RNNInst.attr("NonlineMode");
RNNCellInst
.def(py::init<::megdnn::param::RNNCell::NonlineMode, std::string>(), py::arg("nonlineMode") = ::megdnn::param::RNNCell::NonlineMode::IDENTITY, py::arg("scope") = {})
.def_readwrite("nonlineMode", &RNNCell::nonlineMode);
py::class_<ROIAlign, std::shared_ptr<ROIAlign>, OpDef> ROIAlignInst(m, "ROIAlign");
py::enum_<ROIAlign::Mode>(ROIAlignInst, "Mode")
.value("MAX", ROIAlign::Mode::MAX)
.value("AVERAGE", ROIAlign::Mode::AVERAGE)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "MAX") return ROIAlign::Mode::MAX;
if (str == "AVERAGE") return ROIAlign::Mode::AVERAGE;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, ROIAlign::Mode>();
ROIAlignInst.attr("Format") = AdaptivePoolingInst.attr("Format");
ROIAlignInst
.def(py::init<::megdnn::param::ROIAlign::Mode, ::megdnn::param::ROIAlign::Format, float, float, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("mode") = ::megdnn::param::ROIAlign::Mode::MAX, py::arg("format") = ::megdnn::param::ROIAlign::Format::NCHW, py::arg("spatial_scale") = 1.0, py::arg("offset") = 0.0, py::arg("pooled_height") = 1, py::arg("pooled_width") = 1, py::arg("sample_height") = 2, py::arg("sample_width") = 2, py::arg("scope") = {})
.def_readwrite("mode", &ROIAlign::mode)
.def_readwrite("format", &ROIAlign::format)
.def_readwrite("spatial_scale", &ROIAlign::spatial_scale)
.def_readwrite("offset", &ROIAlign::offset)
.def_readwrite("pooled_height", &ROIAlign::pooled_height)
.def_readwrite("pooled_width", &ROIAlign::pooled_width)
.def_readwrite("sample_height", &ROIAlign::sample_height)
.def_readwrite("sample_width", &ROIAlign::sample_width);
py::class_<ROIPooling, std::shared_ptr<ROIPooling>, OpDef> ROIPoolingInst(m, "ROIPooling");
py::enum_<ROIPooling::Mode>(ROIPoolingInst, "Mode")
.value("MAX", ROIPooling::Mode::MAX)
.value("AVERAGE", ROIPooling::Mode::AVERAGE)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "MAX") return ROIPooling::Mode::MAX;
if (str == "AVERAGE") return ROIPooling::Mode::AVERAGE;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, ROIPooling::Mode>();
ROIPoolingInst
.def(py::init<::megdnn::param::ROIPooling::Mode, float, std::string>(), py::arg("mode") = ::megdnn::param::ROIPooling::Mode::MAX, py::arg("scale") = 1.f, py::arg("scope") = {})
.def_readwrite("mode", &ROIPooling::mode)
.def_readwrite("scale", &ROIPooling::scale);
py::class_<Reduce, std::shared_ptr<Reduce>, OpDef> ReduceInst(m, "Reduce");
py::enum_<Reduce::Mode>(ReduceInst, "Mode")
.value("SUM", Reduce::Mode::SUM)
.value("SUM_SQR", Reduce::Mode::SUM_SQR)
.value("PRODUCT", Reduce::Mode::PRODUCT)
.value("MIN", Reduce::Mode::MIN)
.value("MAX", Reduce::Mode::MAX)
.value("MEAN", Reduce::Mode::MEAN)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "SUM") return Reduce::Mode::SUM;
if (str == "SUM_SQR") return Reduce::Mode::SUM_SQR;
if (str == "PRODUCT") return Reduce::Mode::PRODUCT;
if (str == "MIN") return Reduce::Mode::MIN;
if (str == "MAX") return Reduce::Mode::MAX;
if (str == "MEAN") return Reduce::Mode::MEAN;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Reduce::Mode>();
py::enum_<Reduce::DataType>(ReduceInst, "DataType")
.value("DEFAULT", Reduce::DataType::DEFAULT)
.value("FLOAT_IO16xC32", Reduce::DataType::FLOAT_IO16xC32)
.value("FLOAT_O32xC32", Reduce::DataType::FLOAT_O32xC32)
.value("FLOAT_O16xC32", Reduce::DataType::FLOAT_O16xC32)
.value("QUINT_I8xO32", Reduce::DataType::QUINT_I8xO32)
.value("QINT_I8xO32", Reduce::DataType::QINT_I8xO32)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "DEFAULT") return Reduce::DataType::DEFAULT;
if (str == "FLOAT_IO16xC32") return Reduce::DataType::FLOAT_IO16xC32;
if (str == "FLOAT_O32xC32") return Reduce::DataType::FLOAT_O32xC32;
if (str == "FLOAT_O16xC32") return Reduce::DataType::FLOAT_O16xC32;
if (str == "QUINT_I8xO32") return Reduce::DataType::QUINT_I8xO32;
if (str == "QINT_I8xO32") return Reduce::DataType::QINT_I8xO32;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Reduce::DataType>();
ReduceInst
.def(py::init<::megdnn::param::Reduce::Mode, int32_t, ::megdnn::param::Reduce::DataType, bool, std::string>(), py::arg("mode") = ::megdnn::param::Reduce::Mode::SUM, py::arg("axis") = 2147483647, py::arg("data_type") = ::megdnn::param::Reduce::DataType::DEFAULT, py::arg("keepdim") = true, py::arg("scope") = {})
.def_readwrite("mode", &Reduce::mode)
.def_readwrite("axis", &Reduce::axis)
.def_readwrite("data_type", &Reduce::data_type)
.def_readwrite("keepdim", &Reduce::keepdim);
py::class_<Remap, std::shared_ptr<Remap>, OpDef> RemapInst(m, "Remap");
py::enum_<Remap::InterpolationMode>(RemapInst, "InterpolationMode")
.value("NEAREST", Remap::InterpolationMode::NEAREST)
.value("LINEAR", Remap::InterpolationMode::LINEAR)
.value("AREA", Remap::InterpolationMode::AREA)
.value("CUBIC", Remap::InterpolationMode::CUBIC)
.value("LANCZOS4", Remap::InterpolationMode::LANCZOS4)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "NEAREST") return Remap::InterpolationMode::NEAREST;
if (str == "LINEAR") return Remap::InterpolationMode::LINEAR;
if (str == "AREA") return Remap::InterpolationMode::AREA;
if (str == "CUBIC") return Remap::InterpolationMode::CUBIC;
if (str == "LANCZOS4") return Remap::InterpolationMode::LANCZOS4;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Remap::InterpolationMode>();
py::enum_<Remap::BorderMode>(RemapInst, "BorderMode")
.value("REPLICATE", Remap::BorderMode::REPLICATE)
.value("REFLECT", Remap::BorderMode::REFLECT)
.value("REFLECT_101", Remap::BorderMode::REFLECT_101)
.value("WRAP", Remap::BorderMode::WRAP)
.value("CONSTANT", Remap::BorderMode::CONSTANT)
.value("TRANSPARENT", Remap::BorderMode::TRANSPARENT)
.value("ISOLATED", Remap::BorderMode::ISOLATED)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "REPLICATE") return Remap::BorderMode::REPLICATE;
if (str == "REFLECT") return Remap::BorderMode::REFLECT;
if (str == "REFLECT_101") return Remap::BorderMode::REFLECT_101;
if (str == "WRAP") return Remap::BorderMode::WRAP;
if (str == "CONSTANT") return Remap::BorderMode::CONSTANT;
if (str == "TRANSPARENT") return Remap::BorderMode::TRANSPARENT;
if (str == "ISOLATED") return Remap::BorderMode::ISOLATED;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, Remap::BorderMode>();
RemapInst.attr("Format") = AdaptivePoolingInst.attr("Format");
RemapInst
.def(py::init<::megdnn::param::Remap::InterpolationMode, ::megdnn::param::Remap::BorderMode, ::megdnn::param::Remap::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::Remap::InterpolationMode::LINEAR, py::arg("border_type") = ::megdnn::param::Remap::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::Remap::Format::NHWC, py::arg("scalar") = 0.f, py::arg("scope") = {})
.def_readwrite("imode", &Remap::imode)
.def_readwrite("border_type", &Remap::border_type)
.def_readwrite("format", &Remap::format)
.def_readwrite("scalar", &Remap::scalar);
py::class_<RemoteRecv, std::shared_ptr<RemoteRecv>, OpDef> RemoteRecvInst(m, "RemoteRecv");
RemoteRecvInst
.def(py::init<std::string, std::string, uint32_t, uint32_t, ::mgb::CompNode, std::vector<int32_t>, ::megdnn::DType, std::string, std::string>(), py::arg("key"), py::arg("addr"), py::arg("port"), py::arg("rank_from"), py::arg("cn"), py::arg("shape"), py::arg("dtype"), py::arg("backend"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("key", &RemoteRecv::key)
.def_readwrite("addr", &RemoteRecv::addr)
.def_readwrite("port", &RemoteRecv::port)
.def_readwrite("rank_from", &RemoteRecv::rank_from)
.def_readwrite("cn", &RemoteRecv::cn)
.def_readwrite("shape", &RemoteRecv::shape)
.def_readwrite("dtype", &RemoteRecv::dtype)
.def_readwrite("backend", &RemoteRecv::backend);
py::class_<RemoteSend, std::shared_ptr<RemoteSend>, OpDef> RemoteSendInst(m, "RemoteSend");
RemoteSendInst
.def(py::init<std::string, std::string, uint32_t, uint32_t, std::string, std::string>(), py::arg("key"), py::arg("addr"), py::arg("port"), py::arg("rank_to"), py::arg("backend"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("key", &RemoteSend::key)
.def_readwrite("addr", &RemoteSend::addr)
.def_readwrite("port", &RemoteSend::port)
.def_readwrite("rank_to", &RemoteSend::rank_to)
.def_readwrite("backend", &RemoteSend::backend);
py::class_<RemoveAxis, std::shared_ptr<RemoveAxis>, OpDef> RemoveAxisInst(m, "RemoveAxis");
RemoveAxisInst
.def(py::init<std::vector<int32_t>, std::string>(), py::arg("axis"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &RemoveAxis::axis);
py::class_<Reshape, std::shared_ptr<Reshape>, OpDef> ReshapeInst(m, "Reshape");
ReshapeInst
.def(py::init<int32_t, std::vector<int32_t>, std::string>(), py::arg("axis") = ::megdnn::param::OptionalAxisV1::INVALID_AXIS, py::arg("shape"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &Reshape::axis)
.def_readwrite("shape", &Reshape::shape);
py::class_<Resize, std::shared_ptr<Resize>, OpDef> ResizeInst(m, "Resize");
ResizeInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
ResizeInst.attr("Format") = AdaptivePoolingInst.attr("Format");
ResizeInst
.def(py::init<::megdnn::param::Resize::InterpolationMode, ::megdnn::param::Resize::Format, std::string>(), py::arg("imode") = ::megdnn::param::Resize::InterpolationMode::LINEAR, py::arg("format") = ::megdnn::param::Resize::Format::NHWC, py::arg("scope") = {})
.def_readwrite("imode", &Resize::imode)
.def_readwrite("format", &Resize::format);
py::class_<SVD, std::shared_ptr<SVD>, OpDef> SVDInst(m, "SVD");
SVDInst
.def(py::init<bool, bool, std::string>(), py::arg("full_matrices") = false, py::arg("compute_uv") = true, py::arg("scope") = {})
.def_readwrite("full_matrices", &SVD::full_matrices)
.def_readwrite("compute_uv", &SVD::compute_uv);
py::class_<SetMeshIndexing, std::shared_ptr<SetMeshIndexing>, OpDef> SetMeshIndexingInst(m, "SetMeshIndexing");
SetMeshIndexingInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &SetMeshIndexing::items);
py::class_<SetSubtensor, std::shared_ptr<SetSubtensor>, OpDef> SetSubtensorInst(m, "SetSubtensor");
SetSubtensorInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &SetSubtensor::items);
py::class_<ShuffleRNG, std::shared_ptr<ShuffleRNG>, OpDef> ShuffleRNGInst(m, "ShuffleRNG");
ShuffleRNGInst
.def(py::init<uint64_t, size_t, std::string>(), py::arg("seed") = 0, py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("seed", &ShuffleRNG::seed)
.def_readwrite("handle", &ShuffleRNG::handle);
py::class_<SlidingWindowTranspose, std::shared_ptr<SlidingWindowTranspose>, OpDef> SlidingWindowTransposeInst(m, "SlidingWindowTranspose");
SlidingWindowTransposeInst
.def(py::init<uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, std::string>(), py::arg("out_h") = 0, py::arg("out_w") = 0, py::arg("pad_h") = 0, py::arg("pad_w") = 0, py::arg("stride_h") = 1, py::arg("stride_w") = 1, py::arg("dilate_h") = 1, py::arg("dilate_w") = 1, py::arg("window_h") = 3, py::arg("window_w") = 3, py::arg("scope") = {})
.def_readwrite("out_h", &SlidingWindowTranspose::out_h)
.def_readwrite("out_w", &SlidingWindowTranspose::out_w)
.def_readwrite("pad_h", &SlidingWindowTranspose::pad_h)
.def_readwrite("pad_w", &SlidingWindowTranspose::pad_w)
.def_readwrite("stride_h", &SlidingWindowTranspose::stride_h)
.def_readwrite("stride_w", &SlidingWindowTranspose::stride_w)
.def_readwrite("dilate_h", &SlidingWindowTranspose::dilate_h)
.def_readwrite("dilate_w", &SlidingWindowTranspose::dilate_w)
.def_readwrite("window_h", &SlidingWindowTranspose::window_h)
.def_readwrite("window_w", &SlidingWindowTranspose::window_w);
py::class_<Softmax, std::shared_ptr<Softmax>, OpDef> SoftmaxInst(m, "Softmax");
SoftmaxInst
.def(py::init<int32_t, std::string>(), py::arg("axis") = -1, py::arg("scope") = {})
.def_readwrite("axis", &Softmax::axis);
py::class_<Split, std::shared_ptr<Split>, OpDef> SplitInst(m, "Split");
SplitInst
.def(py::init<int32_t, int32_t, std::string>(), py::arg("axis"), py::arg("nsections"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("axis", &Split::axis)
.def_readwrite("nsections", &Split::nsections);
py::class_<Subtensor, std::shared_ptr<Subtensor>, OpDef> SubtensorInst(m, "Subtensor");
SubtensorInst
.def(py::init<std::vector<std::tuple<int8_t, bool, bool, bool, bool>>, std::string>(), py::arg("items"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("items", &Subtensor::items);
py::class_<TQT, std::shared_ptr<TQT>, OpDef> TQTInst(m, "TQT");
TQTInst
.def(py::init<int32_t, int32_t, std::string>(), py::arg("qmin") = -2147483648, py::arg("qmax") = 2147483647, py::arg("scope") = {})
.def_readwrite("qmin", &TQT::qmin)
.def_readwrite("qmax", &TQT::qmax);
py::class_<TensorRTRuntime, std::shared_ptr<TensorRTRuntime>, OpDef> TensorRTRuntimeInst(m, "TensorRTRuntime");
TensorRTRuntimeInst
.def(py::init<std::string, size_t, std::string>(), py::arg("buf"), py::arg("buf_size"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("buf", &TensorRTRuntime::buf)
.def_readwrite("buf_size", &TensorRTRuntime::buf_size);
py::class_<TopK, std::shared_ptr<TopK>, OpDef> TopKInst(m, "TopK");
py::enum_<TopK::Mode>(TopKInst, "Mode")
.value("KTH_ONLY", TopK::Mode::KTH_ONLY)
.value("VALUE_IDX_NOSORT", TopK::Mode::VALUE_IDX_NOSORT)
.value("VALUE_IDX_SORTED", TopK::Mode::VALUE_IDX_SORTED)
.def(py::init([](const std::string& in) {
auto&& str = normalize_enum(in);
if (str == "KTH_ONLY") return TopK::Mode::KTH_ONLY;
if (str == "VALUE_IDX_NOSORT") return TopK::Mode::VALUE_IDX_NOSORT;
if (str == "VALUE_IDX_SORTED") return TopK::Mode::VALUE_IDX_SORTED;
throw py::cast_error("invalid enum value " + in);
}));
py::implicitly_convertible<std::string, TopK::Mode>();
TopKInst
.def(py::init<::megdnn::param::TopK::Mode, std::string>(), py::arg("mode") = ::megdnn::param::TopK::Mode::KTH_ONLY, py::arg("scope") = {})
.def_readwrite("mode", &TopK::mode);
py::class_<TypeCvt, std::shared_ptr<TypeCvt>, OpDef> TypeCvtInst(m, "TypeCvt");
TypeCvtInst
.def(py::init<::megdnn::DType, std::string>(), py::arg("dtype"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("dtype", &TypeCvt::dtype);
py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef> UniformRNGInst(m, "UniformRNG");
UniformRNGInst
.def(py::init<uint64_t, ::megdnn::DType, size_t, std::string>(), py::arg("seed") = 0, py::arg("dtype") = megdnn::DType::from_enum(megdnn::DTypeEnum::Float32), py::arg("handle"), py::arg("scope") = {})
.def(py::init<>())
.def_readwrite("seed", &UniformRNG::seed)
.def_readwrite("dtype", &UniformRNG::dtype)
.def_readwrite("handle", &UniformRNG::handle);
py::class_<WarpAffine, std::shared_ptr<WarpAffine>, OpDef> WarpAffineInst(m, "WarpAffine");
WarpAffineInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
WarpAffineInst.attr("BorderMode") = RemapInst.attr("BorderMode");
WarpAffineInst.attr("Format") = AdaptivePoolingInst.attr("Format");
WarpAffineInst
.def(py::init<::megdnn::param::WarpAffine::InterpolationMode, ::megdnn::param::WarpAffine::BorderMode, float, ::megdnn::param::WarpAffine::Format, std::string>(), py::arg("imode") = ::megdnn::param::WarpAffine::InterpolationMode::LINEAR, py::arg("border_mode") = ::megdnn::param::WarpAffine::BorderMode::REPLICATE, py::arg("border_val") = .0f, py::arg("format") = ::megdnn::param::WarpAffine::Format::NHWC, py::arg("scope") = {})
.def_readwrite("imode", &WarpAffine::imode)
.def_readwrite("border_mode", &WarpAffine::border_mode)
.def_readwrite("border_val", &WarpAffine::border_val)
.def_readwrite("format", &WarpAffine::format);
py::class_<WarpPerspective, std::shared_ptr<WarpPerspective>, OpDef> WarpPerspectiveInst(m, "WarpPerspective");
WarpPerspectiveInst.attr("InterpolationMode") = RemapInst.attr("InterpolationMode");
WarpPerspectiveInst.attr("BorderMode") = RemapInst.attr("BorderMode");
WarpPerspectiveInst.attr("Format") = AdaptivePoolingInst.attr("Format");
WarpPerspectiveInst
.def(py::init<::megdnn::param::WarpPerspective::InterpolationMode, ::megdnn::param::WarpPerspective::BorderMode, ::megdnn::param::WarpPerspective::Format, float, std::string>(), py::arg("imode") = ::megdnn::param::WarpPerspective::InterpolationMode::LINEAR, py::arg("bmode") = ::megdnn::param::WarpPerspective::BorderMode::REPLICATE, py::arg("format") = ::megdnn::param::WarpPerspective::Format::NCHW, py::arg("border_val") = .0f, py::arg("scope") = {})
.def_readwrite("imode", &WarpPerspective::imode)
.def_readwrite("bmode", &WarpPerspective::bmode)
.def_readwrite("format", &WarpPerspective::format)
.def_readwrite("border_val", &WarpPerspective::border_val);
// clang-format on
set(SOURCES
../../dnn/scripts/opr_param_defs.py
../../src/core/include/megbrain/ir/ops.td
generated/opdef.h.inl
generated/opdef.cpp.inl
generated/opdef.py.inl
generated/opdef.cpy.inl
generated/enum_macro.h)
execute_process(COMMAND ${CMAKE_COMMAND} -E md5sum ${SOURCES}
OUTPUT_VARIABLE HASH_CONTENT)
message(STATUS "Generating hash.txt for opdefs")
file(WRITE generated/hash.txt "${HASH_CONTENT}")
......@@ -12,9 +12,9 @@ endif()
# TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS})
add_dependencies(imperative_test mgb_opdef)
target_include_directories(
imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}
${CPP_REDIS_INCLUDES})
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include
../src/include ${CPP_REDIS_INCLUDES})
target_link_libraries(imperative_test mgb_opdef_inc)
# Python binding
target_include_directories(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册