From aa7f28b8359e0f568c01fd3143afa1ba03a6160a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 27 Aug 2020 00:31:17 +0800 Subject: [PATCH] fix(mgb/jit): fix gpu kernel args order GitOrigin-RevId: a7f0e56747d43812b71cc8a859b52465e3832c61 --- src/jit/impl/mlir/compiler.cpp | 2 +- src/jit/impl/mlir/executable_cuda.cpp | 2 +- .../ir/create_gpu_kernel_outlining_pass.cpp | 339 ++++++++++++++++++ src/jit/impl/mlir/ir/lower_to_affine_pass.cpp | 2 +- src/jit/impl/nvrtc/compiler_cuda.cpp | 6 - src/jit/impl/nvrtc/compiler_cuda.h | 15 - src/jit/include/megbrain/jit/mlir/ir/passes.h | 8 + 7 files changed, 350 insertions(+), 24 deletions(-) create mode 100644 src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp diff --git a/src/jit/impl/mlir/compiler.cpp b/src/jit/impl/mlir/compiler.cpp index 0ebb2fba7..d22d46baf 100644 --- a/src/jit/impl/mlir/compiler.cpp +++ b/src/jit/impl/mlir/compiler.cpp @@ -94,7 +94,7 @@ void add_cuda_lowering_pass(mlir::PassManager& manager, CompNode cn) { opt_pm.addPass(mlir::createLoopFusionPass()); opt_pm.addPass(mlir::createMemRefDataFlowOptPass()); } - manager.addPass(mlir::createGpuKernelOutliningPass()); + manager.addPass(create_gpu_kernel_outlining_pass()); { auto& kernel_pm = manager.nest(); kernel_pm.addPass(mlir::createLowerGpuOpsToNVVMOpsPass()); diff --git a/src/jit/impl/mlir/executable_cuda.cpp b/src/jit/impl/mlir/executable_cuda.cpp index 2921631ec..6448ca6b7 100644 --- a/src/jit/impl/mlir/executable_cuda.cpp +++ b/src/jit/impl/mlir/executable_cuda.cpp @@ -79,7 +79,7 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, CompNodeEnv::from_comp_node(fusion_opr->comp_node()); int64_t num_block = (nr_elements - 1) / block_size + 1; - params.insert(params.begin(), &nr_elements); + params.push_back(&nr_elements); MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0, env.cuda_env().stream, params.data(), 0)); } diff --git a/src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp b/src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp new file mode 100644 index 000000000..ec3666d88 --- /dev/null +++ b/src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp @@ -0,0 +1,339 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the GPU dialect kernel outlining pass. +// +//===----------------------------------------------------------------------===// +/** + * \file src/jit/impl/mlir/ir/create_gpu_kernel_outlining_pass.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + * + * This file has been modified by Megvii ("Megvii Modifications"). + * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights + * reserved. + * + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "megbrain/jit/mlir/ir/passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; + +template +static void createForAllDimensions(OpBuilder& builder, Location loc, + SmallVectorImpl& values) { + for (StringRef dim : {"x", "y", "z"}) { + Value v = builder.create(loc, builder.getIndexType(), + builder.getStringAttr(dim)); + values.push_back(v); + } +} + +// Add operations generating block/thread ids and grid/block dimensions at the +// beginning of the `launchFuncOpBody` region. Add mapping from argument in +// entry block of `launchOpBody`, to the corresponding result value of the added +// operations. +static void injectGpuIndexOperations(Location loc, Region& launchFuncOpBody, + Region& launchOpBody, + BlockAndValueMapping& map) { + OpBuilder builder(loc->getContext()); + Block& firstBlock = launchOpBody.front(); + builder.setInsertionPointToStart(&launchFuncOpBody.front()); + SmallVector indexOps; + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + createForAllDimensions(builder, loc, indexOps); + // Replace the leading 12 function args with the respective thread/block + // index operations. Iterate backwards since args are erased and indices + // change. + for (auto indexOp : enumerate(indexOps)) + map.map(firstBlock.getArgument(indexOp.index()), indexOp.value()); +} + +static bool isSinkingBeneficiary(Operation* op) { + return isa(op); +} + +LogicalResult mlir::sinkOperationsIntoLaunchOp(gpu::LaunchOp launchOp) { + Region& launchOpBody = launchOp.body(); + + // Identify uses from values defined outside of the scope of the launch + // operation. + llvm::SetVector sinkCandidates; + getUsedValuesDefinedAbove(launchOpBody, sinkCandidates); + + llvm::SetVector sunkValues; + llvm::SetVector sunkOperations; + for (Value operand : sinkCandidates) { + Operation* operandOp = operand.getDefiningOp(); + if (!operandOp || !isSinkingBeneficiary(operandOp)) + continue; + // Only sink operations that do not create new sinkCandidates. + if (!llvm::all_of(operandOp->getOperands(), + [&sinkCandidates](Value value) { + return sinkCandidates.count(value); + })) + continue; + sunkValues.insert(operand); + sunkOperations.insert(operandOp); + } + + // Insert operations so that the defs get cloned before uses. + BlockAndValueMapping map; + OpBuilder builder(launchOpBody); + DenseSet processed; + SmallVector clonedOps; + while (processed.size() != sunkOperations.size()) { + auto startSize = processed.size(); + for (Operation* sunkOperation : sunkOperations) { + if (processed.count(sunkOperation)) + continue; + + // Operation cant be cloned yet if any of its operands is also being + // sunk, but isnt cloned yet. + if (llvm::any_of(sunkOperation->getOperands(), [&sunkValues, + &map](Value value) { + return sunkValues.count(value) && !map.lookupOrNull(value); + })) + continue; + + Operation* clonedOp = builder.clone(*sunkOperation, map); + // Only replace uses within the launch op. + for (auto result : llvm::enumerate(sunkOperation->getResults())) { + auto replacement = clonedOp->getResult(result.index()); + for (auto& use : + llvm::make_early_inc_range(result.value().getUses())) + if (use.getOwner()->getParentOfType() == + launchOp) + use.set(replacement); + } + processed.insert(sunkOperation); + } + if (startSize == processed.size()) + return launchOp.emitError( + "found illegal cyclic dependency between operations while " + "sinking"); + } + return success(); +} + +// Outline the `gpu.launch` operation body into a kernel function. Replace +// `gpu.terminator` operations by `gpu.return` in the generated function. +static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp, + StringRef kernelFnName, + SmallVector& operands) { + Location loc = launchOp.getLoc(); + // Create a builder with no insertion point, insertion will happen + // separately due to symbol table manipulation. + OpBuilder builder(launchOp.getContext()); + Region& launchOpBody = launchOp.body(); + + llvm::SetVector operandsSet; + // Identify uses from values defined outside of the scope of the launch + // operation. + getUsedValuesDefinedAbove(launchOpBody, operandsSet); + + // reorder the operands which match the input order + llvm::SetVector insertedOperands; + for (auto& item : launchOp.getParentOfType().getArguments()) { + if (operandsSet.contains(item)) { + operands.push_back(item); + insertedOperands.insert(item); + } + } + for (Value operand : operandsSet) { + if (!insertedOperands.contains(operand)) { + operands.push_back(operand); + } + } + + // Create the gpu.func operation. + SmallVector kernelOperandTypes; + kernelOperandTypes.reserve(operands.size()); + for (Value operand : operands) { + kernelOperandTypes.push_back(operand.getType()); + } + FunctionType type = + FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); + auto outlinedFunc = builder.create(loc, kernelFnName, type); + outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + BlockAndValueMapping map; + + // Map the arguments corresponding to the launch parameters like blockIdx, + // threadIdx, etc. + Region& outlinedFuncBody = outlinedFunc.body(); + injectGpuIndexOperations(loc, outlinedFuncBody, launchOpBody, map); + + // Map arguments from gpu.launch region to the arguments of the gpu.func + // operation. + Block& entryBlock = outlinedFuncBody.front(); + for (auto operand : enumerate(operands)) + map.map(operand.value(), entryBlock.getArgument(operand.index())); + + // Clone the region of the gpu.launch operation into the gpu.func operation. + // TODO: If cloneInto can be modified such that if a mapping for + // a block exists, that block will be used to clone operations into (at the + // end of the block), instead of creating a new block, this would be much + // cleaner. + launchOpBody.cloneInto(&outlinedFuncBody, map); + + // Branch from entry of the gpu.func operation to the block that is cloned + // from the entry block of the gpu.launch operation. + Block& launchOpEntry = launchOpBody.front(); + Block* clonedLaunchOpEntry = map.lookup(&launchOpEntry); + builder.setInsertionPointToEnd(&entryBlock); + builder.create(loc, clonedLaunchOpEntry); + + outlinedFunc.walk([](gpu::TerminatorOp op) { + OpBuilder replacer(op); + replacer.create(op.getLoc()); + op.erase(); + }); + return outlinedFunc; +} + +// Replace `gpu.launch` operations with an `gpu.launch_func` operation launching +// `kernelFunc`. The kernel func contains the body of the `gpu.launch` with +// constant region arguments inlined. +static void convertToLaunchFuncOp(gpu::LaunchOp launchOp, + gpu::GPUFuncOp kernelFunc, + ValueRange operands) { + OpBuilder builder(launchOp); + builder.create( + launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(), + launchOp.getBlockSizeOperandValues(), operands); + launchOp.erase(); +} + +namespace { +/// Pass that moves the kernel of each LaunchOp into its separate nested module. +/// +/// This pass moves the kernel code of each LaunchOp into a function created +/// inside a nested module. It also creates an external function of the same +/// name in the parent module. +/// +/// The gpu.modules are intended to be compiled to a cubin blob independently in +/// a separate pass. The external functions can then be annotated with the +/// symbol of the cubin accessor function. +class GpuKernelOutliningPass + : public PassWrapper> { +public: + void runOnOperation() override { + SymbolTable symbolTable(getOperation()); + bool modified = false; + for (auto func : getOperation().getOps()) { + // Insert just after the function. + Block::iterator insertPt(func.getOperation()->getNextNode()); + auto funcWalkResult = func.walk([&](gpu::LaunchOp op) { + SmallVector operands; + std::string kernelFnName = + Twine(op.getParentOfType().getName(), "_kernel") + .str(); + + // Pull in instructions that can be sunk + if (failed(sinkOperationsIntoLaunchOp(op))) + return WalkResult::interrupt(); + gpu::GPUFuncOp outlinedFunc = + outlineKernelFuncImpl(op, kernelFnName, operands); + + // Create nested module and insert outlinedFunc. The module will + // originally get the same name as the function, but may be + // renamed on insertion into the parent module. + auto kernelModule = + createKernelModule(outlinedFunc, symbolTable); + symbolTable.insert(kernelModule, insertPt); + + // Potentially changes signature, pulling in constants. + convertToLaunchFuncOp(op, outlinedFunc, operands); + modified = true; + return WalkResult::advance(); + }); + if (funcWalkResult.wasInterrupted()) + return signalPassFailure(); + } + + // If any new module was inserted in this module, annotate this module + // as a container module. + if (modified) + getOperation().setAttr( + gpu::GPUDialect::getContainerModuleAttrName(), + UnitAttr::get(&getContext())); + } + +private: + // Returns a gpu.module containing kernelFunc and all callees (recursive). + gpu::GPUModuleOp createKernelModule(gpu::GPUFuncOp kernelFunc, + const SymbolTable& parentSymbolTable) { + // TODO: This code cannot use an OpBuilder because it must be inserted + // into a SymbolTable by the caller. SymbolTable needs to be refactored + // to prevent manual building of Ops with symbols in code using + // SymbolTables and then this needs to use the OpBuilder. + auto context = getOperation().getContext(); + OpBuilder builder(context); + OperationState state(kernelFunc.getLoc(), + gpu::GPUModuleOp::getOperationName()); + gpu::GPUModuleOp::build(builder, state, kernelFunc.getName()); + auto kernelModule = cast(Operation::create(state)); + SymbolTable symbolTable(kernelModule); + symbolTable.insert(kernelFunc); + + SmallVector symbolDefWorklist = {kernelFunc}; + while (!symbolDefWorklist.empty()) { + if (Optional symbolUses = + SymbolTable::getSymbolUses( + symbolDefWorklist.pop_back_val())) { + for (SymbolTable::SymbolUse symbolUse : *symbolUses) { + StringRef symbolName = symbolUse.getSymbolRef() + .cast() + .getValue(); + if (symbolTable.lookup(symbolName)) + continue; + + Operation* symbolDefClone = + parentSymbolTable.lookup(symbolName)->clone(); + symbolDefWorklist.push_back(symbolDefClone); + symbolTable.insert(symbolDefClone); + } + } + } + + return kernelModule; + } +}; + +} // namespace + +std::unique_ptr mgb::jit::create_gpu_kernel_outlining_pass() { + return std::make_unique(); +} + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp index 684a03f20..201516407 100644 --- a/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp +++ b/src/jit/impl/mlir/ir/lower_to_affine_pass.cpp @@ -23,6 +23,7 @@ #include #include #include +#include "mlir/IR/StandardTypes.h" #include @@ -155,7 +156,6 @@ struct TernaryOpLowering : public ConversionPattern { MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) #undef cb - struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx) : ConversionPattern(jit::AssignOp::getOperationName(), 1, ctx) {} diff --git a/src/jit/impl/nvrtc/compiler_cuda.cpp b/src/jit/impl/nvrtc/compiler_cuda.cpp index a891cc184..a3f71cd78 100644 --- a/src/jit/impl/nvrtc/compiler_cuda.cpp +++ b/src/jit/impl/nvrtc/compiler_cuda.cpp @@ -189,12 +189,6 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, env.cuda_env().stream, exec_args, 0)); } } // namespace -void mgb::jit::_on_cuda_cu_error(const char* expr, CUresult cu_res, - const char* msg, const char* file, - const char* func, int line) { - mgb_throw(CudaError, "cuda error %d: %s (%s at %s:%s:%d)", int(cu_res), msg, - expr, file, func, line); -} void mgb::jit::_on_nvrtc_error(const char* expr, nvrtcResult nvrtc_res, const char* file, const char* func, int line) { diff --git a/src/jit/impl/nvrtc/compiler_cuda.h b/src/jit/impl/nvrtc/compiler_cuda.h index 97c290dae..c90ed856f 100644 --- a/src/jit/impl/nvrtc/compiler_cuda.h +++ b/src/jit/impl/nvrtc/compiler_cuda.h @@ -19,17 +19,6 @@ #include #include "megbrain/jit/compiler.h" -#define MGB_CUDA_CU_CHECK(expr) \ - do { \ - CUresult __cuda_result = (expr); \ - if (!mgb_likely(__cuda_result == CUDA_SUCCESS)) { \ - const char* __msg; \ - cuGetErrorName(__cuda_result, &__msg); \ - ::mgb::jit::_on_cuda_cu_error(#expr, __cuda_result, __msg, \ - __FILE__, __func__, __LINE__); \ - } \ - } while (0) - #define MGB_NVRTC_CHECK(expr) \ do { \ nvrtcResult __nvrtc_result = (expr); \ @@ -42,10 +31,6 @@ namespace mgb { namespace jit { -[[noreturn]] void _on_cuda_cu_error(const char* expr, CUresult cu_res, - const char* msg, const char* file, - const char* func, int line); - [[noreturn]] void _on_nvrtc_error(const char* expr, nvrtcResult nvrtc_res, const char* file, const char* func, int line); diff --git a/src/jit/include/megbrain/jit/mlir/ir/passes.h b/src/jit/include/megbrain/jit/mlir/ir/passes.h index 1cb2d3410..630554bda 100644 --- a/src/jit/include/megbrain/jit/mlir/ir/passes.h +++ b/src/jit/include/megbrain/jit/mlir/ir/passes.h @@ -32,6 +32,14 @@ std::unique_ptr create_lower_to_llvm_pass(); std::unique_ptr create_lower_to_gpu_pass(); +/** + * \brief Outline gpu.launch bodies to kernel functions + * + * \warning Modified from lib/Dialect/GPU/Transforms/KernelOutlining.cpp, it + * will reorder gpu function args with the args of the emit c interface. + */ +std::unique_ptr create_gpu_kernel_outlining_pass(); + } // namespace jit } // namespace mgb -- GitLab