diff --git a/cmake/llvm-project.cmake b/cmake/llvm-project.cmake index 561435c59481c944fe274776a469e30922c720ba..94df38794313255e5650b9950178b174de4e14bb 100644 --- a/cmake/llvm-project.cmake +++ b/cmake/llvm-project.cmake @@ -49,6 +49,14 @@ function(external_tablegen_library) install(TARGETS ${_NAME} EXPORT ${MGE_EXPORT_TARGETS}) endfunction() +set(LLVM_LIBS LLVMCore LLVMSupport LLVMX86CodeGen LLVMOrcJIT LLVMNVPTXCodeGen LLVMNVPTXDesc LLVMNVPTXInfo) +set(MLIR_CORE_LIBS MLIRAnalysis MLIRExecutionEngine MLIRIR MLIRParser MLIRPass MLIRSideEffectInterfaces MLIRTransforms) +set(MLIR_DIALECT_LIBS MLIRAsync MLIRAVX512 MLIRGPU MLIRLLVMAVX512 MLIRNVVMIR MLIROpenACC MLIRPDL MLIRPDLInterp MLIRQuant MLIRROCDLIR MLIRSDBM MLIRShape MLIRSPIRV MLIRStandardOpsTransforms) +set(MLIR_CONVERSION_LIBS MLIRAffineToStandard MLIRAVX512ToLLVM MLIRGPUToGPURuntimeTransforms MLIRGPUToNVVMTransforms MLIRSCFToStandard) +set(MLIR_TRANSLATION_LIBS MLIRTargetLLVMIR MLIRTargetNVVMIR) +set(MLIR_LIBS ${MLIR_CORE_LIBS} ${MLIR_DIALECT_LIBS} ${MLIR_CONVERSION_LIBS} ${MLIR_TRANSLATION_LIBS}) +set(MLIR_LLVM_LIBS ${LLVM_LIBS} ${MLIR_LIBS}) + if (MGE_USE_SYSTEM_LIB) find_package(ZLIB) find_package(MLIR REQUIRED CONFIG) @@ -77,9 +85,7 @@ if (MGE_USE_SYSTEM_LIB) endif() endfunction(find_mlir_llvm_lib) - set(MLIR_COMPONENTS MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRShape;MLIRGPUToNVVMTransforms;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) - - foreach(c ${MLIR_COMPONENTS}) + foreach(c ${MLIR_LIBS}) find_mlir_llvm_lib(${c}) endforeach() return() @@ -119,5 +125,3 @@ set(MLIR_LLVM_INCLUDE_DIR ${PROJECT_BINARY_DIR}/third_party/llvm-project/llvm/tools/mlir/include ) set(MLIR_TABLEGEN_EXE mlir-tblgen) - -set(MLIR_LLVM_LIBS LLVMCore;LLVMSupport;LLVMX86CodeGen;LLVMOrcJIT;LLVMNVPTXCodeGen;LLVMNVPTXDesc;LLVMNVPTXInfo;MLIRAnalysis;MLIRExecutionEngine;MLIRIR;MLIRParser;MLIRPass;MLIRSideEffectInterfaces;MLIRTargetLLVMIR;MLIRTransforms;MLIRAffineToStandard;MLIRSCFToStandard;MLIRAVX512ToLLVM;MLIRAVX512;MLIRLLVMAVX512;MLIRSDBM;MLIRROCDLIR;MLIRGPU;MLIRQuant;MLIRSPIRV;MLIRNVVMIR;MLIRGPUToNVVMTransforms;MLIRShape;MLIRTargetNVVMIR;MLIRGPUToGPURuntimeTransforms;MLIRStandardOpsTransforms) diff --git a/src/jit/impl/mlir/compiler.cpp b/src/jit/impl/mlir/compiler.cpp index 1be70411fb81e49211416b666b5fc71680c30d4e..6f4588d15438450a67504b707cd81e6581c2a355 100644 --- a/src/jit/impl/mlir/compiler.cpp +++ b/src/jit/impl/mlir/compiler.cpp @@ -67,8 +67,8 @@ mlir::OwnedBlob compile_ptx_to_cubin(const std::string ptx, mlir::Location, } std::unique_ptr translate_module_to_nvvm_ir_and_link_device( - Operation* m) { - std::unique_ptr module = mlir::translateModuleToNVVMIR(m); + Operation* m, llvm::LLVMContext& llvmContext, llvm::StringRef name) { + std::unique_ptr module = mlir::translateModuleToNVVMIR(m, llvmContext); auto get_device_path = []() -> std::string { auto cuda_path = getenv("CUDA_BIN_PATH"); std::string device_dir; @@ -223,6 +223,7 @@ void MLIRCompiler::run_lowering_pass(mlir::OwningModuleRef& module, std::unique_ptr MLIRCompiler::do_compile( const InternalGraph& graph, const JITExecutor::Args& args) { mlir::MLIRContext ctx; + ctx.getOrLoadDialect(); ctx.printStackTraceOnDiagnostic(true); ctx.printOpOnDiagnostic(true); diff --git a/src/jit/impl/mlir/ir/dialect.cpp b/src/jit/impl/mlir/ir/dialect.cpp index e877b5a649cdd7fc323ba41e9f3ed91b54b92b3b..c4a82bfd43de5af47f2afc746dace8c7218c5e46 100644 --- a/src/jit/impl/mlir/ir/dialect.cpp +++ b/src/jit/impl/mlir/ir/dialect.cpp @@ -24,7 +24,8 @@ using namespace mgb; using namespace jit; -MgbDialect::MgbDialect(mlir::MLIRContext* ctx) : mlir::Dialect("mgb", ctx) { +MgbDialect::MgbDialect(mlir::MLIRContext* ctx) + : mlir::Dialect("mgb", ctx, mlir::TypeID::get()) { addOperations< #define GET_OP_LIST #include "megbrain/jit/mlir/ir/ops.cpp.inc" diff --git a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp index a63a556eabc69fdb14e615e185bbb4b04d74faca..32c7bcff22ace6bdbdeaabc8ac6e2699f6df837d 100644 --- a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp @@ -209,6 +209,11 @@ struct ConstantScalarOpLowering class MgbToAffineLoweringPass : public PassWrapper { public: + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + } + void runOnFunction() override final { ConversionTarget target(getContext()); target.addLegalDialect(); diff --git a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp index 525f256e95fe6dd9a0d869868abb5204586e7aad..e0b35b663bb2bd122673bc5542e92aef6e7a5388 100644 --- a/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp @@ -259,6 +259,11 @@ private: class MgbToGpuLoweringPass : public PassWrapper { public: + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + } + void runOnFunction() override final { auto func_op = getFunction(); Location loc = func_op.getLoc(); diff --git a/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp b/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp index 9e17e0c1f05a8a57520b91378368fde5dc949667..edacdbcb1d5200c77a142df2b2f791aa5d65c20e 100644 --- a/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_llvm_pass.cpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include using namespace mgb; @@ -30,6 +32,12 @@ namespace { class AffineToLLVMLoweringPass : public PassWrapper> { +public: + void getDependentDialects(mlir::DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + } + void runOnOperation() final { LLVMConversionTarget target(getContext()); target.addLegalOp(); diff --git a/src/jit/impl/mlir/ir/types.h b/src/jit/impl/mlir/ir/types.h index 548b5db42b2e62b278c80d3d634fc9430cb64aa5..ee8a1f6152847a68fb213b1e2d4535f0a920e6af 100644 --- a/src/jit/impl/mlir/ir/types.h +++ b/src/jit/impl/mlir/ir/types.h @@ -21,7 +21,7 @@ namespace jit { inline bool is_elemwise_float(const mlir::Type& dt) { if (auto cast = dt.dyn_cast_or_null()) { - if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { + if (cast.getElementType().isF32()) { return true; } } diff --git a/src/jit/impl/mlir/ir/utils.cpp b/src/jit/impl/mlir/ir/utils.cpp index f3553b982dce13ac24ada88792b0028462bc558a..afdb25e8b9fa3d69c93b71b79ada03fb67805112 100644 --- a/src/jit/impl/mlir/ir/utils.cpp +++ b/src/jit/impl/mlir/ir/utils.cpp @@ -82,13 +82,12 @@ megdnn::DType jit::mlir_type_to_dtype(mlir::Type type) { if (auto cast = type.dyn_cast_or_null()) { element_type = cast.getElementType(); } - switch (element_type.getKind()) { - case mlir::StandardTypes::F32: - return megdnn::dtype::Float32{}; - default: - mgb_throw(InternalError, - "Unsupport mlir type for MemRefType, got: %s\n", - mlir_type_to_string(type).c_str()); + if (element_type.isF32()) { + return megdnn::dtype::Float32{}; + } else { + mgb_throw(InternalError, + "Unsupport mlir type for MemRefType, got: %s\n", + mlir_type_to_string(type).c_str()); } return {}; } diff --git a/src/jit/include/megbrain/jit/mlir/ir/dialect.h b/src/jit/include/megbrain/jit/mlir/ir/dialect.h index f0ee5fe8de9345125d2ddd9e550d7cb0ce18f2c1..d1c93e9b0bec61fbb92e826f6c5478d7beb5f2ed 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/dialect.h +++ b/src/jit/include/megbrain/jit/mlir/ir/dialect.h @@ -34,13 +34,13 @@ public: static llvm::StringRef getDialectNamespace() { return "mgb::jit"; } }; +} // namespace jit +} // namespace mgb + #define GET_OP_CLASSES using namespace mlir; #include "megbrain/jit/mlir/ir/ops.h.inc" -} // namespace jit -} // namespace mgb - #endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen