/** * \file src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR #include "megbrain/common.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/passes.h" #include "../utils.h" #include #include #include #include #include #include #include #include #include #include #include #include using namespace mgb; using namespace jit; namespace { mlir::Value get_operand(ConversionPatternRewriter& rewriter, const mlir::Location& loc, const mlir::Value& val, const mlir::Value& index) { if (val.getType().isa()) { return rewriter.create(loc, val, index); } else { return val; } } mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) { auto thread_idx = rewriter.create( loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); auto block_idx = rewriter.create( loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); auto group_size = rewriter.create( loc, rewriter.getIndexType(), rewriter.getStringAttr("x")); Value index = rewriter.create( loc, thread_idx, rewriter.create(loc, block_idx, group_size)); return index; } template struct BinaryOpLowering : public ConversionPattern { BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) : ConversionPattern(BinaryOp::getOperationName(), 1, ctx), m_launch_op{launch_op} {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); typename BinaryOp::Adaptor binary_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); auto index = get_tid(rewriter, loc); auto loaded_lhs = get_operand(rewriter, loc, binary_adaptor.lhs(), index); auto loaded_rhs = get_operand(rewriter, loc, binary_adaptor.rhs(), index); auto binary_op = rewriter.create(loc, loaded_lhs, loaded_rhs); rewriter.replaceOp(op, binary_op.getResult()); return success(); } private: gpu::LaunchOp* m_launch_op; }; using AddOpLowering = BinaryOpLowering; struct ReturnOpLowering : public ConversionPattern { ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) : ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx), m_launch_op{launch_op} {} LogicalResult matchAndRewrite( Operation* op, ArrayRef, ConversionPatternRewriter& rewriter) const final { rewriter.replaceOpWithNewOp(op); auto loc = op->getLoc(); //! remove the first gpu.terminator m_launch_op->body().front().front().erase(); //! if (tid >= nr_tid) {return;} in the begin of the block rewriter.setInsertionPointToStart(&(m_launch_op->body().front())); Block* cond_block = rewriter.getInsertionBlock(); Block::iterator op_position = rewriter.getInsertionPoint(); Block* remaining_ops_block = rewriter.splitBlock(cond_block, op_position); rewriter.setInsertionPointToEnd(cond_block); auto index = get_tid(rewriter, loc); auto comparison = rewriter.create( loc, CmpIPredicate::sge, index, m_launch_op->getParentOfType() .getArguments() .back()); Block* then_block = rewriter.splitBlock(cond_block, rewriter.getInsertionPoint()); rewriter.setInsertionPointToEnd(then_block); rewriter.create(loc); rewriter.setInsertionPointToEnd(cond_block); rewriter.create( loc, comparison, then_block, ArrayRef(), remaining_ops_block, ArrayRef()); rewriter.setInsertionPointToEnd(remaining_ops_block); rewriter.create(loc); return success(); } private: gpu::LaunchOp* m_launch_op; }; struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op) : ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx), m_launch_op{launch_op} {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); AssignOpAdaptor assign_adaptor(operands); rewriter.setInsertionPointToEnd(&(m_launch_op->body().front())); auto index = get_tid(rewriter, loc); auto loaded_lhs = get_operand(rewriter, loc, assign_adaptor.lhs(), index); rewriter.create(loc, loaded_lhs, assign_adaptor.rhs(), index); rewriter.eraseOp(op); return success(); } private: gpu::LaunchOp* m_launch_op; }; class MgbToGpuLoweringPass : public PassWrapper { public: void runOnFunction() override final { auto func_op = getFunction(); Location loc = func_op.getLoc(); OpBuilder builder(&func_op.getBody()); Value constantOne = builder.create(loc, 1); gpu::LaunchOp launch_op = builder.create( loc, constantOne, constantOne, constantOne, constantOne, constantOne, constantOne); builder.setInsertionPointToEnd(&(launch_op.body().front())); builder.create(loc); OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalDialect(); target.addIllegalDialect(); patterns.insert( &getContext(), &launch_op); if (failed(applyPartialConversion(func_op, target, patterns))) { signalPassFailure(); } } }; } // namespace std::unique_ptr mgb::jit::create_lower_to_gpu_pass() { return std::make_unique(); } #endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen