Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
409c9881
MegEngine
项目概览
MegEngine 天元
/
MegEngine
10 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
409c9881
编写于
4月 14, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): add matmul apply_on_varnode
GitOrigin-RevId: 2cf6bf237cb573f0c78fcb5cacc0257f99ebcecb
上级
d52ba79d
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
330 addition
and
357 deletion
+330
-357
dnn/include/megdnn/oprs/linalg.h
dnn/include/megdnn/oprs/linalg.h
+2
-2
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+0
-252
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+3
-3
imperative/python/src/tensor_utils.cpp
imperative/python/src/tensor_utils.cpp
+48
-60
imperative/python/test/integration/test_trace_dump.py
imperative/python/test/integration/test_trace_dump.py
+0
-1
imperative/src/impl/ops/matmul.cpp
imperative/src/impl/ops/matmul.cpp
+223
-30
imperative/tablegen/helper.h
imperative/tablegen/helper.h
+6
-1
imperative/tablegen/targets/cpp_class.cpp
imperative/tablegen/targets/cpp_class.cpp
+36
-6
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+12
-2
未找到文件。
dnn/include/megdnn/oprs/linalg.h
浏览文件 @
409c9881
...
...
@@ -36,7 +36,7 @@ public:
virtual
void
exec
(
_megdnn_tensor_in
A
,
_megdnn_tensor_in
B
,
_megdnn_tensor_out
C
,
_megdnn_workspace
workspace
)
=
0
;
void
deduce_dtype
(
DType
A
,
DType
B
,
DType
&
C
);
MGE_WIN_DECLSPEC_FUC
void
deduce_dtype
(
DType
A
,
DType
B
,
DType
&
C
);
void
deduce_layout
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
TensorLayout
&
C
);
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
=
0
;
...
...
@@ -73,7 +73,7 @@ public:
virtual
void
exec
(
_megdnn_tensor_in
A
,
_megdnn_tensor_in
B
,
_megdnn_tensor_out
C
,
_megdnn_workspace
workspace
)
=
0
;
void
deduce_dtype
(
DType
A
,
DType
B
,
DType
&
C
);
MGE_WIN_DECLSPEC_FUC
void
deduce_dtype
(
DType
A
,
DType
B
,
DType
&
C
);
void
deduce_layout
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
TensorLayout
&
C
);
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
=
0
;
...
...
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
409c9881
...
...
@@ -44,216 +44,6 @@ def _elwise(*args, mode):
return
_elwise_apply
(
args
,
mode
)
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
1
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
1
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
builtin
.
GetVarShape
(),
inp1
)
shape2
=
f
(
builtin
.
GetVarShape
(),
inp2
)
if
_dim1
>
2
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape1
)),
build_shape_tail
(
shape1
),
),
)
if
_dim2
>
2
:
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape2
)),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
MatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
.
value
,
)
result
=
f
(
op
,
inp1
,
inp2
)
result_shape
=
f
(
builtin
.
GetVarShape
(),
result
)
if
_dim1
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
),
build_shape_tail
(
result_shape
),
),
)
if
_dim2
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
),
build_shape_tail
(
result_shape
),
),
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedMatrixMulOp
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedBatchedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedBatchedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedBatchedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
2
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
2
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
builtin
.
GetVarShape
(),
inp1
)
shape2
=
f
(
builtin
.
GetVarShape
(),
inp2
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
_dim1
>
_dim2
:
# broadcast
shape2
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
,
idx
=-
_dim2
),
# shape1[:-_dim2]
shape2
,
)
inp2
=
f
(
builtin
.
Broadcast
(),
inp2
,
shape2
)
batch_shape
=
build_shape_head
(
shape1
)
if
_dim2
>
_dim1
:
# broadcast
shape1
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
,
idx
=-
_dim1
),
# shape2[:-_dim1]
shape1
,
)
inp1
=
f
(
builtin
.
Broadcast
(),
inp1
,
shape1
)
batch_shape
=
build_shape_head
(
shape2
)
if
_dim1
==
_dim2
:
batch_shape
=
build_shape_head
(
shape1
)
# compress inputs to 3d
if
maxdim
>
3
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape1
),
),
)
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
BatchedMatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
.
value
,
)
result
=
f
(
op
,
inp1
,
inp2
)
if
maxdim
>
3
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
batch_shape
,
build_shape_tail
(
f
(
builtin
.
GetVarShape
(),
result
)),
),
)
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedBatchedMatrixMulOp
class
_Hashable
:
def
__init__
(
self
,
value
)
->
None
:
self
.
value
=
value
...
...
@@ -267,42 +57,6 @@ class _Hashable:
return
self
.
value
==
o
.
value
def
symbolicMatrixMul
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
):
extentedMatrixMulOp
=
_get_extentedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
)
(
result
,)
=
apply
(
extentedMatrixMulOp
(),
inp1
,
inp2
)
return
result
def
symbolicBatchedMatrixMul
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
):
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
(),
inp1
,
inp2
)
return
result
def
_matmul
(
inp1
,
inp2
,
...
...
@@ -342,11 +96,8 @@ def _matmul(
transpose_a
,
transpose_b
,
compute_mode
,
format
,
_config
.
_benchmark_kernel
,
_config
.
_deterministic_kernel
,
strategy
,
symbolicMatrixMul
,
)
else
:
# dispath to BatchedMatrixMul
# nx1(transpose_a=True), n>=3
...
...
@@ -362,11 +113,8 @@ def _matmul(
transpose_a
,
transpose_b
,
compute_mode
,
format
,
_config
.
_benchmark_kernel
,
_config
.
_deterministic_kernel
,
strategy
,
symbolicBatchedMatrixMul
,
)
...
...
imperative/python/megengine/functional/nn.py
浏览文件 @
409c9881
...
...
@@ -32,7 +32,7 @@ from ..core.ops.builtin import (
TypeCvt
,
)
from
..core.tensor
import
amp
,
megbrain_graph
from
..core.tensor.array_method
import
_
elwise_apply
from
..core.tensor.array_method
import
_
matmul
from
..core.tensor.utils
import
(
astensor1d
,
cast_tensors
,
...
...
@@ -49,7 +49,7 @@ from ..utils.deprecation import deprecated_func
from
.debug_param
import
get_execution_strategy
from
.distributed
import
all_reduce_sum
from
.elemwise
import
_elwise
,
exp
,
log
,
log1p
,
maximum
,
minimum
from
.math
import
ma
tmul
,
ma
x
,
sum
from
.math
import
max
,
sum
from
.tensor
import
broadcast_to
,
concat
,
expand_dims
,
ones
,
squeeze
,
zeros
__all__
=
[
...
...
@@ -127,7 +127,7 @@ def linear(
bias: bias with shape `(out_features,)`. Default: None
"""
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
ret
=
matmul
(
inp
,
weight
,
transpose_b
=
True
,
compute_mode
=
compute_mode
)
ret
=
_
matmul
(
inp
,
weight
,
transpose_b
=
True
,
compute_mode
=
compute_mode
)
if
bias
is
not
None
:
if
amp
.
_enabled
:
bias
=
bias
.
astype
(
"float16"
)
...
...
imperative/python/src/tensor_utils.cpp
浏览文件 @
409c9881
...
...
@@ -1494,73 +1494,61 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py
::
object
_matmul_cpp
(
py
::
handle
inp1
,
py
::
handle
inp2
,
py
::
handle
dim1
,
py
::
handle
dim2
,
py
::
handle
transpose_a
,
py
::
handle
transpose_b
,
py
::
handle
compute_mode
,
py
::
handle
format
,
py
::
handle
profile
,
py
::
handle
determistic
,
py
::
handle
strategy
,
py
::
handle
func
)
{
if
(
enable_fastpath
(
inp1
))
{
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
DEFAULT
;
if
(
compute_mode
.
cast
<
std
::
string
>
().
compare
(
std
::
string
(
"float32"
))
==
0
)
{
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
FLOAT32
;
}
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
cstrategy
;
if
(
profile
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
PROFILE
;
}
else
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
HEURISTIC
;
}
if
(
determistic
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
REPRODUCIBLE
;
}
std
::
shared_ptr
<
OpDef
>
op
=
MatrixMul
::
make
(
transpose_a
.
cast
<
bool
>
(),
transpose_b
.
cast
<
bool
>
(),
mode
,
::
megdnn
::
param
::
MatrixMul
::
Format
::
DEFAULT
,
cstrategy
,
UINT64_MAX
);
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
inp1
.
ptr
(),
inp2
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
[
0
];
py
::
handle
profile
,
py
::
handle
determistic
)
{
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
DEFAULT
;
if
(
compute_mode
.
cast
<
std
::
string
>
().
compare
(
std
::
string
(
"float32"
))
==
0
)
{
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
FLOAT32
;
}
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
cstrategy
=
static_cast
<::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
>
(
0
);
if
(
profile
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
PROFILE
;
}
else
{
// fallback to traceable implementation
return
func
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
);
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
HEURISTIC
;
}
if
(
determistic
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
REPRODUCIBLE
;
}
std
::
shared_ptr
<
OpDef
>
op
=
MatrixMul
::
make
(
transpose_a
.
cast
<
bool
>
(),
transpose_b
.
cast
<
bool
>
(),
mode
,
::
megdnn
::
param
::
MatrixMul
::
Format
::
DEFAULT
,
cstrategy
,
UINT64_MAX
,
dim1
.
cast
<
uint32_t
>
(),
dim2
.
cast
<
uint32_t
>
());
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
inp1
.
ptr
(),
inp2
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
[
0
];
}
py
::
object
_batched_matmul_cpp
(
py
::
handle
inp1
,
py
::
handle
inp2
,
py
::
handle
dim1
,
py
::
handle
dim2
,
py
::
handle
transpose_a
,
py
::
handle
transpose_b
,
py
::
handle
compute_mode
,
py
::
handle
format
,
py
::
handle
profile
,
py
::
handle
determistic
,
py
::
handle
strategy
,
py
::
handle
func
)
{
if
(
enable_fastpath
(
inp1
))
{
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
DEFAULT
;
if
(
compute_mode
.
cast
<
std
::
string
>
().
compare
(
std
::
string
(
"float32"
))
==
0
)
{
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
FLOAT32
;
}
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
cstrategy
;
if
(
profile
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
PROFILE
;
}
else
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
HEURISTIC
;
}
if
(
determistic
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
REPRODUCIBLE
;
}
std
::
shared_ptr
<
OpDef
>
op
=
BatchedMatrixMul
::
make
(
transpose_a
.
cast
<
bool
>
(),
transpose_b
.
cast
<
bool
>
(),
mode
,
::
megdnn
::
param
::
MatrixMul
::
Format
::
DEFAULT
,
cstrategy
,
UINT64_MAX
);
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
inp1
.
ptr
(),
inp2
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
[
0
];
py
::
handle
profile
,
py
::
handle
determistic
)
{
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
DEFAULT
;
if
(
compute_mode
.
cast
<
std
::
string
>
().
compare
(
std
::
string
(
"float32"
))
==
0
)
{
mode
=
::
megdnn
::
param
::
MatrixMul
::
ComputeMode
::
FLOAT32
;
}
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
cstrategy
=
static_cast
<::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
>
(
0
);
if
(
profile
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
PROFILE
;
}
else
{
// fallback to traceable implementation
return
func
(
inp1
,
inp2
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
);
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
HEURISTIC
;
}
if
(
determistic
.
cast
<
bool
>
())
{
cstrategy
|=
::
megdnn
::
param
::
ExecutionPolicy
::
Strategy
::
REPRODUCIBLE
;
}
std
::
shared_ptr
<
OpDef
>
op
=
BatchedMatrixMul
::
make
(
transpose_a
.
cast
<
bool
>
(),
transpose_b
.
cast
<
bool
>
(),
mode
,
::
megdnn
::
param
::
MatrixMul
::
Format
::
DEFAULT
,
cstrategy
,
UINT64_MAX
,
dim1
.
cast
<
uint32_t
>
(),
dim2
.
cast
<
uint32_t
>
());
py
::
object
Op
=
py
::
cast
(
op
);
PyObject
*
p
[
3
]
=
{
Op
.
ptr
(),
inp1
.
ptr
(),
inp2
.
ptr
()};
py
::
tuple
ret
=
py
::
reinterpret_steal
<
py
::
object
>
(
py_apply
(
NULL
,
p
,
3
));
return
ret
[
0
];
}
py
::
object
_pixel_shuffle_cpp
(
py
::
handle
inp
,
py
::
handle
val
,
py
::
handle
func
)
{
...
...
@@ -1671,7 +1659,7 @@ PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try
{
return
_matmul_cpp
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
args
[
6
],
args
[
7
],
args
[
8
]
,
args
[
9
],
args
[
10
],
args
[
11
]
)
args
[
7
],
args
[
8
])
.
release
()
.
ptr
();
}
...
...
@@ -1682,7 +1670,7 @@ PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs
try
{
return
_batched_matmul_cpp
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
],
args
[
5
],
args
[
6
],
args
[
7
],
args
[
8
]
,
args
[
9
],
args
[
10
],
args
[
11
]
)
args
[
7
],
args
[
8
])
.
release
()
.
ptr
();
}
...
...
imperative/python/test/integration/test_trace_dump.py
浏览文件 @
409c9881
...
...
@@ -20,7 +20,6 @@ import megengine.optimizer as optim
from
megengine
import
tensor
from
megengine.autodiff
import
GradManager
from
megengine.jit
import
trace
from
megengine.traced_module
import
trace_module
@
contextlib
.
contextmanager
...
...
imperative/src/impl/ops/matmul.cpp
浏览文件 @
409c9881
...
...
@@ -2,8 +2,12 @@
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
#include "megbrain/graph/symbol_var.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "../algo_chooser.h"
...
...
@@ -12,12 +16,93 @@ namespace imperative {
namespace
{
namespace
matrix_mul
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
MatrixMul
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
MatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
(),
config
);
auto
inp1
=
SymbolVar
{
inputs
[
0
]},
inp2
=
SymbolVar
{
inputs
[
1
]};
auto
dim1
=
matmul
.
dimA
,
dim2
=
matmul
.
dimB
;
auto
cn
=
inputs
[
0
]
->
comp_node
();
using
Desc
=
opr
::
AxisAddRemove
::
AxisDesc
;
using
IndexDesc
=
opr
::
Subtensor
::
IndexDesc
;
OperatorNodeConfig
config
{
matmul
.
make_name
(),
cn
};
DTypeScalar
vi
{
-
1
};
auto
graph
=
inputs
[
0
]
->
owner_graph
();
bool
remove_row
=
false
,
remove_col
=
false
;
if
(
dim1
==
1
)
{
dim1
=
2
;
remove_row
=
true
;
inp1
=
inp1
.
add_axis
(
0
);
}
if
(
dim2
==
1
)
{
dim2
=
2
;
remove_col
=
true
;
inp2
=
inp2
.
add_axis
(
1
);
}
SymbolVar
shp1_head
,
shp1_tail
,
shp2_head
,
shp2_tail
;
if
(
dim1
>
2
)
{
auto
idx
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
vi
,
config
);
auto
shp1
=
inp1
.
symshape
();
IndexDesc
head_desc
(
1
);
head_desc
[
0
].
end
=
idx
;
shp1_head
=
opr
::
Subtensor
::
make
(
shp1
,
head_desc
);
auto
batch
=
opr
::
Reduce
::
make
(
shp1_head
,
{
Reduce
::
Mode
::
PRODUCT
,
0
});
IndexDesc
tail_desc
(
1
);
tail_desc
[
0
].
begin
=
idx
;
shp1_tail
=
opr
::
Subtensor
::
make
(
shp1
,
tail_desc
);
auto
tshp
=
opr
::
Concat
::
make
({
batch
,
shp1_tail
},
0
,
cn
);
inp1
=
inp1
.
reshape
(
tshp
);
}
if
(
dim2
>
2
)
{
auto
idx
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
vi
,
config
);
auto
shp2
=
inp2
.
symshape
();
IndexDesc
head_desc
(
1
);
head_desc
[
0
].
end
=
idx
;
shp2_head
=
opr
::
Subtensor
::
make
(
shp2
,
head_desc
);
auto
batch
=
opr
::
Reduce
::
make
(
shp2_head
,
{
Reduce
::
Mode
::
PRODUCT
,
0
});
IndexDesc
tail_desc
(
1
);
tail_desc
[
0
].
begin
=
idx
;
auto
shp2_tail
=
opr
::
Subtensor
::
make
(
shp2
,
tail_desc
);
auto
tshp
=
opr
::
Concat
::
make
({
batch
,
shp2_tail
},
0
,
cn
);
inp2
=
inp2
.
reshape
(
tshp
);
}
auto
result
=
opr
::
MatrixMul
::
make
(
inp1
,
inp2
,
matmul
.
param
(),
matmul
.
policy
(),
config
);
if
(
dim1
>
2
)
{
auto
idx
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
vi
,
config
);
auto
result_shape
=
result
.
symshape
();
IndexDesc
tail_desc
(
1
);
tail_desc
[
0
].
begin
=
idx
;
auto
shp_tail
=
opr
::
Subtensor
::
make
(
result_shape
,
tail_desc
);
auto
tshp
=
opr
::
Concat
::
make
({
shp1_head
,
shp_tail
},
0
,
cn
);
result
=
result
.
reshape
(
tshp
);
}
if
(
dim2
>
2
)
{
auto
idx
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
vi
,
config
);
auto
result_shape
=
result
.
symshape
();
IndexDesc
tail_desc
(
1
);
tail_desc
[
0
].
begin
=
idx
;
auto
shp_tail
=
opr
::
Subtensor
::
make
(
result_shape
,
tail_desc
);
auto
tshp
=
opr
::
Concat
::
make
({
shp2_head
,
shp_tail
},
0
,
cn
);
result
=
result
.
reshape
(
tshp
);
}
auto
maxdim
=
dim1
>
dim2
?
dim1
:
dim2
;
if
(
remove_row
)
{
std
::
vector
<
Desc
>
remove_param
;
remove_param
.
push_back
(
Desc
::
make_remove
(
maxdim
-
2
));
result
=
opr
::
AxisAddRemove
::
make
(
result
,
remove_param
);
}
if
(
remove_col
)
{
std
::
vector
<
Desc
>
remove_param
;
remove_param
.
push_back
(
Desc
::
make_remove
(
maxdim
-
1
));
result
=
opr
::
AxisAddRemove
::
make
(
result
,
remove_param
);
}
return
result
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
...
...
@@ -27,8 +112,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
auto
layout2
=
inputs
[
1
].
layout
;
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
DType
dst_dtype
;
DnnOprCaller
<
megdnn
::
MatrixMul
>
dnn_opr
(
inputs
[
0
].
comp_node
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
dnn_opr
.
op
->
deduce_dtype
(
layout1
.
dtype
,
layout1
.
dtype
,
dst_dtype
);
if
(
dim1
==
0
||
dim2
==
0
)
{
return
{{{
TensorLayout
(
layout1
.
dtype
),
inputs
[
0
].
comp_node
}},
false
};
return
{{{
TensorLayout
(
dst_
dtype
),
inputs
[
0
].
comp_node
}},
false
};
}
if
(
matmul
.
transposeA
)
...
...
@@ -37,7 +128,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
std
::
swap
(
layout2
[
0
],
layout2
[
1
]);
mgb_assert
(
layout1
[
dim1
-
1
]
==
layout2
[
0
]);
TensorLayout
dst_layout
(
layout1
.
dtype
);
TensorLayout
dst_layout
(
dst_dtype
);
size_t
ci
=
0
;
for
(
size_t
i
=
0
;
i
<
dim1
-
1
;
i
++
)
dst_layout
[
ci
++
]
=
layout1
[
i
];
...
...
@@ -61,6 +153,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
SmallVector
<
TensorND
>
inp_tensornds
(
inputs
.
size
());
TensorLayout
layout1
=
inputs
[
0
]
->
layout
(),
layout2
=
inputs
[
1
]
->
layout
();
DnnOprCaller
<
megdnn
::
MatrixMul
>
dnn_opr
(
cn
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
DType
dst_dtype
;
dnn_opr
.
op
->
deduce_dtype
(
layout1
.
dtype
,
layout1
.
dtype
,
dst_dtype
);
// only matters when layout1 has dim 2
if
(
matmul
.
transposeA
)
std
::
swap
(
layout1
.
shape
[
0
],
layout1
.
shape
[
1
]);
...
...
@@ -69,7 +167,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
std
::
swap
(
layout2
.
shape
[
0
],
layout2
.
shape
[
1
]);
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
TensorLayout
real_dst_layout
(
layout1
.
dtype
);
TensorLayout
real_dst_layout
(
dst_
dtype
);
if
(
validated
)
{
real_dst_layout
=
output_descs
[
0
].
layout
;
}
else
{
...
...
@@ -126,12 +224,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
inp_tensornds
[
1
]
=
inputs
[
1
]
->
dnn_tensor
();
}
TensorLayout
dst_layout
=
TensorLayout
({
layout_a
[
0
],
layout_b
[
1
]},
layout_a
.
dtype
);
TensorLayout
dst_layout
=
TensorLayout
({
layout_a
[
0
],
layout_b
[
1
]},
dst_
dtype
);
dst_layout
.
init_contiguous_stride
();
DnnOprCaller
<
megdnn
::
MatrixMul
>
dnn_opr
(
cn
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
dst_layout
);
size_t
sz
=
setup_algo
<
megdnn
::
MatrixMul
>
(
...
...
@@ -167,9 +262,99 @@ namespace batched_matrix_mul {
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
def
.
cast_final_safe
<
BatchedMatrixMul
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
BatchedMatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
(),
config
);
auto
inp1
=
SymbolVar
{
inputs
[
0
]},
inp2
=
SymbolVar
{
inputs
[
1
]};
auto
dim1
=
matmul
.
dimA
,
dim2
=
matmul
.
dimB
;
auto
cn
=
inputs
[
0
]
->
comp_node
();
using
Desc
=
opr
::
AxisAddRemove
::
AxisDesc
;
using
IndexDesc
=
opr
::
Subtensor
::
IndexDesc
;
OperatorNodeConfig
config
{
matmul
.
make_name
(),
cn
};
DTypeScalar
vi
{
-
2
};
auto
graph
=
inputs
[
0
]
->
owner_graph
();
auto
idx
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
vi
,
config
);
bool
remove_row
=
false
,
remove_col
=
false
;
if
(
dim1
==
1
)
{
dim1
=
2
;
remove_row
=
true
;
inp1
=
inp1
.
add_axis
(
0
);
}
if
(
dim2
==
1
)
{
dim2
=
2
;
remove_col
=
true
;
inp2
=
inp2
.
add_axis
(
1
);
}
auto
shp1
=
inp1
.
symshape
();
auto
shp2
=
inp2
.
symshape
();
SymbolVar
shp1_head
,
shp1_tail
,
shp2_head
,
shp2_tail
;
SymbolVar
batch_shape
;
if
(
dim1
>
dim2
)
{
HostTensorND
hv
=
HostTensorND
(
cn
,
{
1
},
dtype
::
Int32
());
auto
*
ptr
=
hv
.
ptr
<
dt_int32
>
();
ptr
[
0
]
=
-
dim2
;
IndexDesc
head_desc
(
1
);
head_desc
[
0
].
end
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
);
shp1_head
=
opr
::
Subtensor
::
make
(
shp1
,
head_desc
);
shp2
=
opr
::
Concat
::
make
({
shp1_head
,
shp2
},
0
,
cn
);
inp2
=
inp2
.
broadcast
(
shp2
);
head_desc
[
0
].
end
=
idx
;
batch_shape
=
opr
::
Subtensor
::
make
(
shp1
,
head_desc
);
}
if
(
dim2
>
dim1
)
{
HostTensorND
hv
=
HostTensorND
(
cn
,
{
1
},
dtype
::
Int32
());
auto
*
ptr
=
hv
.
ptr
<
dt_int32
>
();
ptr
[
0
]
=
-
dim1
;
IndexDesc
head_desc
(
1
);
head_desc
[
0
].
end
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
);
shp2_head
=
opr
::
Subtensor
::
make
(
shp2
,
head_desc
);
shp1
=
opr
::
Concat
::
make
({
shp2_head
,
shp1
},
0
,
cn
);
inp1
=
inp1
.
broadcast
(
shp1
);
head_desc
[
0
].
end
=
idx
;
batch_shape
=
opr
::
Subtensor
::
make
(
shp2
,
head_desc
);
}
if
(
dim1
==
dim2
)
{
IndexDesc
head_desc
(
1
);
head_desc
[
0
].
end
=
idx
;
batch_shape
=
opr
::
Subtensor
::
make
(
shp1
,
head_desc
);
}
auto
maxdim
=
dim1
>
dim2
?
dim1
:
dim2
;
if
(
maxdim
>
3
)
{
IndexDesc
tail_desc
(
1
);
tail_desc
[
0
].
begin
=
idx
;
shp1_tail
=
opr
::
Subtensor
::
make
(
shp1
,
tail_desc
);
auto
batch
=
opr
::
Reduce
::
make
(
batch_shape
,
{
Reduce
::
Mode
::
PRODUCT
,
0
});
shp1
=
opr
::
Concat
::
make
({
batch
,
shp1_tail
},
0
,
cn
);
inp1
=
inp1
.
reshape
(
shp1
);
shp2_tail
=
opr
::
Subtensor
::
make
(
shp2
,
tail_desc
);
shp2
=
opr
::
Concat
::
make
({
batch
,
shp2_tail
},
0
,
cn
);
inp2
=
inp2
.
reshape
(
shp2
);
}
auto
result
=
opr
::
BatchedMatrixMul
::
make
(
inp1
,
inp2
,
matmul
.
param
(),
matmul
.
policy
(),
config
);
if
(
maxdim
>
3
)
{
auto
result_shp
=
result
.
symshape
();
IndexDesc
tail_desc
(
1
);
tail_desc
[
0
].
begin
=
idx
;
auto
shp_tail
=
opr
::
Subtensor
::
make
(
result_shp
,
tail_desc
);
result_shp
=
opr
::
Concat
::
make
({
batch_shape
,
shp_tail
},
0
,
cn
);
result
=
result
.
reshape
(
result_shp
);
}
if
(
remove_row
)
{
std
::
vector
<
Desc
>
remove_param
;
remove_param
.
push_back
(
Desc
::
make_remove
(
maxdim
-
2
));
result
=
opr
::
AxisAddRemove
::
make
(
result
,
remove_param
);
}
if
(
remove_col
)
{
std
::
vector
<
Desc
>
remove_param
;
remove_param
.
push_back
(
Desc
::
make_remove
(
maxdim
-
1
));
result
=
opr
::
AxisAddRemove
::
make
(
result
,
remove_param
);
}
return
result
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
...
...
@@ -178,8 +363,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
TensorLayout
layout1
=
inputs
[
0
].
layout
,
layout2
=
inputs
[
1
].
layout
;
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
DType
dst_dtype
;
DnnOprCaller
<
megdnn
::
MatrixMul
>
dnn_opr
(
inputs
[
0
].
comp_node
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
dnn_opr
.
op
->
deduce_dtype
(
layout1
.
dtype
,
layout1
.
dtype
,
dst_dtype
);
if
(
dim1
==
0
||
dim2
==
0
)
{
return
{{{
TensorLayout
(
layout1
.
dtype
),
inputs
[
0
].
comp_node
}},
false
};
return
{{{
TensorLayout
(
dst_
dtype
),
inputs
[
0
].
comp_node
}},
false
};
}
if
(
matmul
.
transposeA
)
...
...
@@ -187,7 +378,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
if
(
matmul
.
transposeB
)
std
::
swap
(
layout2
[
dim2
-
1
],
layout2
[
dim2
-
2
]);
TensorLayout
dst_layout
(
layout1
.
dtype
);
TensorLayout
dst_layout
(
dst_
dtype
);
size_t
di
=
0
;
if
(
dim1
>
dim2
)
{
for
(
size_t
i
=
0
;
i
<
dim1
-
2
;
i
++
)
...
...
@@ -217,6 +408,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorLayout
layout1
=
inputs
[
0
]
->
layout
(),
layout2
=
inputs
[
1
]
->
layout
();
size_t
dim1
=
layout1
.
ndim
,
dim2
=
layout2
.
ndim
;
DnnOprCaller
<
megdnn
::
BatchedMatrixMul
>
dnn_opr
(
cn
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
DType
dst_dtype
;
dnn_opr
.
op
->
deduce_dtype
(
layout1
.
dtype
,
layout1
.
dtype
,
dst_dtype
);
bool
remove_row
=
false
,
remove_col
=
false
;
if
(
dim1
==
1
)
{
dim1
=
2
;
...
...
@@ -234,6 +430,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorShape
tshp
,
batch_shp
;
size_t
j
=
0
;
auto
inp1
=
inputs
[
0
],
inp2
=
inputs
[
1
];
if
(
dim1
>
dim2
)
{
for
(
size_t
i
=
0
;
i
<
dim1
-
2
;
i
++
)
tshp
[
j
++
]
=
layout1
.
shape
[
i
];
...
...
@@ -266,7 +463,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
shp2
.
ndim
+=
2
;
size_t
maxdim
=
dim1
>
dim2
?
dim1
:
dim2
;
size_t
nbatch
=
batch_shp
[
0
];
auto
inp1
=
inputs
[
0
],
inp2
=
inputs
[
1
];
if
(
maxdim
>
3
)
{
nbatch
=
std
::
accumulate
(
batch_shp
.
shape
,
batch_shp
.
shape
+
batch_shp
.
ndim
,
(
size_t
)
1
,
...
...
@@ -274,29 +470,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
TensorLayout
layout_a
;
// batched_matmul does not support memory forwarding, so ensure contiguous
// manually
TensorShape
nl1
=
TensorShape
(
{
nbatch
,
layout1
[
layout1
.
ndim
-
2
],
layout1
[
layout1
.
ndim
-
1
]});
if
(
!
layout1
.
try_reshape
(
layout_a
,
nl1
))
{
inp1
=
Tensor
::
make
(
inputs
[
0
]
->
blob
(),
inputs
[
0
]
->
offset
(),
layout1
);
inp1
->
to_contiguous_inplace
();
layout1
=
inp1
->
layout
();
}
inp1
=
Tensor
::
make
(
inputs
[
0
]
->
blob
(),
inputs
[
0
]
->
offset
(),
layout1
);
inp1
->
to_contiguous_inplace
();
layout1
=
inp1
->
layout
();
layout_a
=
layout1
.
reshape
(
nl1
);
layout1
=
layout_a
;
TensorShape
nl2
=
TensorShape
(
{
nbatch
,
layout2
[
layout2
.
ndim
-
2
],
layout2
[
layout2
.
ndim
-
1
]});
if
(
!
layout2
.
try_reshape
(
layout_a
,
nl2
))
{
inp2
=
Tensor
::
make
(
inputs
[
1
]
->
blob
(),
inputs
[
1
]
->
offset
(),
layout2
);
inp2
->
to_contiguous_inplace
();
layout2
=
inp2
->
layout
();
}
inp2
=
Tensor
::
make
(
inputs
[
1
]
->
blob
(),
inputs
[
1
]
->
offset
(),
layout2
);
inp2
->
to_contiguous_inplace
();
layout2
=
inp2
->
layout
();
layout_a
=
layout2
.
reshape
(
nl2
);
layout2
=
layout_a
;
}
TensorLayout
dst_layout
(
{
nbatch
,
matmul
.
transposeA
?
layout1
[
2
]
:
layout1
[
1
],
matmul
.
transposeB
?
layout2
[
1
]
:
layout2
[
2
]},
layout1
.
dtype
);
dst_
dtype
);
dst_layout
.
init_contiguous_stride
();
if
(
dim1
==
0
||
dim2
==
0
||
layout1
[
layout1
.
ndim
-
1
]
==
0
)
{
...
...
@@ -317,9 +513,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
DeviceTensorND
out
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
dst_layout
);
DnnOprCaller
<
megdnn
::
BatchedMatrixMul
>
dnn_opr
(
cn
);
dnn_opr
.
op
->
param
()
=
matmul
.
param
();
size_t
sz
=
setup_algo
<
megdnn
::
BatchedMatrixMul
>
(
{
layout1
,
layout2
,
dst_layout
},
dnn_opr
.
op
.
get
(),
0
,
false
,
false
,
cn
,
matmul
.
policy
(),
false
);
...
...
imperative/tablegen/helper.h
浏览文件 @
409c9881
...
...
@@ -246,7 +246,12 @@ private:
it
.
name
,
enumMember
.
substr
(
0
,
d
));
body
+=
" break;
\n
"
;
}
body
+=
" default: break;
\n
"
;
body
+=
" default:
\n
"
;
body
+=
formatv
(
" props_.emplace_back(
\"
{0}
\"
, "
"
\"
INVALID
\"
);
\n
"
,
it
.
name
);
body
+=
" break;
\n
"
;
body
+=
" }
\n
"
;
}
else
{
auto
&&
attr
=
llvm
::
cast
<
MgbHashableAttrMixin
>
(
it
.
attr
);
...
...
imperative/tablegen/targets/cpp_class.cpp
浏览文件 @
409c9881
...
...
@@ -89,19 +89,35 @@ void OpDefEmitter::emit_header() {
gen_ctor
(
""
,
""
,
" = default;"
);
if
(
!
op
.
getMgbAttributes
().
empty
())
{
std
::
string
strategy_val
=
""
;
std
::
vector
<
std
::
string
>
paramList
,
initList
;
for
(
auto
&&
i
:
op
.
getMgbAttributes
())
{
if
(
attr_to_ctype
(
i
.
attr
).
compare
(
"Strategy"
)
==
0
)
{
strategy_val
=
i
.
name
;
}
paramList
.
push_back
(
formatv
(
"{0} {1}_"
,
attr_to_ctype
(
i
.
attr
),
i
.
name
));
initList
.
push_back
(
formatv
(
"{0}({0}_)"
,
i
.
name
));
}
paramList
.
push_back
(
"std::string scope_ = {}"
);
gen_ctor
(
llvm
::
join
(
paramList
,
", "
),
": "
+
llvm
::
join
(
initList
,
", "
),
" { set_scope(scope_); }"
);
if
(
!
strategy_val
.
empty
())
{
gen_ctor
(
llvm
::
join
(
paramList
,
", "
),
": "
+
llvm
::
join
(
initList
,
", "
),
formatv
(
" {"
"
\n
set_scope(scope_);"
"
\n
mgb_assert(static_cast<uint32_t>({0}) <= "
"uint32_t(8));"
"
\n
}"
,
strategy_val
));
}
else
{
gen_ctor
(
llvm
::
join
(
paramList
,
", "
),
": "
+
llvm
::
join
(
initList
,
", "
),
" { set_scope(scope_); }"
);
}
}
auto
packedParams
=
op
.
getPackedParams
();
if
(
!
packedParams
.
empty
())
{
std
::
string
strategy_val
=
""
;
std
::
vector
<
std
::
string
>
paramList
,
initList
;
for
(
auto
&&
p
:
packedParams
)
{
auto
&&
paramFields
=
p
.
getFields
();
...
...
@@ -111,6 +127,9 @@ void OpDefEmitter::emit_header() {
paramFields
.
empty
()
?
paramType
.
str
()
:
formatv
(
"{0} {1}"
,
paramType
,
paramName
));
for
(
auto
&&
i
:
paramFields
)
{
if
(
i
.
name
.
compare
(
"strategy"
)
==
0
)
{
strategy_val
=
i
.
name
;
}
initList
.
push_back
(
formatv
(
"{0}({1}.{0})"
,
i
.
name
,
paramName
));
}
}
...
...
@@ -118,9 +137,20 @@ void OpDefEmitter::emit_header() {
paramList
.
push_back
(
formatv
(
"{0} {1}_"
,
attr_to_ctype
(
i
.
attr
),
i
.
name
));
initList
.
push_back
(
formatv
(
"{0}({0}_)"
,
i
.
name
));
}
gen_ctor
(
llvm
::
join
(
paramList
,
", "
),
initList
.
empty
()
?
""
:
": "
+
llvm
::
join
(
initList
,
", "
),
" {}"
);
if
(
!
strategy_val
.
empty
())
{
gen_ctor
(
llvm
::
join
(
paramList
,
", "
),
initList
.
empty
()
?
""
:
": "
+
llvm
::
join
(
initList
,
", "
),
formatv
(
" {"
"
\n
mgb_assert(static_cast<uint32_t>({0}) <= "
"uint32_t(8));"
"
\n
}"
,
strategy_val
));
}
else
{
gen_ctor
(
llvm
::
join
(
paramList
,
", "
),
initList
.
empty
()
?
""
:
": "
+
llvm
::
join
(
initList
,
", "
),
" {}"
);
}
}
if
(
!
packedParams
.
empty
())
{
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
409c9881
...
...
@@ -43,9 +43,19 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
let extraArguments = (ins
MgbUI32Attr:$dimA,
MgbUI32Attr:$dimB
);
}
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> {
let extraArguments = (ins
MgbUI32Attr:$dimA,
MgbUI32Attr:$dimB
);
}
def Dot: MgbHashableOp<"Dot", [EmptyParam]>;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录