提交 5d0f8da4 编写于 作者: M Megvii Engine Team

feat(mgb/jit): add Dimshuffle and lowering passes in jit mlir backend

GitOrigin-RevId: ce6f4ea42a876fafbb7ca67f30d5c0fa96d28096
上级 0007b9e0
......@@ -34,7 +34,7 @@ public:
Property property() const override {
using F = Property::Flag;
return Property{F::NEED_INPUT_COLLAPSE | F::BIND_NDIM,
JITFeatureBits::NONE, 64};
JITFeatureBits::DIMSHUFFLE, 64};
}
size_t get_nr_workspace_outputs(JITExecutor* opr) const override;
......
......@@ -62,6 +62,7 @@ struct ElemwiseLowering : public ConversionPattern {
ElemwiseLowering(MLIRContext* ctx)
: ConversionPattern(mgb::dialect::Elemwise::getOperationName(), 1,
ctx) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
......@@ -89,6 +90,7 @@ struct TypeCvtLowering : public ConversionPattern {
TypeCvtLowering(MLIRContext* ctx)
: ConversionPattern(mgb::dialect::TypeCvt::getOperationName(), 1,
ctx) {}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
......@@ -105,6 +107,41 @@ struct TypeCvtLowering : public ConversionPattern {
}
};
struct DimshuffleLowering : public ConversionPattern {
DimshuffleLowering(MLIRContext* ctx)
: ConversionPattern(mgb::dialect::Dimshuffle::getOperationName(), 1,
ctx) {}
static mlir::AffineMap get_affinemap_from_pattern(
const std::vector<int32_t>& pattern, mlir::MLIRContext* ctx) {
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
std::vector<mlir::AffineExpr> exprs(ndim);
for (size_t i = 0; i < pattern.size(); i++) {
int32_t j = pattern[i];
if (j >= 0) {
exprs[j] = mlir::getAffineDimExpr(i, ctx);
}
}
return mlir::AffineMap::get(pattern.size(), 0, exprs, ctx);
}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern();
auto map = get_affinemap_from_pattern(pattern, op->getContext());
lower_op_to_loops(
op, operands, rewriter,
[loc, op, &map](OpBuilder& builder, ValueRange memref_operands,
ValueRange loop_ivs) {
return builder.create<AffineLoadOp>(loc, memref_operands[0],
map, loop_ivs);
});
return success();
}
};
struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx)
: ConversionPattern(dialect::AssignOp::getOperationName(), 1, ctx) {
......@@ -172,9 +209,9 @@ public:
target.addIllegalDialect<MgbDialect>();
OwningRewritePatternList patterns;
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering,
AssignOpLowering, ConstantScalarOpLowering>(
&getContext());
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering,
ReturnOpLowering, AssignOpLowering,
ConstantScalarOpLowering>(&getContext());
if (failed(applyPartialConversion(getFunction(), target,
std::move(patterns)))) {
......
......@@ -152,6 +152,47 @@ private:
gpu::LaunchOp& m_launch_op;
};
struct DimshuffleLowering : public ConversionPattern {
DimshuffleLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::Dimshuffle::getOperationName(), 1,
ctx),
m_launch_op{launch_op} {}
static std::vector<mlir::Value> get_index_from_pattern(
const std::vector<int32_t>& pattern,
const std::vector<mlir::Value>& index) {
size_t ndim = *std::max_element(pattern.begin(), pattern.end()) + 1;
std::vector<mlir::Value> res(ndim);
for (size_t i = 0; i < pattern.size(); i++) {
int32_t j = pattern[i];
if (j >= 0) {
res[j] = index[i];
}
}
return res;
}
LogicalResult matchAndRewrite(
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, operands[0], dst_layout);
auto pattern = llvm::dyn_cast<dialect::Dimshuffle>(op).pattern();
auto shuffled_index = get_index_from_pattern(pattern, index);
rewriter.replaceOp(op, get_operand<LoadOp>(rewriter, loc, operands[0],
shuffled_index));
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(dialect::ReturnOp::getOperationName(), 1, ctx),
......@@ -275,9 +316,9 @@ public:
target.addLegalDialect<gpu::GPUDialect>();
target.addIllegalDialect<MgbDialect>();
patterns.insert<ElemwiseLowering, TypeCvtLowering, ReturnOpLowering,
ConstantScalarOpLowering, AssignOpLowering>(
&getContext(), launch_op);
patterns.insert<ElemwiseLowering, TypeCvtLowering, DimshuffleLowering,
ReturnOpLowering, ConstantScalarOpLowering,
AssignOpLowering>(&getContext(), launch_op);
if (failed(applyPartialConversion(func_op, target,
std::move(patterns)))) {
......
......@@ -20,6 +20,7 @@
#include "megbrain/jit/mlir/ir/dialect.h"
#include "megbrain/jit/mlir/ir/utils.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/tensor_manip.h"
#include "megdnn/dtype.h"
#include <mlir/Dialect/Affine/IR/AffineOps.h>
......@@ -160,6 +161,10 @@ private:
mgb_assert(
mlir::succeeded(declare(opr->output(0)->name(), out)));
return;
} else if (opr->same_type<opr::Dimshuffle>()) {
auto&& out = gen_dimshuffle(opr->cast_final<opr::Dimshuffle>());
mgb_assert(
mlir::succeeded(declare(opr->output(0)->name(), out)));
} else if (opr->same_type<opr::TypeCvt>()) {
auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>());
mgb_assert(
......@@ -186,18 +191,44 @@ private:
}
mlir::Value gen_typecvt(const opr::TypeCvt& opr) {
auto shape = get(opr.input(0))
auto itype = get(opr.input(0))
.getType()
.dyn_cast_or_null<mlir::MemRefType>()
.getShape();
.dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(itype, "currently only support MemRefType");
auto res_type = mlir::MemRefType::get(
shape,
itype.getShape(),
megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext()));
return m_builder.create<dialect::TypeCvt>(
m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
opr.input(0)->dtype(), opr.param());
}
mlir::Value gen_dimshuffle(const opr::Dimshuffle& opr) {
auto itype = get(opr.input(0))
.getType()
.dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(itype, "the input type of Dimshuffle must be MemRefType");
auto ishape = itype.getShape();
auto param = opr.param();
std::vector<int32_t> pattern;
std::vector<int64_t> oshape;
for (size_t i = 0; i < param.pattern_len; i++) {
int32_t j = param.pattern[i];
pattern.push_back(j);
if (j < 0) {
oshape.push_back(1);
} else {
oshape.push_back(ishape[j]);
}
}
auto res_type = mlir::MemRefType::get(oshape, itype.getElementType());
return m_builder.create<dialect::Dimshuffle>(
m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
pattern);
}
mlir::Type get_type(const TensorLayout& layout) {
return layout_to_mlir_type(layout, m_builder);
}
......
......@@ -15,6 +15,7 @@
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/test/helper.h"
#include "megdnn/dtype.h"
......@@ -539,6 +540,51 @@ add_typecvt_gtest(Uint8, Float32);
#undef add_typecvt_gtest
/* ===================== TestJITMlirDimshuffle ===================== */
void run_dimshuffle(CompNode cn, TensorShape ishape,
const std::vector<int>& pattern) {
set_backend(Backend::MLIR);
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
auto host_x = gen(ishape, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto y = opr::Dimshuffle::make(x, pattern);
auto ig_gen = std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
for (auto i : get_rev_topo_order(y)) {
if (!i->template same_type<opr::Host2DeviceCopy>()) {
ig_gen->add_opr(i);
}
}
auto igraph = ig_gen->generate();
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_jit, host_y_jit)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}
void run_dimshuffle_cases(CompNode cn) {
run_dimshuffle(cn, {3, 4, 5}, {2, 0, 1});
run_dimshuffle(cn, {3, 4, 5}, {1, -1, 0, 2});
}
TEST(TestJITMlirDimshuffle, Basic) {
run_dimshuffle_cases(CompNode::load("cpu0"));
}
TEST(TestJITMlirDimshuffle, BasicGPU) {
REQUIRE_GPU(1);
run_dimshuffle_cases(CompNode::load("gpu0"));
}
#endif // MGB_JIT_MLIR
#endif // MGB_JIT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册