提交 f87bba68 编写于 作者: M Megvii Engine Team

feat(mgb/jit): add scalar support for mlir

GitOrigin-RevId: 27b1649c042c68c840a51657807c5ba0ea2e88d8
上级 11b121a7
......@@ -101,9 +101,9 @@ Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) {
compiler = std::make_unique<CudaCompiler>();
break;
}
#endif
mgb_throw(InternalError, "No compiler support for cuda");
break;
#endif
case CompNode::DeviceType::CPU:
#if MGB_JIT_MLIR
if (!backend || !strcmp(backend, "MLIR")) {
......
......@@ -20,6 +20,10 @@
#if MGB_JIT
#if MGB_JIT_MLIR
#include "./mlir/ir/each_mode.h"
#endif
using namespace mgb;
using namespace gopt;
using namespace jit;
......@@ -339,35 +343,76 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const {
return false;
}
//! As MLIR backend has some contraints
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
// float elemwise
if (auto elem = gopt::try_cast_as_op<opr::Elemwise>(opr)) {
return ast_c::check_elem_mode(elem->param().mode) &&
bool ret = true;
#if MGB_JIT_MLIR
if (!strcmp(backend, "MLIR")) {
switch (elem->param().mode) {
#define cb(_, _mode) \
case opr::Elemwise::Mode::_mode: \
ret = true; \
break;
MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb)
MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
default:
ret = false;
#undef cb
}
#define FOREACH_ELEMWISE_SKIP_MODE(cb) cb(SIN)
//! FIXME mlir on cuda does't support sin currently.
if (opr->output(0)->comp_node().device_type() ==
CompNode::DeviceType::CUDA) {
switch (elem->param().mode) {
#define cb(_mode) \
case opr::Elemwise::Mode::_mode: \
ret = false; \
break;
FOREACH_ELEMWISE_SKIP_MODE(cb)
default:
break;
#undef cb
}
}
#undef FOREACH_ELEMWISE_SKIP_MODE
}
#endif // MGB_JIT_MLIR
return ret && ast_c::check_elem_mode(elem->param().mode) &&
elem->output(0)->dtype().category() == DTypeCategory::FLOAT;
}
if (opr->same_type<opr::PowC>()) {
return true;
}
if (strcmp(backend, "MLIR")) {
if (opr->same_type<opr::PowC>()) {
return true;
}
// float typecvt (e.g. used in f16 training)
if (opr->same_type<opr::TypeCvt>()) {
auto category = opr->input(0)->dtype().category();
if (category != opr->output(0)->dtype().category())
return false;
return category == DTypeCategory::FLOAT;
}
// float typecvt (e.g. used in f16 training)
if (opr->same_type<opr::TypeCvt>()) {
auto category = opr->input(0)->dtype().category();
if (category != opr->output(0)->dtype().category())
return false;
return category == DTypeCategory::FLOAT;
}
// float reduce
if ((m_feature_bits & JITFeatureBits::REDUCE) &&
opr->same_type<opr::Reduce>()) {
return opr->output(0)->dtype().category() == DTypeCategory::FLOAT;
}
// float reduce
if ((m_feature_bits & JITFeatureBits::REDUCE) &&
opr->same_type<opr::Reduce>()) {
return opr->output(0)->dtype().category() == DTypeCategory::FLOAT;
}
// dimshuffle
if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) &&
opr->same_type<opr::Dimshuffle>()) {
auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
return param.pattern_len <= 4;
// dimshuffle
if ((m_feature_bits & JITFeatureBits::DIMSHUFFLE) &&
opr->same_type<opr::Dimshuffle>()) {
auto param = opr->cast_final_safe<opr::Dimshuffle>().param();
return param.pattern_len <= 4;
}
}
// existing JITExecutor
......
......@@ -10,7 +10,6 @@
* implied.
*/
#include "llvm/Pass.h"
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
......@@ -40,6 +39,7 @@
#include <llvm/Support/TargetSelect.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Linker/Linker.h>
#include <llvm/Pass.h>
#include <dlfcn.h>
#include <dirent.h>
......
......@@ -77,6 +77,16 @@ private:
mlir::Location m_location;
};
template <typename Op>
mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc,
const mlir::Value& val, const mlir::ValueRange& index) {
if (val.getType().isa<mlir::MemRefType>()) {
return builder.create<Op>(loc, val, index);
} else {
return val;
}
}
} // namespace jit
} // namespace mgb
......
......@@ -14,6 +14,7 @@
#if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/jit/mlir/ir/dialect.h"
#include "./types.h"
#include <mlir/IR/Builders.h>
#include <mlir/IR/OpImplementation.h>
......
......@@ -74,8 +74,8 @@ struct UnaryOpLowering : public ConversionPattern {
typename Op::Adaptor binary_adaptor(memref_operands);
LoweredOp lower_op;
auto loaded_lhs = builder.create<AffineLoadOp>(
loc, binary_adaptor.lhs(), loop_ivs);
auto loaded_lhs = get_operand<AffineLoadOp>(
builder, loc, binary_adaptor.lhs(), loop_ivs);
return lower_op(builder, loc, {loaded_lhs});
});
......@@ -104,10 +104,10 @@ struct BinaryOpLowering : public ConversionPattern {
typename Op::Adaptor binary_adaptor(memref_operands);
LoweredOp lower_op;
auto loaded_lhs = builder.create<AffineLoadOp>(
loc, binary_adaptor.lhs(), loop_ivs);
auto loaded_rhs = builder.create<AffineLoadOp>(
loc, binary_adaptor.rhs(), loop_ivs);
auto loaded_lhs = get_operand<AffineLoadOp>(
builder, loc, binary_adaptor.lhs(), loop_ivs);
auto loaded_rhs = get_operand<AffineLoadOp>(
builder, loc, binary_adaptor.rhs(), loop_ivs);
return lower_op(builder, loc, {loaded_lhs, loaded_rhs});
});
......@@ -136,12 +136,12 @@ struct TernaryOpLowering : public ConversionPattern {
typename Op::Adaptor ternary_adaptor(memref_operands);
LoweredOp lower_op;
auto loaded_x = builder.create<AffineLoadOp>(
loc, ternary_adaptor.x(), loop_ivs);
auto loaded_y = builder.create<AffineLoadOp>(
loc, ternary_adaptor.y(), loop_ivs);
auto loaded_z = builder.create<AffineLoadOp>(
loc, ternary_adaptor.z(), loop_ivs);
auto loaded_x = get_operand<AffineLoadOp>(
builder, loc, ternary_adaptor.x(), loop_ivs);
auto loaded_y = get_operand<AffineLoadOp>(
builder, loc, ternary_adaptor.y(), loop_ivs);
auto loaded_z = get_operand<AffineLoadOp>(
builder, loc, ternary_adaptor.z(), loop_ivs);
return lower_op(builder, loc,
{loaded_x, loaded_y, loaded_z});
......@@ -193,6 +193,19 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> {
}
};
struct ConstantScalarOpLowering
: public OpRewritePattern<jit::ConstantScalarOp> {
using OpRewritePattern<jit::ConstantScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(jit::ConstantScalarOp op,
PatternRewriter& rewriter) const final {
ConstantScalarOpAdaptor constant_scalar_adaptor(op);
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
op, constant_scalar_adaptor.value());
return success();
}
};
class MgbToAffineLoweringPass
: public PassWrapper<MgbToAffineLoweringPass, FunctionPass> {
public:
......@@ -207,7 +220,8 @@ public:
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
ReturnOpLowering,
AssignOpLowering>(&getContext());
AssignOpLowering, ConstantScalarOpLowering>(
&getContext());
#undef cb
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
......
......@@ -38,16 +38,6 @@ using namespace jit;
namespace {
mlir::Value get_operand(ConversionPatternRewriter& rewriter,
const mlir::Location& loc, const mlir::Value& val,
const mlir::Value& index) {
if (val.getType().isa<mlir::MemRefType>()) {
return rewriter.create<LoadOp>(loc, val, index);
} else {
return val;
}
}
mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
auto thread_idx = rewriter.create<gpu::ThreadIdOp>(
loc, rewriter.getIndexType(), rewriter.getStringAttr("x"));
......@@ -64,7 +54,7 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
template <typename Op, typename LoweredOp>
struct UnaryOpLowering : public ConversionPattern {
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(Op::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
......@@ -74,11 +64,11 @@ struct UnaryOpLowering : public ConversionPattern {
auto loc = op->getLoc();
typename Op::Adaptor binary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto loaded_lhs =
get_operand(rewriter, loc, binary_adaptor.lhs(), index);
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index);
LoweredOp lower_op;
......@@ -87,7 +77,7 @@ struct UnaryOpLowering : public ConversionPattern {
}
private:
gpu::LaunchOp* m_launch_op;
gpu::LaunchOp& m_launch_op;
};
#define cb(_op, _) \
......@@ -97,7 +87,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb)
template <typename Op, typename LoweredOp>
struct BinaryOpLowering : public ConversionPattern {
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
BinaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(Op::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
......@@ -107,13 +97,13 @@ struct BinaryOpLowering : public ConversionPattern {
auto loc = op->getLoc();
typename Op::Adaptor binary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto loaded_lhs =
get_operand(rewriter, loc, binary_adaptor.lhs(), index);
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index);
auto loaded_rhs =
get_operand(rewriter, loc, binary_adaptor.rhs(), index);
get_operand<LoadOp>(rewriter, loc, binary_adaptor.rhs(), index);
LoweredOp lower_op;
......@@ -123,7 +113,7 @@ struct BinaryOpLowering : public ConversionPattern {
}
private:
gpu::LaunchOp* m_launch_op;
gpu::LaunchOp& m_launch_op;
};
#define cb(_op, _) \
......@@ -133,7 +123,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
template <typename Op, typename LoweredOp>
struct TernaryOpLowering : public ConversionPattern {
TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
TernaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(Op::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
......@@ -143,15 +133,15 @@ struct TernaryOpLowering : public ConversionPattern {
auto loc = op->getLoc();
typename Op::Adaptor ternary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto loaded_x =
get_operand(rewriter, loc, ternary_adaptor.x(), index);
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), index);
auto loaded_y =
get_operand(rewriter, loc, ternary_adaptor.y(), index);
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), index);
auto loaded_z =
get_operand(rewriter, loc, ternary_adaptor.z(), index);
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), index);
LoweredOp lower_op;
......@@ -161,7 +151,7 @@ struct TernaryOpLowering : public ConversionPattern {
}
private:
gpu::LaunchOp* m_launch_op;
gpu::LaunchOp& m_launch_op;
};
#define cb(_op, _) \
......@@ -171,7 +161,7 @@ MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
#undef cb
struct ReturnOpLowering : public ConversionPattern {
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
ReturnOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(jit::ReturnOp::getOperationName(), 1, ctx),
m_launch_op{launch_op} {}
......@@ -182,10 +172,10 @@ struct ReturnOpLowering : public ConversionPattern {
auto loc = op->getLoc();
//! remove the first gpu.terminator
m_launch_op->body().front().front().erase();
m_launch_op.body().front().front().erase();
//! if (tid >= nr_tid) {return;} in the begin of the block
rewriter.setInsertionPointToStart(&(m_launch_op->body().front()));
rewriter.setInsertionPointToStart(&(m_launch_op.body().front()));
Block* cond_block = rewriter.getInsertionBlock();
Block::iterator op_position = rewriter.getInsertionPoint();
Block* remaining_ops_block =
......@@ -195,7 +185,7 @@ struct ReturnOpLowering : public ConversionPattern {
auto index = get_tid(rewriter, loc);
auto comparison = rewriter.create<mlir::CmpIOp>(
loc, CmpIPredicate::sge, index,
m_launch_op->getParentOfType<mlir::FuncOp>()
m_launch_op.getParentOfType<mlir::FuncOp>()
.getArguments()
.back());
......@@ -216,11 +206,31 @@ struct ReturnOpLowering : public ConversionPattern {
}
private:
gpu::LaunchOp* m_launch_op;
gpu::LaunchOp& m_launch_op;
};
struct ConstantScalarOpLowering
: public OpRewritePattern<jit::ConstantScalarOp> {
ConstantScalarOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: OpRewritePattern<jit::ConstantScalarOp>(ctx),
m_launch_op{launch_op} {}
LogicalResult matchAndRewrite(jit::ConstantScalarOp op,
PatternRewriter& rewriter) const final {
ConstantScalarOpAdaptor constant_scalar_adaptor(op);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(
op, constant_scalar_adaptor.value());
return success();
}
private:
gpu::LaunchOp& m_launch_op;
};
struct AssignOpLowering : public ConversionPattern {
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp* launch_op)
AssignOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
: ConversionPattern(jit::AssignOp::getOperationName(), 2, ctx),
m_launch_op{launch_op} {}
......@@ -230,12 +240,12 @@ struct AssignOpLowering : public ConversionPattern {
auto loc = op->getLoc();
AssignOpAdaptor assign_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op->body().front()));
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto loaded_lhs =
get_operand(rewriter, loc, assign_adaptor.lhs(), index);
get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index);
rewriter.create<StoreOp>(loc, loaded_lhs, assign_adaptor.rhs(), index);
rewriter.eraseOp(op);
......@@ -243,7 +253,7 @@ struct AssignOpLowering : public ConversionPattern {
}
private:
gpu::LaunchOp* m_launch_op;
gpu::LaunchOp& m_launch_op;
};
class MgbToGpuLoweringPass
......@@ -271,7 +281,8 @@ public:
cb) MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb)
MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb)
ReturnOpLowering,
AssignOpLowering>(&getContext(), &launch_op);
ConstantScalarOpLowering, AssignOpLowering>(
&getContext(), launch_op);
#undef cb
if (failed(applyPartialConversion(func_op, target, patterns))) {
......
......@@ -17,6 +17,7 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "./interfaces.td"
include "./predicates.td"
def Mgb_Dialect : Dialect {
let name = "mgb";
......@@ -90,7 +91,7 @@ def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>;
class ElemwiseBinaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> :
ElemwiseOp<mnemonic, traits> {
let arguments = (ins F32MemRef:$lhs, F32MemRef:$rhs);
let arguments = (ins ElemwiseFloatAny:$lhs, ElemwiseFloatAny:$rhs);
let results = (outs F32MemRef);
let builders = [OpBuilder<
......@@ -141,7 +142,7 @@ def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>;
class ElemwiseTernaryOp<string mnemonic, list<OpTrait> traits = [NoSideEffect]> :
ElemwiseOp<mnemonic, traits> {
let arguments = (ins F32MemRef:$x, F32MemRef:$y, F32MemRef:$z);
let arguments = (ins ElemwiseFloatAny:$x, ElemwiseFloatAny:$y, ElemwiseFloatAny:$z);
let results = (outs F32MemRef);
let builders = [OpBuilder<
......@@ -178,6 +179,25 @@ def ReturnOp : GenericOp<"return",
}
def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> {
let summary = "scalar constant";
let arguments = (ins AnyAttr:$value);
let results = (outs F32:$result);
let builders = [OpBuilder<
"Builder* builder, OperationState& result, float value", [{
result.addAttribute("value", builder->getF32FloatAttr(value));
result.addTypes(builder->getF32Type());
}]
>];
let extraClassDeclaration = [{
Attribute getValue() { return getAttr("value"); }
FloatAttr getFloatAttr() { return getAttrOfType<FloatAttr>("value"); }
}];
}
def AssignOp : GenericOp<"assign", []> {
let summary = "assign op";
let description = [{
......
/**
* \file src/jit/impl/mlir/ir/predicates.td
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#ifndef MGB_MLIR_PREDICATES
#define MGB_MLIR_PREDICATES
#ifndef OP_BASE
include "mlir/IR/OpBase.td"
#endif
def ElemwiseFloatAny : TypeConstraint<
CPred<"is_elemwise_float($_self)">, "elemwise-float">;
#endif
/**
* \file src/jit/impl/mlir/ir/types.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include <mlir/IR/StandardTypes.h>
namespace mgb {
namespace jit {
inline bool is_elemwise_float(const mlir::Type& dt) {
if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) {
if (cast.getElementType().getKind() == mlir::StandardTypes::F32) {
return true;
}
}
if (dt.isa<mlir::FloatType>()) {
return true;
}
return false;
}
} // namespace jit
} // namespace mgb
#endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen
......@@ -49,6 +49,9 @@ mlir::Type jit::deduce_result_type(mlir::ValueRange operands) {
megdnn::TensorShape dst;
megdnn::DType dst_type;
for (auto operand : operands) {
if (operand.getType().isa<mlir::FloatType>()) {
continue;
}
auto type = operand.getType().dyn_cast_or_null<mlir::MemRefType>();
mgb_assert(type, "currently only support MemRefType");
......
......@@ -137,6 +137,27 @@ private:
return;
}
if (opr->same_type<opr::ImmutableTensor>()) {
auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar();
if (imm.valid()) {
auto dtype = imm->dtype();
float scalar_value;
if (dtype == dtype::Float32()) {
scalar_value = imm->get<float>();
} else {
mgb_throw(InternalError,
"mlir backend currently only support f32 "
"dtype, but got %s",
dtype.name());
}
auto&& out = m_builder.create<jit::ConstantScalarOp>(
m_builder.getUnknownLoc(), m_builder.getF32Type(),
m_builder.getF32FloatAttr(scalar_value));
mgb_assert(mlir::succeeded(
declare(opr->output(0)->name(), out)));
}
}
if (opr->same_type<opr::Elemwise>()) {
auto&& out = gen_op(opr->cast_final<opr::Elemwise>());
mgb_assert(
......
......@@ -137,7 +137,7 @@ void run_mlir(CompNode cn) {
b = opr::Host2DeviceCopy::make(*graph, host_x1),
c = opr::Host2DeviceCopy::make(*graph, host_x2);
auto y = a + b * c;
auto y = a + b * c + 0.3f;
auto ig_gen =
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册