提交 88e918e2 编写于 作者: M Megvii Engine Team

feat(mgb/jit): add scf.ForOp in MgbToGpuLoweringPass

GitOrigin-RevId: 3cdae27c378f7f76c7dc59ecb80b08d6dd5c35fe
上级 7aa54b0e
......@@ -26,6 +26,7 @@
#include <mlir/Conversion/GPUCommon/GPUCommonPass.h>
#include <mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h>
#include <mlir/Conversion/SCFToStandard/SCFToStandard.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/GPU/Passes.h>
#include <mlir/IR/Dialect.h>
......@@ -152,6 +153,7 @@ void add_cuda_lowering_pass(mlir::PassManager& manager,
{
mlir::OpPassManager& opt_pm = manager.nest<mlir::FuncOp>();
opt_pm.addPass(create_lower_to_gpu_pass());
opt_pm.addPass(mlir::createLowerToCFGPass());
opt_pm.addPass(mlir::createCanonicalizerPass());
opt_pm.addPass(mlir::createCSEPass());
opt_pm.addPass(mlir::createLoopFusionPass());
......
......@@ -32,6 +32,14 @@ using namespace mgb;
using namespace jit;
namespace {
int64_t get_grid_size(int64_t nr_elements, int64_t block_size) {
// unroll three times in the kernel
int64_t a = nr_elements / (block_size * 2);
int64_t b = (nr_elements - 1) / (block_size * 3) + 1;
return std::max(a, b);
}
template <int out_dim, typename ctype>
void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
int block_size) {
......@@ -87,9 +95,18 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
const CompNodeEnv& env =
CompNodeEnv::from_comp_node(fusion_opr->comp_node());
int64_t num_block = (nr_elements - 1) / block_size + 1;
int64_t grid_size;
if (nr_elements <= block_size) {
block_size = nr_elements;
grid_size = 1;
} else {
grid_size = get_grid_size(nr_elements, block_size);
}
int64_t nr_threads = grid_size * block_size;
params.push_back(&nr_elements);
MGB_CUDA_CU_CHECK(cuLaunchKernel(func, num_block, 1, 1, block_size, 1, 1, 0,
params.push_back(&nr_threads);
MGB_CUDA_CU_CHECK(cuLaunchKernel(func, grid_size, 1, 1, block_size, 1, 1, 0,
env.cuda_env().stream, params.data(), 0));
}
......
......@@ -21,11 +21,6 @@
#include "megbrain/jit/mlir/ir/passes.h"
#include "megbrain/jit/mlir/ir/utils.h"
#include <llvm/ADT/PointerUnion.h>
#include <llvm/ADT/Sequence.h>
#include <llvm/ADT/SetVector.h>
#include <llvm/ADT/Twine.h>
#include <llvm/IR/Type.h>
#include <mlir/Dialect/GPU/GPUDialect.h>
#include <mlir/Dialect/SCF/SCF.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
......@@ -39,124 +34,98 @@ using namespace jit;
namespace {
mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
auto thread_idx = rewriter.create<gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
auto block_idx = rewriter.create<gpu::BlockIdOp>(
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
auto group_size = rewriter.create<gpu::BlockDimOp>(
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
Value index = rewriter.create<AddIOp>(
loc, thread_idx,
rewriter.create<MulIOp>(loc, block_idx, group_size));
return index;
}
using Rewriter = ConversionPatternRewriter;
using Layout = megdnn::TensorLayout;
megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) {
auto func_op = launch_op.getParentOfType<mlir::FuncOp>();
mgb_assert(func_op, "Unexpexted launch op.");
for (auto block_iter = func_op.rbegin(); block_iter != func_op.rend();
block_iter++) {
for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend();
op_iter++) {
auto op = llvm::dyn_cast_or_null<dialect::AssignOp>(&(*op_iter));
if (op && op.getNumOperands() > 0) {
return mlir_type_to_layout(*(op.operand_type_begin()));
}
}
/* ===================== GpuLoweringHelper ===================== */
struct GpuLoweringHelper {
GpuLoweringHelper(scf::ForOp* for_op, Value index, const Layout& dest)
: m_for_op(for_op), m_index(index), m_dest(dest) {}
void set_insertion_point(OpBuilder& builder) const {
// insert before the last operation (scf.yield) in the loop body
builder.setInsertionPoint(&(m_for_op->getLoopBody().front().back()));
}
mgb_throw(MegBrainError, "Unexpexted launch op.");
}
std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter,
const Location& loc,
const mlir::Value& val,
const megdnn::TensorLayout& dst) {
Value index = get_tid(rewriter, loc);
auto type = val.getType().dyn_cast_or_null<mlir::MemRefType>();
if (type) {
ValueBuilderHelper helper(rewriter, loc);
std::vector<mlir::Value> idxs;
idxs.resize(dst.ndim);
mlir::Value dim_index = index;
for (int i = dst.ndim - 1; i >= 0; i--) {
auto cur_index = helper.modI(dim_index, helper.const_i32(dst[i]));
idxs[i] = cur_index;
dim_index = helper.divI(dim_index, helper.const_i32(dst[i]));
std::vector<Value> map_indices(OpBuilder& builder, Location loc,
Value value) const {
auto type = value.getType().dyn_cast_or_null<MemRefType>();
if (!type) {
return {m_index};
}
megdnn::TensorLayout src_layout = mlir_type_to_layout(type);
std::vector<Value> indices(m_dest.ndim);
ValueBuilderHelper helper(builder, loc);
// map global index to multi-dimensional indices
Value dim_index = m_index;
for (int i = m_dest.ndim - 1; i >= 0; i--) {
indices[i] = helper.modI(dim_index, helper.const_i32(m_dest[i]));
dim_index = helper.divI(dim_index, helper.const_i32(m_dest[i]));
}
// allow broadcasting
Layout src_layout = mlir_type_to_layout(type);
src_layout.init_contiguous_stride();
for (int i = 0; i < type.getRank(); ++i) {
if (src_layout[i] == 1) {
idxs[i] = helper.const_i32(0);
indices[i] = helper.const_i32(0);
}
}
return idxs;
} else {
return {index};
return indices;
}
}
struct ElemwiseLowering : public ConversionPattern {
ElemwiseLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::Elemwise::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
private:
scf::ForOp* m_for_op;
Value m_index;
Layout m_dest;
};
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
/* ===================== conversion patterns ===================== */
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
struct AssignOpLowering : public ConversionPattern, public GpuLoweringHelper {
AssignOpLowering(MLIRContext* ctx, scf::ForOp* for_op, mlir::Value index,
const Layout& dest)
: ConversionPattern(dialect::AssignOp::getOperationName(), 2, ctx),
GpuLoweringHelper(for_op, index, dest) {}
auto dst_layout = output_layout(m_launch_op);
auto inputs = llvm::to_vector<4>(
llvm::map_range(operands, [&](mlir::Value val) {
auto index =
get_multidim_tid(rewriter, loc, val, dst_layout);
return get_operand<LoadOp>(rewriter, loc, val, index);
}));
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
rewriter.replaceOp(op,
lower_elemwise_to_std(op, rewriter, loc, inputs));
auto index = map_indices(rewriter, loc, operands[1]);
auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index);
rewriter.create<StoreOp>(loc, input, operands[1], index);
rewriter.eraseOp(op);
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
struct TypeCvtLowering : public ConversionPattern {
TypeCvtLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::TypeCvt::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout);
auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index);
struct ConstantScalarOpLowering
: public OpRewritePattern<dialect::ConstantScalarOp>,
public GpuLoweringHelper {
ConstantScalarOpLowering(MLIRContext* ctx, scf::ForOp* for_op, Value index,
const Layout& dest)
: OpRewritePattern<dialect::ConstantScalarOp>(ctx),
GpuLoweringHelper(for_op, index, dest) {}
rewriter.replaceOp(op, lower_typecvt_to_std(op, rewriter, loc, input));
LogicalResult matchAndRewrite(dialect::ConstantScalarOp op,
PatternRewriter& rewriter) const final {
set_insertion_point(rewriter);
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.value());
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
struct DimshuffleLowering : public ConversionPattern {
DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
struct DimshuffleLowering : public ConversionPattern, public GpuLoweringHelper {
DimshuffleLowering(MLIRContext* ctx, scf::ForOp* for_op, Value index,
const Layout& dest)
: ConversionPattern(dialect::Dimshuffle::getOperationName(), 1,
ctx),
m_launch_op{launch_op} {}
GpuLoweringHelper(for_op, index, dest) {}
static std::vector<mlir::Value> get_index_from_pattern(
const std::vector<int32_t>& pattern,
......@@ -172,163 +141,162 @@ struct DimshuffleLowering : public ConversionPattern {
return res;
}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout);
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern();
auto index = map_indices(rewriter, loc, operands[0]);
auto shuffled_index = get_index_from_pattern(pattern, index);
rewriter.replaceOp(op, get_operand<LoadOp>(rewriter, loc, operands[0],
shuffled_index));
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
struct ElemwiseLowering : public ConversionPattern, public GpuLoweringHelper {
ElemwiseLowering(MLIRContext* ctx, scf::ForOp* for_op, Value index,
const Layout& dest)
: ConversionPattern(dialect::Elemwise::getOperationName(), 1, ctx),
GpuLoweringHelper(for_op, index, dest) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value>,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
//! remove the first gpu.terminator
m_launch_op.body().front().front().erase();
//! if (tid >= nr_tid) {return;} in the begin of the block
rewriter.setInsertionPointToStart(&(m_launch_op.body().front()));
Block* cond_block = rewriter.getInsertionBlock();
Block::iterator op_position = rewriter.getInsertionPoint();
Block* remaining_ops_block =
rewriter.splitBlock(cond_block, op_position);
rewriter.setInsertionPointToEnd(cond_block);
auto index = get_tid(rewriter, loc);
auto comparison = rewriter.create<mlir::CmpIOp>(
loc, CmpIPredicate::sge, index,
m_launch_op.getParentOfType<mlir::FuncOp>()
.getArguments()
.back());
Block* then_block =
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToEnd(then_block);
rewriter.create<gpu::TerminatorOp>(loc);
rewriter.setInsertionPointToEnd(cond_block);
rewriter.create<mlir::CondBranchOp>(
loc, comparison, then_block, ArrayRef<Value>(),
remaining_ops_block, ArrayRef<Value>());
rewriter.setInsertionPointToEnd(remaining_ops_block);
rewriter.create<gpu::TerminatorOp>(loc);
// currently Elemwise handles at most three operands
auto inputs = llvm::to_vector<4>(
llvm::map_range(operands, [&](mlir::Value val) {
auto index = map_indices(rewriter, loc, val);
return get_operand<LoadOp>(rewriter, loc, val, index);
}));
rewriter.replaceOp(op,
lower_elemwise_to_std(op, rewriter, loc, inputs));
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
struct ConstantScalarOpLowering
: public OpRewritePattern<dialect::ConstantScalarOp> {
ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: OpRewritePattern<dialect::ConstantScalarOp>(ctx),
m_launch_op{launch_op} {}
LogicalResult matchAndRewrite(dialect::ConstantScalarOp op,
PatternRewriter& rewriter) const final {
dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, scf::ForOp*, Value, const Layout&)
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx) {
}
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
op, constant_scalar_adaptor.value());
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value>,
Rewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op);
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::AssignOp::getOperationName(), 2, ctx),
m_launch_op{launch_op} {}
struct TypeCvtLowering : public ConversionPattern, public GpuLoweringHelper {
TypeCvtLowering(MLIRContext* ctx, scf::ForOp* for_op, Value index,
const Layout& dest)
: ConversionPattern(dialect::TypeCvt::getOperationName(), 1, ctx),
GpuLoweringHelper(for_op, index, dest) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
Rewriter& rewriter) const final {
auto loc = op->getLoc();
set_insertion_point(rewriter);
dialect::AssignOpAdaptor assign_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, assign_adaptor.rhs(),
dst_layout);
auto loaded_lhs =
get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index);
rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index);
auto index = map_indices(rewriter, loc, operands[0]);
auto input = get_operand<LoadOp>(rewriter, loc, operands[0], index);
rewriter.eraseOp(op);
rewriter.replaceOp(op, lower_typecvt_to_std(op, rewriter, loc, input));
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
/* ===================== MgbToGpuLoweringPass ===================== */
class MgbToGpuLoweringPass
: public PassWrapper<MgbToGpuLoweringPass, FunctionPass> {
public:
void getDependentDialects(mlir::DialectRegistry& registry) const override {
registry.insert<mlir::gpu::GPUDialect>();
registry.insert<mlir::StandardOpsDialect>();
}
void getDependentDialects(DialectRegistry& registry) const override;
void runOnFunction() final;
void runOnFunction() override final {
auto func_op = getFunction();
Location loc = func_op.getLoc();
OpBuilder builder(&func_op.getBody());
Value constantOne = builder.create<ConstantIndexOp>(loc, 1);
gpu::LaunchOp launch_op = builder.create<gpu::LaunchOp>(
loc, constantOne, constantOne, constantOne, constantOne,
constantOne, constantOne);
builder.setInsertionPointToEnd(&(launch_op.body().front()));
builder.create<gpu::TerminatorOp>(loc);
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<gpu::GPUDialect>();
target.addIllegalDialect<MgbDialect>();
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering,
ReturnOpLowering, ConstantScalarOpLowering,
AssignOpLowering>(&getContext(), launch_op);
if (failed(applyPartialConversion(func_op, target,
std::move(patterns)))) {
signalPassFailure();
}
}
private:
Value get_idx(OpBuilder& builder, Location loc);
Layout get_dest_layout(FuncOp func_op);
};
void MgbToGpuLoweringPass::getDependentDialects(
DialectRegistry& registry) const {
registry.insert<gpu::GPUDialect, scf::SCFDialect, StandardOpsDialect>();
}
void MgbToGpuLoweringPass::runOnFunction() {
FuncOp func_op = getFunction();
Location loc = func_op.getLoc();
OpBuilder builder(func_op.getBody());
// create gpu::LaunchOp
Value one = builder.create<ConstantIndexOp>(loc, 1);
gpu::LaunchOp launch_op =
builder.create<gpu::LaunchOp>(loc, one, one, one, one, one, one);
builder.setInsertionPointToEnd(&(launch_op.body().front()));
// create scf::ForOp
auto it = func_op.getArguments().end();
Value nr_threads = *(--it);
Value nr_elements = *(--it);
Value idx = get_idx(builder, loc);
auto for_op = builder.create<scf::ForOp>(loc, idx, nr_elements, nr_threads);
builder.create<gpu::TerminatorOp>(loc);
Layout dest = get_dest_layout(func_op);
Value for_idx = for_op.getLoopBody().getArgument(0);
OwningRewritePatternList patterns;
patterns.insert<AssignOpLowering, ConstantScalarOpLowering,
DimshuffleLowering, ElemwiseLowering, ReturnOpLowering,
TypeCvtLowering>(&getContext(), &for_op, for_idx, dest);
ConversionTarget target(getContext());
target.addLegalDialect<gpu::GPUDialect, scf::SCFDialect,
StandardOpsDialect>();
target.addIllegalDialect<MgbDialect>();
if (failed(applyPartialConversion(func_op, target, std::move(patterns)))) {
signalPassFailure();
}
}
//! block_dim * block_idx + thread_idx
Value MgbToGpuLoweringPass::get_idx(OpBuilder& builder, Location loc) {
IndexType idx_type = builder.getIndexType();
StringAttr x = builder.getStringAttr("x");
Value block_dim = builder.create<gpu::BlockDimOp>(loc, idx_type, x);
Value block_idx = builder.create<gpu::BlockIdOp>(loc, idx_type, x);
Value thread_idx = builder.create<gpu::ThreadIdOp>(loc, idx_type, x);
Value prod = builder.create<MulIOp>(loc, block_dim, block_idx);
return builder.create<AddIOp>(loc, prod, thread_idx);
}
//! traverse the body of func_op and get dest_layout from AssignOp
Layout MgbToGpuLoweringPass::get_dest_layout(FuncOp func_op) {
Layout dest_layout;
bool found = false;
func_op.walk([&](dialect::AssignOp assign_op) {
dest_layout = mlir_type_to_layout(assign_op.lhs().getType());
found = true;
return WalkResult::interrupt();
});
mgb_assert(found, "AssignOp not found in the body of FuncOp");
return dest_layout;
}
} // namespace
/* ===================== create_lower_to_gpu_pass ===================== */
std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_gpu_pass() {
return std::make_unique<MgbToGpuLoweringPass>();
}
......
......@@ -80,7 +80,9 @@ private:
for (auto&& arg : args.outputs) {
func_args.push_back(get_type(arg.from->layout()));
}
//! the last arg is nr_elements
//! nr_elements
func_args.push_back(m_builder.getIndexType());
//! nr_threads
func_args.push_back(m_builder.getIndexType());
auto func_type = m_builder.getFunctionType(func_args, llvm::None);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册