Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9682db98
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9682db98
编写于
10月 10, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb): add jit mlir elemwise broadcast
GitOrigin-RevId: 89d5e2f91eab46bc66fea014cf9170e49b5dfc4e
上级
89303cd8
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
219 addition
and
64 deletion
+219
-64
src/jit/impl/fusion_pass.cpp
src/jit/impl/fusion_pass.cpp
+0
-16
src/jit/impl/mlir/executable_cuda.cpp
src/jit/impl/mlir/executable_cuda.cpp
+7
-7
src/jit/impl/mlir/ir/common.cpp
src/jit/impl/mlir/ir/common.cpp
+47
-0
src/jit/impl/mlir/ir/common.h
src/jit/impl/mlir/ir/common.h
+13
-1
src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
+31
-18
src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
+76
-14
src/jit/impl/mlir/ir/utils.cpp
src/jit/impl/mlir/ir/utils.cpp
+0
-1
src/jit/impl/mlir/mlir_gen.cpp
src/jit/impl/mlir/mlir_gen.cpp
+2
-2
src/jit/include/megbrain/jit/mlir/ir/utils.h
src/jit/include/megbrain/jit/mlir/ir/utils.h
+0
-1
src/jit/test/codegen.cpp
src/jit/test/codegen.cpp
+41
-2
src/jit/test/fusion.cpp
src/jit/test/fusion.cpp
+2
-2
未找到文件。
src/jit/impl/fusion_pass.cpp
浏览文件 @
9682db98
...
...
@@ -294,22 +294,6 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) {
cond_nr_inp
=
ig_gen
->
get_cnt_input_if_add
(
opr
)
<=
max_nr_input
,
cond_mlir_specific
=
true
;
#if MGB_JIT_MLIR
//! FIXME mlir does't support broadcast currently.
auto
backend
=
MGB_GETENV
(
"MGB_JIT_BACKEND"
);
if
(
backend
&&
!
strcmp
(
backend
,
"MLIR"
))
{
for
(
VarNode
*
var
:
opr
->
input
())
{
if
(
!
SymbolVar
{
var
}.
as_immutable_scalar
().
valid
())
{
if
(
opr
->
node_prop
().
dep_map
().
at
(
var
)
&
DepType
::
DEV_VALUE
)
{
if
(
!
var
->
shape
().
eq_shape
(
opr
->
output
(
0
)
->
shape
()))
{
cond_mlir_specific
=
false
;
}
}
}
}
}
#endif
if
(
cond_readers
&&
cond_cn
&&
cond_shp
&&
cond_nr_inp
&&
cond_mlir_specific
)
{
ig_gen
->
add_opr
(
opr
);
...
...
src/jit/impl/mlir/executable_cuda.cpp
浏览文件 @
9682db98
...
...
@@ -57,23 +57,23 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
}
};
for
(
const
auto
&
arg
:
args
.
inputs
)
{
set_params
(
arg
.
from
->
dev_tensor
().
raw_ptr
(),
arg
.
layout
);
set_params
(
arg
.
from
->
dev_tensor
().
raw_ptr
(),
arg
.
from
->
layout
()
);
}
int64_t
nr_elements
=
0
;
for
(
const
auto
&
arg
:
args
.
outputs
)
{
if
(
nr_elements
==
0
)
{
nr_elements
=
arg
.
layout
.
total_nr_elems
();
nr_elements
=
arg
.
from
->
layout
()
.
total_nr_elems
();
}
else
{
mgb_assert
(
static_cast
<
size_t
>
(
nr_elements
)
==
arg
.
layout
.
total_nr_elems
(),
"The number of elements of outputs mismatch, expected: "
"%zu got: %zu(%s)"
,
static_cast
<
size_t
>
(
nr_elements
),
arg
.
layout
.
total_nr_elems
(),
arg
.
layout
.
to_string
().
c_str
());
arg
.
from
->
layout
()
.
total_nr_elems
(),
arg
.
from
->
layout
()
.
to_string
().
c_str
());
}
set_params
(
arg
.
from
->
dev_tensor
().
raw_ptr
(),
arg
.
layout
);
set_params
(
arg
.
from
->
dev_tensor
().
raw_ptr
(),
arg
.
from
->
layout
()
);
}
const
CompNodeEnv
&
env
=
CompNodeEnv
::
from_comp_node
(
fusion_opr
->
comp_node
());
...
...
@@ -134,8 +134,8 @@ void MLIRCUDAExecutable::FuncCache::exec(const JITExecutor* fusion_opr,
mgb_assert
(
fusion_opr
->
args
().
outputs
.
size
()
==
1
,
"Currently only support 1 outputs, got %zu"
,
fusion_opr
->
args
().
outputs
.
size
());
int
out_dim
=
fusion_opr
->
args
().
outputs
[
0
].
layout
.
ndim
;
DType
dtype
=
fusion_opr
->
args
().
outputs
[
0
].
layout
.
dtype
;
int
out_dim
=
fusion_opr
->
args
().
outputs
[
0
].
from
->
layout
()
.
ndim
;
DType
dtype
=
fusion_opr
->
args
().
outputs
[
0
].
from
->
layout
()
.
dtype
;
#define cb_outdim(_ndim, _dtype) \
if (_ndim == out_dim) { \
setup_and_launch<_ndim, _dtype>(fusion_opr, func->func, \
...
...
src/jit/impl/mlir/ir/common.cpp
浏览文件 @
9682db98
...
...
@@ -14,8 +14,10 @@
#if MGB_JIT && MGB_JIT_MLIR
#include "./common.h"
#include "megbrain/jit/mlir/ir/utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include <mlir/Dialect/Affine/IR/AffineOps.h>
using
namespace
mgb
;
using
namespace
jit
;
...
...
@@ -28,9 +30,11 @@ cb(add, AddFOp);
cb
(
sub
,
SubFOp
);
cb
(
mul
,
MulFOp
);
cb
(
div
,
DivFOp
);
cb
(
divI
,
SignedDivIOp
);
cb
(
mod
,
RemFOp
);
cb
(
bit_and
,
AndOp
);
cb
(
bit_or
,
OrOp
);
cb
(
modI
,
SignedRemIOp
);
#undef cb
#define cb(name, mode) \
...
...
@@ -62,6 +66,11 @@ mlir::Value ValueBuilderHelper::const_val(float val) {
m_builder
.
getF32FloatAttr
(
val
));
}
mlir
::
Value
ValueBuilderHelper
::
constI
(
int32_t
val
)
{
return
m_builder
.
create
<
mlir
::
ConstantOp
>
(
m_location
,
m_builder
.
getIndexAttr
(
val
));
}
#define cb(name, op) \
mlir::Value ValueBuilderHelper::name(mlir::Value lhs) { \
return m_builder.create<mlir::op>(m_location, lhs); \
...
...
@@ -97,6 +106,44 @@ mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val,
false_val
);
}
mlir
::
AffineMap
jit
::
get_affinemap
(
mlir
::
OpBuilder
&
builder
,
const
mlir
::
Value
&
val
,
const
megdnn
::
TensorLayout
&
layout
)
{
auto
type
=
val
.
getType
().
cast
<
mlir
::
MemRefType
>
();
mgb_assert
(
type
,
"currently only support MemRefType"
);
std
::
vector
<
mlir
::
AffineExpr
>
exprs
;
for
(
int
i
=
0
;
i
<
type
.
getRank
();
++
i
)
{
if
(
layout
[
i
]
==
1
)
{
exprs
.
push_back
(
builder
.
getAffineConstantExpr
(
0
));
}
else
{
exprs
.
push_back
(
builder
.
getAffineDimExpr
(
i
));
}
}
auto
map
=
mlir
::
AffineMap
::
get
(
type
.
getRank
(),
0
,
exprs
,
builder
.
getContext
());
return
map
;
}
mlir
::
Value
jit
::
get_affine_load_op
(
mlir
::
OpBuilder
&
builder
,
const
mlir
::
Location
&
loc
,
const
mlir
::
Value
&
val
,
const
mlir
::
ValueRange
&
index
,
const
megdnn
::
TensorLayout
&
dst
)
{
if
(
val
.
getType
().
isa
<
mlir
::
MemRefType
>
())
{
auto
type
=
val
.
getType
().
cast
<
mlir
::
MemRefType
>
();
megdnn
::
TensorLayout
src_layout
=
mlir_type_to_layout
(
type
);
src_layout
.
init_contiguous_stride
();
if
(
src_layout
.
eq_shape
(
dst
))
{
return
builder
.
create
<
mlir
::
AffineLoadOp
>
(
loc
,
val
,
index
);
}
else
{
auto
lhs_map
=
get_affinemap
(
builder
,
val
,
src_layout
);
return
builder
.
create
<
mlir
::
AffineLoadOp
>
(
loc
,
val
,
lhs_map
,
index
);
}
}
else
{
return
val
;
}
}
#endif // MGB_JIT && MGB_JIT_MLIR
// vim: syntax=cpp.doxygen
src/jit/impl/mlir/ir/common.h
浏览文件 @
9682db98
...
...
@@ -14,7 +14,7 @@
#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
#include "megbrain/tensor.h"
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/Value.h>
...
...
@@ -39,9 +39,11 @@ public:
cb
(
sub
);
cb
(
mul
);
cb
(
div
);
cb
(
divI
);
cb
(
max
);
cb
(
min
);
cb
(
mod
);
cb
(
modI
);
cb
(
gt
);
cb
(
ge
);
cb
(
lt
);
...
...
@@ -51,6 +53,7 @@ public:
cb
(
bit_or
);
#undef cb
mlir
::
Value
const_val
(
float
val
);
mlir
::
Value
constI
(
int32_t
val
);
#define cb(name) \
mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \
...
...
@@ -89,6 +92,15 @@ mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc,
}
}
mlir
::
AffineMap
get_affinemap
(
mlir
::
OpBuilder
&
builder
,
const
mlir
::
Value
&
val
,
const
TensorLayout
&
layout
);
mlir
::
Value
get_affine_load_op
(
mlir
::
OpBuilder
&
builder
,
const
mlir
::
Location
&
loc
,
const
mlir
::
Value
&
val
,
const
mlir
::
ValueRange
&
index
,
const
TensorLayout
&
dst
);
}
// namespace jit
}
// namespace mgb
...
...
src/jit/impl/mlir/ir/lower_to_affine_pass.cpp
浏览文件 @
9682db98
...
...
@@ -42,8 +42,8 @@ void lower_op_to_loops(Operation* op, ValueRange operands,
auto
alloc
=
jit
::
insert_alloc_and_dealloc
(
memref_type
,
loc
,
rewriter
);
SmallVector
<
int64_t
,
4
>
lower_bounds
(
memref_type
.
getRank
(),
0
);
SmallVector
<
int64_t
,
4
>
steps
(
memref_type
.
getRank
(),
1
);
llvm
::
SmallVector
<
int64_t
,
4
>
lower_bounds
(
memref_type
.
getRank
(),
0
);
llvm
::
SmallVector
<
int64_t
,
4
>
steps
(
memref_type
.
getRank
(),
1
);
buildAffineLoopNest
(
rewriter
,
loc
,
lower_bounds
,
memref_type
.
getShape
(),
steps
,
[
&
](
OpBuilder
&
nested_builder
,
Location
loc
,
ValueRange
ivs
)
{
...
...
@@ -96,17 +96,23 @@ struct BinaryOpLowering : public ConversionPattern {
Operation
*
op
,
ArrayRef
<
Value
>
operands
,
ConversionPatternRewriter
&
rewriter
)
const
final
{
auto
loc
=
op
->
getLoc
();
auto
dst_memref_type
=
(
*
op
->
result_type_begin
()).
cast
<
MemRefType
>
();
megdnn
::
TensorLayout
dst_layout
=
mlir_type_to_layout
(
dst_memref_type
);
dst_layout
.
init_contiguous_stride
();
lower_op_to_loops
(
op
,
operands
,
rewriter
,
[
loc
](
OpBuilder
&
builder
,
ValueRange
memref_operands
,
ValueRange
loop_ivs
)
{
[
dst_layout
,
loc
,
this
](
OpBuilder
&
builder
,
ValueRange
memref_operands
,
ValueRange
loop_ivs
)
{
typename
Op
::
Adaptor
binary_adaptor
(
memref_operands
);
LoweredOp
lower_op
;
auto
loaded_lhs
=
get_operand
<
AffineLoadOp
>
(
builder
,
loc
,
binary_adaptor
.
lhs
(),
loop_ivs
);
auto
loaded_rhs
=
get_operand
<
AffineLoadOp
>
(
builder
,
loc
,
binary_adaptor
.
rhs
(),
loop_ivs
);
auto
loaded_lhs
=
get_affine_load_op
(
builder
,
loc
,
binary_adaptor
.
lhs
(),
loop_ivs
,
dst_layout
);
auto
loaded_rhs
=
get_affine_load_op
(
builder
,
loc
,
binary_adaptor
.
rhs
(),
loop_ivs
,
dst_layout
);
return
lower_op
(
builder
,
loc
,
{
loaded_lhs
,
loaded_rhs
});
});
...
...
@@ -128,19 +134,26 @@ struct TernaryOpLowering : public ConversionPattern {
Operation
*
op
,
ArrayRef
<
Value
>
operands
,
ConversionPatternRewriter
&
rewriter
)
const
final
{
auto
loc
=
op
->
getLoc
();
auto
dst_memref_type
=
(
*
op
->
result_type_begin
()).
cast
<
MemRefType
>
();
megdnn
::
TensorLayout
dst_layout
=
mlir_type_to_layout
(
dst_memref_type
);
dst_layout
.
init_contiguous_stride
();
lower_op_to_loops
(
op
,
operands
,
rewriter
,
[
loc
](
OpBuilder
&
builder
,
ValueRange
memref_operands
,
ValueRange
loop_ivs
)
{
[
dst_layout
,
loc
](
OpBuilder
&
builder
,
ValueRange
memref_operands
,
ValueRange
loop_ivs
)
{
typename
Op
::
Adaptor
ternary_adaptor
(
memref_operands
);
LoweredOp
lower_op
;
auto
loaded_x
=
get_operand
<
AffineLoadOp
>
(
builder
,
loc
,
ternary_adaptor
.
x
(),
loop_ivs
);
auto
loaded_y
=
get_operand
<
AffineLoadOp
>
(
builder
,
loc
,
ternary_adaptor
.
y
(),
loop_ivs
);
auto
loaded_z
=
get_operand
<
AffineLoadOp
>
(
builder
,
loc
,
ternary_adaptor
.
z
(),
loop_ivs
);
auto
loaded_x
=
get_affine_load_op
(
builder
,
loc
,
ternary_adaptor
.
x
(),
loop_ivs
,
dst_layout
);
auto
loaded_y
=
get_affine_load_op
(
builder
,
loc
,
ternary_adaptor
.
y
(),
loop_ivs
,
dst_layout
);
auto
loaded_z
=
get_affine_load_op
(
builder
,
loc
,
ternary_adaptor
.
z
(),
loop_ivs
,
dst_layout
);
return
lower_op
(
builder
,
loc
,
{
loaded_x
,
loaded_y
,
loaded_z
});
...
...
@@ -166,8 +179,8 @@ struct AssignOpLowering : public ConversionPattern {
auto
memref_type
=
operands
[
0
].
getType
().
cast
<
MemRefType
>
();
AssignOpAdaptor
assign_adaptor
(
operands
);
SmallVector
<
int64_t
,
4
>
lower_bounds
(
memref_type
.
getRank
(),
0
);
SmallVector
<
int64_t
,
4
>
steps
(
memref_type
.
getRank
(),
1
);
llvm
::
SmallVector
<
int64_t
,
4
>
lower_bounds
(
memref_type
.
getRank
(),
0
);
llvm
::
SmallVector
<
int64_t
,
4
>
steps
(
memref_type
.
getRank
(),
1
);
buildAffineLoopNest
(
rewriter
,
loc
,
lower_bounds
,
memref_type
.
getShape
(),
steps
,
[
&
](
OpBuilder
&
nested_builder
,
Location
loc
,
ValueRange
ivs
)
{
...
...
src/jit/impl/mlir/ir/lower_to_gpu_pass.cpp
浏览文件 @
9682db98
...
...
@@ -52,6 +52,54 @@ mlir::Value get_tid(ConversionPatternRewriter& rewriter, const Location& loc) {
return
index
;
}
megdnn
::
TensorLayout
output_layout
(
gpu
::
LaunchOp
&
launch_op
)
{
auto
func_op
=
launch_op
.
getParentOfType
<
mlir
::
FuncOp
>
();
mgb_assert
(
func_op
,
"Unexpexted launch op."
);
for
(
auto
block_iter
=
func_op
.
rbegin
();
block_iter
!=
func_op
.
rend
();
block_iter
++
)
{
for
(
auto
op_iter
=
block_iter
->
rbegin
();
op_iter
!=
block_iter
->
rend
();
op_iter
++
)
{
auto
op
=
llvm
::
dyn_cast_or_null
<
AssignOp
>
(
&
(
*
op_iter
));
if
(
op
&&
op
.
getNumOperands
()
>
0
)
{
return
mlir_type_to_layout
(
*
(
op
.
operand_type_begin
()));
}
}
}
mgb_throw
(
MegBrainError
,
"Unexpexted launch op."
);
}
std
::
vector
<
mlir
::
Value
>
get_multidim_tid
(
ConversionPatternRewriter
&
rewriter
,
const
Location
&
loc
,
const
mlir
::
Value
&
val
,
const
megdnn
::
TensorLayout
&
dst
)
{
Value
index
=
get_tid
(
rewriter
,
loc
);
auto
type
=
val
.
getType
().
dyn_cast_or_null
<
mlir
::
MemRefType
>
();
if
(
type
)
{
ValueBuilderHelper
helper
(
rewriter
,
loc
);
std
::
vector
<
mlir
::
Value
>
idxs
;
idxs
.
resize
(
dst
.
ndim
);
mlir
::
Value
dim_index
=
index
;
for
(
int
i
=
dst
.
ndim
-
1
;
i
>=
0
;
i
--
)
{
auto
cur_index
=
helper
.
modI
(
dim_index
,
helper
.
constI
(
dst
[
i
]));
idxs
[
i
]
=
cur_index
;
dim_index
=
helper
.
divI
(
dim_index
,
helper
.
constI
(
dst
[
i
]));
}
megdnn
::
TensorLayout
src_layout
=
mlir_type_to_layout
(
type
);
src_layout
.
init_contiguous_stride
();
for
(
int
i
=
0
;
i
<
type
.
getRank
();
++
i
)
{
if
(
src_layout
[
i
]
==
1
)
{
idxs
[
i
]
=
helper
.
constI
(
0
);
}
}
return
idxs
;
}
else
{
return
{
index
};
}
}
template
<
typename
Op
,
typename
LoweredOp
>
struct
UnaryOpLowering
:
public
ConversionPattern
{
UnaryOpLowering
(
MLIRContext
*
ctx
,
gpu
::
LaunchOp
&
launch_op
)
...
...
@@ -66,7 +114,9 @@ struct UnaryOpLowering : public ConversionPattern {
typename
Op
::
Adaptor
binary_adaptor
(
operands
);
rewriter
.
setInsertionPointToEnd
(
&
(
m_launch_op
.
body
().
front
()));
auto
index
=
get_tid
(
rewriter
,
loc
);
auto
dst_layout
=
output_layout
(
m_launch_op
);
auto
index
=
get_multidim_tid
(
rewriter
,
loc
,
binary_adaptor
.
lhs
(),
dst_layout
);
auto
loaded_lhs
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
binary_adaptor
.
lhs
(),
index
);
...
...
@@ -99,11 +149,15 @@ struct BinaryOpLowering : public ConversionPattern {
typename
Op
::
Adaptor
binary_adaptor
(
operands
);
rewriter
.
setInsertionPointToEnd
(
&
(
m_launch_op
.
body
().
front
()));
auto
index
=
get_tid
(
rewriter
,
loc
);
auto
loaded_lhs
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
binary_adaptor
.
lhs
(),
index
);
auto
loaded_rhs
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
binary_adaptor
.
rhs
(),
index
);
auto
dst_layout
=
output_layout
(
m_launch_op
);
auto
lhs_index
=
get_multidim_tid
(
rewriter
,
loc
,
binary_adaptor
.
lhs
(),
dst_layout
);
auto
rhs_index
=
get_multidim_tid
(
rewriter
,
loc
,
binary_adaptor
.
rhs
(),
dst_layout
);
auto
loaded_lhs
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
binary_adaptor
.
lhs
(),
lhs_index
);
auto
loaded_rhs
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
binary_adaptor
.
rhs
(),
rhs_index
);
LoweredOp
lower_op
;
...
...
@@ -135,13 +189,19 @@ struct TernaryOpLowering : public ConversionPattern {
typename
Op
::
Adaptor
ternary_adaptor
(
operands
);
rewriter
.
setInsertionPointToEnd
(
&
(
m_launch_op
.
body
().
front
()));
auto
index
=
get_tid
(
rewriter
,
loc
);
auto
loaded_x
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
ternary_adaptor
.
x
(),
index
);
auto
loaded_y
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
ternary_adaptor
.
y
(),
index
);
auto
loaded_z
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
ternary_adaptor
.
z
(),
index
);
auto
dst_layout
=
output_layout
(
m_launch_op
);
auto
index_x
=
get_multidim_tid
(
rewriter
,
loc
,
ternary_adaptor
.
x
(),
dst_layout
);
auto
index_y
=
get_multidim_tid
(
rewriter
,
loc
,
ternary_adaptor
.
y
(),
dst_layout
);
auto
index_z
=
get_multidim_tid
(
rewriter
,
loc
,
ternary_adaptor
.
z
(),
dst_layout
);
auto
loaded_x
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
ternary_adaptor
.
x
(),
index_x
);
auto
loaded_y
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
ternary_adaptor
.
y
(),
index_y
);
auto
loaded_z
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
ternary_adaptor
.
z
(),
index_z
);
LoweredOp
lower_op
;
...
...
@@ -242,7 +302,9 @@ struct AssignOpLowering : public ConversionPattern {
AssignOpAdaptor
assign_adaptor
(
operands
);
rewriter
.
setInsertionPointToEnd
(
&
(
m_launch_op
.
body
().
front
()));
auto
index
=
get_tid
(
rewriter
,
loc
);
auto
dst_layout
=
output_layout
(
m_launch_op
);
auto
index
=
get_multidim_tid
(
rewriter
,
loc
,
assign_adaptor
.
rhs
(),
dst_layout
);
auto
loaded_lhs
=
get_operand
<
LoadOp
>
(
rewriter
,
loc
,
assign_adaptor
.
lhs
(),
index
);
...
...
src/jit/impl/mlir/ir/utils.cpp
浏览文件 @
9682db98
...
...
@@ -98,7 +98,6 @@ mlir::MemRefType jit::layout_to_mlir_type(const megdnn::TensorLayout& layout,
for
(
size_t
i
=
0
;
i
<
layout
.
ndim
;
i
++
)
{
shape
.
push_back
(
layout
[
i
]);
}
switch
(
layout
.
dtype
.
enumv
())
{
case
megdnn
::
DTypeEnum
::
Float32
:
return
mlir
::
MemRefType
::
get
(
shape
,
builder
.
getF32Type
());
...
...
src/jit/impl/mlir/mlir_gen.cpp
浏览文件 @
9682db98
...
...
@@ -73,10 +73,10 @@ private:
m_symbol_table
);
std
::
vector
<
mlir
::
Type
>
func_args
;
for
(
auto
&&
arg
:
args
.
inputs
)
{
func_args
.
push_back
(
get_type
(
arg
.
layout
));
func_args
.
push_back
(
get_type
(
arg
.
from
->
layout
()
));
}
for
(
auto
&&
arg
:
args
.
outputs
)
{
func_args
.
push_back
(
get_type
(
arg
.
layout
));
func_args
.
push_back
(
get_type
(
arg
.
from
->
layout
()
));
}
//! the last arg is nr_elements
func_args
.
push_back
(
m_builder
.
getIndexType
());
...
...
src/jit/include/megbrain/jit/mlir/ir/utils.h
浏览文件 @
9682db98
...
...
@@ -44,7 +44,6 @@ megdnn::TensorLayout mlir_type_to_layout(mlir::Type type);
megdnn
::
DType
mlir_type_to_dtype
(
mlir
::
Type
type
);
mlir
::
MemRefType
layout_to_mlir_type
(
const
megdnn
::
TensorLayout
&
layout
,
mlir
::
Builder
&
builder
);
}
// namespace jit
}
// namespace mgb
...
...
src/jit/test/codegen.cpp
浏览文件 @
9682db98
...
...
@@ -130,8 +130,8 @@ void run_mlir(CompNode cn) {
auto
graph
=
ComputingGraph
::
make
();
HostTensorGenerator
<
dtype
::
Float32
>
gen
;
auto
host_x0
=
gen
({
23
,
42
},
cn
),
host_x1
=
gen
({
23
,
42
},
cn
),
host_x2
=
gen
({
23
,
42
},
cn
)
,
host_x3
=
gen
({
23
,
42
},
cn
)
;
auto
host_x0
=
gen
({
23
,
42
},
cn
),
host_x1
=
gen
({
23
,
1
},
cn
),
host_x2
=
gen
({
23
,
42
},
cn
);
auto
a
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x0
),
b
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x1
),
...
...
@@ -159,6 +159,43 @@ void run_mlir(CompNode cn) {
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_jit
);
}
void
run_mlir_broadcast
(
CompNode
cn
)
{
set_backend
(
Backend
::
MLIR
);
auto
graph
=
ComputingGraph
::
make
();
HostTensorGenerator
<
dtype
::
Float32
>
gen
;
auto
host_x0
=
gen
({
10
,
20
,
5
,
6
},
cn
),
host_x1
=
gen
({
1
,
20
,
1
,
1
},
cn
),
host_x2
=
gen
({
10
,
1
,
5
,
1
},
cn
),
host_x3
=
gen
({
10
,
1
,
1
,
1
},
cn
);
auto
a
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x0
),
b
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x1
),
c
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x2
),
d
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x3
);
auto
y
=
opr
::
Elemwise
::
make
({
a
,
b
,
c
},
opr
::
Elemwise
::
Mode
::
FUSE_MUL_ADD3
)
+
opr
::
Elemwise
::
make
({
d
},
opr
::
Elemwise
::
Mode
::
ABS
)
-
0.3
f
;
auto
ig_gen
=
std
::
make_unique
<
InternalGraphGenerator
>
(
y
.
node
()
->
owner_opr
());
for
(
auto
i
:
get_rev_topo_order
(
y
))
{
if
(
!
i
->
same_type
<
opr
::
Host2DeviceCopy
>
())
{
ig_gen
->
add_opr
(
i
);
}
}
auto
igraph
=
ig_gen
->
generate
();
auto
y_jit
=
JITExecutor
::
make
(
igraph
,
ig_gen
->
orig_inps
());
HostTensorND
host_y
,
host_y_jit
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
),
make_callback_copy
(
y_jit
,
host_y_jit
)});
func
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
host_y
,
host_y_jit
);
}
struct
MlirTestOpt
{
float
low
;
float
high
;
...
...
@@ -252,12 +289,14 @@ TYPED_TEST(TestJITNvrtcCodeGen, run) {
TEST
(
TestJITMlirCodeGen
,
Basic
)
{
auto
cn
=
CompNode
::
load
(
"cpu0"
);
run_mlir
(
cn
);
run_mlir_broadcast
(
cn
);
}
TEST
(
TestJITMlirCodeGen
,
BasicGPU
)
{
REQUIRE_GPU
(
1
);
auto
cn
=
CompNode
::
load
(
"gpu0"
);
run_mlir
(
cn
);
run_mlir_broadcast
(
cn
);
}
///////////////////////// unary ///////////////////////////////
...
...
src/jit/test/fusion.cpp
浏览文件 @
9682db98
...
...
@@ -1580,8 +1580,8 @@ void run_mlir(CompNode cn) {
JITExecutor
*
jit
;
unpack_vector
(
find_oprs
<
JITExecutor
>
(
*
funcs
.
second
),
jit
);
ASSERT_EQ
(
2
u
,
find_oprs
<
opr
::
Elemwise
>
(
*
funcs
.
second
).
size
());
ASSERT_EQ
(
3
u
,
jit
->
input
().
size
());
ASSERT_EQ
(
0
u
,
find_oprs
<
opr
::
Elemwise
>
(
*
funcs
.
second
).
size
());
ASSERT_EQ
(
5
u
,
jit
->
input
().
size
());
}
TEST
(
TestJITExecutor
,
TestJITMlirFusion
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录