提交 9682db98 编写于 作者: M Megvii Engine Team

feat(mgb): add jit mlir elemwise broadcast

GitOrigin-RevId: 89d5e2f91eab46bc66fea014cf9170e49b5dfc4e
上级 89303cd8
......@@ -294,22 +294,6 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
cond_nr_inp = ig_gen->get_cnt_input_if_add(opr) <= max_nr_input,
cond_mlir_specific = true;
#if MGB_JIT_MLIR
//! FIXME mlir does't support broadcast currently.
auto backend = MGB_GETENV("MGB_JIT_BACKEND");
if (backend && !strcmp(backend, "MLIR")) {
for (VarNode* var : opr->input()) {
if (!SymbolVar{var}.as_immutable_scalar().valid()) {
if (opr->node_prop().dep_map().at(var) &
DepType::DEV_VALUE) {
if (!var->shape().eq_shape(opr->output(0)->shape())) {
cond_mlir_specific = false;
}
}
}
}
}
#endif
if (cond_readers && cond_cn && cond_shp && cond_nr_inp &&
cond_mlir_specific) {
ig_gen->add_opr(opr);
......
......@@ -57,23 +57,23 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
}
};
for (const auto& arg : args.inputs) {
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout);
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout());
}
int64_t nr_elements = 0;
for (const auto& arg : args.outputs) {
if (nr_elements == 0) {
nr_elements = arg.layout.total_nr_elems();
nr_elements = arg.from->layout().total_nr_elems();
} else {
mgb_assert(static_cast<size_t>(nr_elements) ==
arg.layout.total_nr_elems(),
"The number of elements of outputs mismatch, expected: "
"%zu got: %zu(%s)",
static_cast<size_t>(nr_elements),
arg.layout.total_nr_elems(),
arg.layout.to_string().c_str());
arg.from->layout().total_nr_elems(),
arg.from->layout().to_string().c_str());
}
set_params(arg.from->dev_tensor().raw_ptr(), arg.layout);
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout());
}
const CompNodeEnv& env =
CompNodeEnv::from_comp_node(fusion_opr->comp_node());
......@@ -134,8 +134,8 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr,
mgb_assert(fusion_opr->args().outputs.size() == 1,
"Currently only support 1 outputs, got %zu",
fusion_opr->args().outputs.size());
int out_dim = fusion_opr->args().outputs[0].layout.ndim;
DType dtype = fusion_opr->args().outputs[0].layout.dtype;
int out_dim = fusion_opr->args().outputs[0].from->layout().ndim;
DType dtype = fusion_opr->args().outputs[0].from->layout().dtype;
#define cb_outdim(_ndim, _dtype) \
if (_ndim == out_dim) { \
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \
......
......@@ -14,8 +14,10 @@
#if MGB_JIT && MGB_JIT_MLIR
#include "./common.h"
#include "megbrain/jit/mlir/ir/utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include <mlir/Dialect/Affine/IR/AffineOps.h>
using namespace mgb;
using namespace jit;
......@@ -28,9 +30,11 @@ cb(add, AddFOp);
cb(sub, SubFOp);
cb(mul, MulFOp);
cb(div, DivFOp);
cb(divI, SignedDivIOp);
cb(mod, RemFOp);
cb(bit_and, AndOp);
cb(bit_or, OrOp);
cb(modI, SignedRemIOp);
#undef cb
#define cb(name, mode) \
......@@ -62,6 +66,11 @@ mlir::Value ValueBuilderHelper::const_val(float val) {
m_builder.getF32FloatAttr(val));
}
mlir::Value ValueBuilderHelper::constI(int32_t val) {
return m_builder.create<mlir::ConstantOp>(m_location,
m_builder.getIndexAttr(val));
}
#define cb(name, op) \
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \
return m_builder.create<mlir::op>(m_location, lhs); \
......@@ -97,6 +106,44 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val,
false_val);
}
mlir::AffineMap jit::get_affinemap(mlir::OpBuilder& builder,
const mlir::Value& val,
const megdnn::TensorLayout& layout) {
auto type = val.getType().cast<mlir::MemRefType>();
mgb_assert(type, "currently only support MemRefType");
std::vector<mlir::AffineExpr> exprs;
for (int i = 0; i < type.getRank(); ++i) {
if (layout[i] == 1) {
exprs.push_back(builder.getAffineConstantExpr(0));
} else {
exprs.push_back(builder.getAffineDimExpr(i));
}
}
auto map = mlir::AffineMap::get(type.getRank(), 0, exprs,
builder.getContext());
return map;
}
mlir::Value jit::get_affine_load_op(mlir::OpBuilder& builder,
const mlir::Location& loc,
const mlir::Value& val,
const mlir::ValueRange& index,
const megdnn::TensorLayout& dst) {
if (val.getType().isa<mlir::MemRefType>()) {
auto type = val.getType().cast<mlir::MemRefType>();
megdnn::TensorLayout src_layout = mlir_type_to_layout(type);
src_layout.init_contiguous_stride();
if (src_layout.eq_shape(dst)) {
return builder.create<mlir::AffineLoadOp>(loc, val, index);
} else {
auto lhs_map = get_affinemap(builder, val, src_layout);
return builder.create<mlir::AffineLoadOp>(loc, val, lhs_map, index);
}
} else {
return val;
}
}
#endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen
......@@ -14,7 +14,7 @@
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/tensor.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/Value.h>
......@@ -39,9 +39,11 @@ public:
cb(sub);
cb(mul);
cb(div);
cb(divI);
cb(max);
cb(min);
cb(mod);
cb(modI);
cb(gt);
cb(ge);
cb(lt);
......@@ -51,6 +53,7 @@ public:
cb(bit_or);
#undef cb
mlir::Value const_val(float val);
mlir::Value constI(int32_t val);
#define cb(name) \
mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \
......@@ -89,6 +92,15 @@ mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc,
}
}
mlir::AffineMap get_affinemap(mlir::OpBuilder& builder, const mlir::Value& val,
const TensorLayout& layout);
mlir::Value get_affine_load_op(mlir::OpBuilder& builder,
const mlir::Location& loc,
const mlir::Value& val,
const mlir::ValueRange& index,
const TensorLayout& dst);
} // namespace jit
} // namespace mgb
......
......@@ -42,8 +42,8 @@ void lower_op_to_loops(Operation* op, ValueRange operands,
auto alloc = jit::insert_alloc_and_dealloc(memref_type, loc, rewriter);
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
buildAffineLoopNest(
rewriter, loc, lower_bounds, memref_type.getShape(), steps,
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) {
......@@ -96,17 +96,23 @@ struct BinaryOpLowering : public ConversionPattern {
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>();
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type);
dst_layout.init_contiguous_stride();
lower_op_to_loops(
op, operands, rewriter,
[loc](OpBuilder& builder, ValueRange memref_operands,
ValueRange loop_ivs) {
[dst_layout, loc, this](OpBuilder& builder,
ValueRange memref_operands,
ValueRange loop_ivs) {
typename Op::Adaptor binary_adaptor(memref_operands);
LoweredOp lower_op;
auto loaded_lhs = get_operand<AffineLoadOp>(
builder, loc, binary_adaptor.lhs(), loop_ivs);
auto loaded_rhs = get_operand<AffineLoadOp>(
builder, loc, binary_adaptor.rhs(), loop_ivs);
auto loaded_lhs = get_affine_load_op(builder, loc,
binary_adaptor.lhs(),
loop_ivs, dst_layout);
auto loaded_rhs = get_affine_load_op(builder, loc,
binary_adaptor.rhs(),
loop_ivs, dst_layout);
return lower_op(builder, loc, {loaded_lhs, loaded_rhs});
});
......@@ -128,19 +134,26 @@ struct TernaryOpLowering : public ConversionPattern {
Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op->getLoc();
auto dst_memref_type = (*op->result_type_begin()).cast<MemRefType>();
megdnn::TensorLayout dst_layout = mlir_type_to_layout(dst_memref_type);
dst_layout.init_contiguous_stride();
lower_op_to_loops(
op, operands, rewriter,
[loc](OpBuilder& builder, ValueRange memref_operands,
ValueRange loop_ivs) {
[dst_layout, loc](OpBuilder& builder,
ValueRange memref_operands,
ValueRange loop_ivs) {
typename Op::Adaptor ternary_adaptor(memref_operands);
LoweredOp lower_op;
auto loaded_x = get_operand<AffineLoadOp>(
builder, loc, ternary_adaptor.x(), loop_ivs);
auto loaded_y = get_operand<AffineLoadOp>(
builder, loc, ternary_adaptor.y(), loop_ivs);
auto loaded_z = get_operand<AffineLoadOp>(
builder, loc, ternary_adaptor.z(), loop_ivs);
auto loaded_x = get_affine_load_op(builder, loc,
ternary_adaptor.x(),
loop_ivs, dst_layout);
auto loaded_y = get_affine_load_op(builder, loc,
ternary_adaptor.y(),
loop_ivs, dst_layout);
auto loaded_z = get_affine_load_op(builder, loc,
ternary_adaptor.z(),
loop_ivs, dst_layout);
return lower_op(builder, loc,
{loaded_x, loaded_y, loaded_z});
......@@ -166,8 +179,8 @@ struct AssignOpLowering : public ConversionPattern {
auto memref_type = operands[0].getType().cast<MemRefType>();
AssignOpAdaptor assign_adaptor(operands);
SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
llvm::SmallVector<int64_t, 4> lower_bounds(memref_type.getRank(), 0);
llvm::SmallVector<int64_t, 4> steps(memref_type.getRank(), 1);
buildAffineLoopNest(
rewriter, loc, lower_bounds, memref_type.getShape(), steps,
[&](OpBuilder& nested_builder, Location loc, ValueRange ivs) {
......
......@@ -52,6 +52,54 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
return index;
}
megdnn::TensorLayout output_layout(gpu::LaunchOp& launch_op) {
auto func_op = launch_op.getParentOfType<mlir::FuncOp>();
mgb_assert(func_op, "Unexpexted launch op.");
for (auto block_iter = func_op.rbegin(); block_iter != func_op.rend();
block_iter++) {
for (auto op_iter = block_iter->rbegin(); op_iter != block_iter->rend();
op_iter++) {
auto op = llvm::dyn_cast_or_null<AssignOp>(&(*op_iter));
if (op && op.getNumOperands() > 0) {
return mlir_type_to_layout(*(op.operand_type_begin()));
}
}
}
mgb_throw(MegBrainError, "Unexpexted launch op.");
}
std::vector<mlir::Value> get_multidim_tid(ConversionPatternRewriter& rewriter,
const Location& loc,
const mlir::Value& val,
const megdnn::TensorLayout& dst) {
Value index = get_tid(rewriter, loc);
auto type = val.getType().dyn_cast_or_null<mlir::MemRefType>();
if (type) {
ValueBuilderHelper helper(rewriter, loc);
std::vector<mlir::Value> idxs;
idxs.resize(dst.ndim);
mlir::Value dim_index = index;
for (int i = dst.ndim - 1; i >= 0; i--) {
auto cur_index = helper.modI(dim_index, helper.constI(dst[i]));
idxs[i] = cur_index;
dim_index = helper.divI(dim_index, helper.constI(dst[i]));
}
megdnn::TensorLayout src_layout = mlir_type_to_layout(type);
src_layout.init_contiguous_stride();
for (int i = 0; i < type.getRank(); ++i) {
if (src_layout[i] == 1) {
idxs[i] = helper.constI(0);
}
}
return idxs;
} else {
return {index};
}
}
template <typename Op, typename LoweredOp>
struct UnaryOpLowering : public ConversionPattern {
UnaryOpLowering(MLIRContext* ctx, gpu::LaunchOp& launch_op)
......@@ -66,7 +114,9 @@ struct UnaryOpLowering : public ConversionPattern {
typename Op::Adaptor binary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(),
dst_layout);
auto loaded_lhs =
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index);
......@@ -99,11 +149,15 @@ struct BinaryOpLowering : public ConversionPattern {
typename Op::Adaptor binary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto loaded_lhs =
get_operand<LoadOp>(rewriter, loc, binary_adaptor.lhs(), index);
auto loaded_rhs =
get_operand<LoadOp>(rewriter, loc, binary_adaptor.rhs(), index);
auto dst_layout = output_layout(m_launch_op);
auto lhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.lhs(),
dst_layout);
auto rhs_index = get_multidim_tid(rewriter, loc, binary_adaptor.rhs(),
dst_layout);
auto loaded_lhs = get_operand<LoadOp>(rewriter, loc,
binary_adaptor.lhs(), lhs_index);
auto loaded_rhs = get_operand<LoadOp>(rewriter, loc,
binary_adaptor.rhs(), rhs_index);
LoweredOp lower_op;
......@@ -135,13 +189,19 @@ struct TernaryOpLowering : public ConversionPattern {
typename Op::Adaptor ternary_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto loaded_x =
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(), index);
auto loaded_y =
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(), index);
auto loaded_z =
get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(), index);
auto dst_layout = output_layout(m_launch_op);
auto index_x = get_multidim_tid(rewriter, loc, ternary_adaptor.x(),
dst_layout);
auto index_y = get_multidim_tid(rewriter, loc, ternary_adaptor.y(),
dst_layout);
auto index_z = get_multidim_tid(rewriter, loc, ternary_adaptor.z(),
dst_layout);
auto loaded_x = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.x(),
index_x);
auto loaded_y = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.y(),
index_y);
auto loaded_z = get_operand<LoadOp>(rewriter, loc, ternary_adaptor.z(),
index_z);
LoweredOp lower_op;
......@@ -242,7 +302,9 @@ struct AssignOpLowering : public ConversionPattern {
AssignOpAdaptor assign_adaptor(operands);
rewriter.setInsertionPointToEnd(&(m_launch_op.body().front()));
auto index = get_tid(rewriter, loc);
auto dst_layout = output_layout(m_launch_op);
auto index = get_multidim_tid(rewriter, loc, assign_adaptor.rhs(),
dst_layout);
auto loaded_lhs =
get_operand<LoadOp>(rewriter, loc, assign_adaptor.lhs(), index);
......
......@@ -98,7 +98,6 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout,
for (size_t i = 0; i < layout.ndim; i++) {
shape.push_back(layout[i]);
}
switch (layout.dtype.enumv()) {
case megdnn::DTypeEnum::Float32:
return mlir::MemRefType::get(shape, builder.getF32Type());
......
......@@ -73,10 +73,10 @@ private:
m_symbol_table);
std::vector<mlir::Type> func_args;
for (auto&& arg : args.inputs) {
func_args.push_back(get_type(arg.layout));
func_args.push_back(get_type(arg.from->layout()));
}
for (auto&& arg : args.outputs) {
func_args.push_back(get_type(arg.layout));
func_args.push_back(get_type(arg.from->layout()));
}
//! the last arg is nr_elements
func_args.push_back(m_builder.getIndexType());
......
......@@ -44,7 +44,6 @@ megdnn::TensorLayout mlir_type_to_layout(mlir::Type type);
megdnn::DType mlir_type_to_dtype(mlir::Type type);
mlir::MemRefType layout_to_mlir_type(const megdnn::TensorLayout& layout,
mlir::Builder& builder);
} // namespace jit
} // namespace mgb
......
......@@ -130,8 +130,8 @@ void run_mlir(CompNode cn) {
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Float32> gen;
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 42}, cn),
host_x2 = gen({23, 42}, cn), host_x3 = gen({23, 42}, cn);
auto host_x0 = gen({23, 42}, cn), host_x1 = gen({23, 1}, cn),
host_x2 = gen({23, 42}, cn);
auto a = opr::Host2DeviceCopy::make(*graph, host_x0),
b = opr::Host2DeviceCopy::make(*graph, host_x1),
......@@ -159,6 +159,43 @@ void run_mlir(CompNode cn) {
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}
void run_mlir_broadcast(CompNode cn) {
set_backend(Backend::MLIR);
auto graph = ComputingGraph::make();
HostTensorGenerator<dtype::Float32> gen;
auto host_x0 = gen({10, 20, 5, 6}, cn), host_x1 = gen({1, 20, 1, 1}, cn),
host_x2 = gen({10, 1, 5, 1}, cn), host_x3 = gen({10, 1, 1, 1}, cn);
auto a = opr::Host2DeviceCopy::make(*graph, host_x0),
b = opr::Host2DeviceCopy::make(*graph, host_x1),
c = opr::Host2DeviceCopy::make(*graph, host_x2),
d = opr::Host2DeviceCopy::make(*graph, host_x3);
auto y =
opr::Elemwise::make({a, b, c}, opr::Elemwise::Mode::FUSE_MUL_ADD3) +
opr::Elemwise::make({d}, opr::Elemwise::Mode::ABS) - 0.3f;
auto ig_gen =
std::make_unique<InternalGraphGenerator>(y.node()->owner_opr());
for (auto i : get_rev_topo_order(y)) {
if (!i->same_type<opr::Host2DeviceCopy>()) {
ig_gen->add_opr(i);
}
}
auto igraph = ig_gen->generate();
auto y_jit = JITExecutor::make(igraph, ig_gen->orig_inps());
HostTensorND host_y, host_y_jit;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_jit, host_y_jit)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit);
}
struct MlirTestOpt {
float low;
float high;
......@@ -252,12 +289,14 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) {
TEST(TestJITMlirCodeGen, Basic) {
auto cn = CompNode::load("cpu0");
run_mlir(cn);
run_mlir_broadcast(cn);
}
TEST(TestJITMlirCodeGen, BasicGPU) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
run_mlir(cn);
run_mlir_broadcast(cn);
}
///////////////////////// unary ///////////////////////////////
......
......@@ -1580,8 +1580,8 @@ void run_mlir(CompNode cn) {
JITExecutor* jit;
unpack_vector(find_oprs<JITExecutor>(*funcs.second), jit);
ASSERT_EQ(2u, find_oprs<opr::Elemwise>(*funcs.second).size());
ASSERT_EQ(3u, jit->input().size());
ASSERT_EQ(0u, find_oprs<opr::Elemwise>(*funcs.second).size());
ASSERT_EQ(5u, jit->input().size());
}
TEST(TestJITExecutor, TestJITMlirFusion) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册