/** * \file src/jit/impl/mlir/ir/lower_to_affine_pass.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR #include "./common.h" #include "./each_mode.h" #include "megbrain/common.h" #include "megbrain/jit/mlir/ir/dialect.h" #include "megbrain/jit/mlir/ir/passes.h" #include "megbrain/jit/mlir/ir/utils.h" #include #include #include #include #include using namespace mgb; using namespace jit; namespace { using LoopIterationFn = function_ref; void lower_op_to_loops(Operation* op, ValueRange operands, PatternRewriter& rewriter, LoopIterationFn process_iteration) { auto memref_type = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter); llvm::SmallVector lower_bounds(memref_type.getRank(), 0); llvm::SmallVector steps(memref_type.getRank(), 1); buildAffineLoopNest( rewriter, loc, lower_bounds, memref_type.getShape(), steps, [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { Value value_to_store = process_iteration(nested_builder, operands, ivs); nested_builder.create(loc, value_to_store, alloc, ivs); }); // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); } struct ElemwiseLowering : public ConversionPattern { ElemwiseLowering(MLIRContext* ctx) : ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); auto dst_memref_type = (*op->result_type_begin()).cast(); megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type); dst_layout.init_contiguous_stride(); lower_op_to_loops( op, operands, rewriter, [dst_layout, loc, op](OpBuilder& builder, ValueRange memref_operands, ValueRange loop_ivs) { auto inputs = llvm::to_vector<4>(llvm::map_range( memref_operands, [&](mlir::Value val) { return get_affine_load_op(builder, loc, val, loop_ivs, dst_layout); })); return lower_elemwise_to_std(op, builder, loc, inputs); }); return success(); } }; struct TypeCvtLowering : public ConversionPattern { TypeCvtLowering(MLIRContext* ctx) : ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); lower_op_to_loops( op, operands, rewriter, [loc, op](OpBuilder& builder, ValueRange memref_operands, ValueRange loop_ivs) { mlir::Value input = get_operand( builder, loc, memref_operands[0], loop_ivs); return lower_typecvt_to_std(op, builder, loc, input); }); return success(); } }; struct AssignOpLowering : public ConversionPattern { AssignOpLowering(MLIRContext* ctx) : ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) { } LogicalResult matchAndRewrite( Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { auto loc = op->getLoc(); auto memref_type = operands[0].getType().cast(); dialect::AssignOpAdaptor assign_adaptor(operands); llvm::SmallVector lower_bounds(memref_type.getRank(), 0); llvm::SmallVector steps(memref_type.getRank(), 1); buildAffineLoopNest( rewriter, loc, lower_bounds, memref_type.getShape(), steps, [&](OpBuilder& nested_builder, Location loc, ValueRange ivs) { auto loaded_lhs = nested_builder.create( loc, assign_adaptor.lhs(), ivs); nested_builder.create( loc, loaded_lhs, assign_adaptor.rhs(), ivs); }); rewriter.eraseOp(op); return success(); } }; struct ReturnOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(dialect::ReturnOp op, PatternRewriter& rewriter) const final { // We lower "mgb.return" directly to "std.return". rewriter.replaceOpWithNewOp(op); return success(); } }; struct ConstantScalarOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(dialect::ConstantScalarOp op, PatternRewriter& rewriter) const final { dialect::ConstantScalarOpAdaptor constant_scalar_adaptor(op); rewriter.replaceOpWithNewOp( op, constant_scalar_adaptor.value()); return success(); } }; class MgbToAffineLoweringPass : public PassWrapper { public: void getDependentDialects(mlir::DialectRegistry& registry) const override { registry.insert(); registry.insert(); } void runOnFunction() override final { ConversionTarget target(getContext()); target.addLegalDialect(); // target.addLegalDialect(); target.addIllegalDialect(); OwningRewritePatternList patterns; patterns.insert( &getContext()); if (failed(applyPartialConversion(getFunction(), target, std::move(patterns)))) { signalPassFailure(); } } }; } // namespace std::unique_ptr mgb::jit::create_lower_to_affine_pass() { return std::make_unique(); } namespace mgb { namespace jit { void register_test_mgb_to_affine_lowering_pass() { PassRegistration( "mgb-convert-to-affine", "Perform conversion from MGB Dialect to Affine Dialect ", [] { return std::make_unique(); }); } } // namespace jit } // namespace mgb #endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen