Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3206af9d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3206af9d
编写于
8月 06, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(functional/matmul): reimplement matmul with subgraph
GitOrigin-RevId: 456b2a51d35852152c46d31548dac96f977d5b41
上级
8c47c1f1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
243 addition
and
75 deletion
+243
-75
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+243
-75
未找到文件。
imperative/python/megengine/functional/math.py
浏览文件 @
3206af9d
...
@@ -8,17 +8,21 @@
...
@@ -8,17 +8,21 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
collections
import
math
import
math
from
functools
import
lru_cache
from
typing
import
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core._trace_option
import
use_symbolic_shape
from
..core._trace_option
import
use_symbolic_shape
from
..core.ops
import
builtin
from
..core.ops
import
builtin
from
..core.ops.builtin
import
BatchNorm
,
Elemwise
,
GetVarShape
,
Reduce
,
TypeCvt
from
..core.ops.special
import
Const
from
..core.ops.special
import
Const
from
..core.tensor
import
amp
from
..core.tensor
import
amp
from
..core.tensor.utils
import
_normalize_axis
,
cast_tensors
,
setscalar
from
..core.tensor.utils
import
_normalize_axis
,
cast_tensors
,
setscalar
,
subgraph
from
..jit
import
exclude_from_trace
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
.debug_param
import
get_execution_strategy
from
.debug_param
import
get_execution_strategy
from
.elemwise
import
clip
from
.elemwise
import
clip
,
minimum
from
.tensor
import
broadcast_to
,
concat
,
expand_dims
,
squeeze
from
.tensor
import
broadcast_to
,
concat
,
expand_dims
,
squeeze
__all__
=
[
__all__
=
[
...
@@ -763,6 +767,216 @@ def matinv(inp: Tensor) -> Tensor:
...
@@ -763,6 +767,216 @@ def matinv(inp: Tensor) -> Tensor:
return
result
return
result
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
3
)
def
extentedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
1
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
1
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
GetVarShape
(),
inp1
)
shape2
=
f
(
GetVarShape
(),
inp2
)
if
_dim1
>
2
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape1
)),
build_shape_tail
(
shape1
),
),
)
if
_dim2
>
2
:
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape2
)),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
MatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
,
)
result
=
f
(
op
,
inp1
,
inp2
)
result_shape
=
f
(
GetVarShape
(),
result
)
if
_dim1
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
),
build_shape_tail
(
result_shape
),
),
)
if
_dim2
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
),
build_shape_tail
(
result_shape
),
),
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedMatrixMulOp
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedBatchedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedBatchedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
3
)
def
extentedBatchedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
2
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
2
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
GetVarShape
(),
inp1
)
shape2
=
f
(
GetVarShape
(),
inp2
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
_dim1
>
_dim2
:
# broadcast
shape2
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
,
idx
=-
_dim2
),
# shape1[:-_dim2]
shape2
,
)
inp2
=
f
(
builtin
.
Broadcast
(),
inp2
,
shape2
)
batch_shape
=
build_shape_head
(
shape1
)
if
_dim2
>
_dim1
:
# broadcast
shape1
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
,
idx
=-
_dim1
),
# shape2[:-_dim1]
shape1
,
)
inp1
=
f
(
builtin
.
Broadcast
(),
inp1
,
shape1
)
batch_shape
=
build_shape_head
(
shape2
)
if
_dim1
==
_dim2
:
batch_shape
=
build_shape_head
(
shape1
)
# compress inputs to 3d
if
maxdim
>
3
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape1
),
),
)
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
BatchedMatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
,
)
result
=
f
(
op
,
inp1
,
inp2
)
if
maxdim
>
3
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
batch_shape
,
build_shape_tail
(
f
(
GetVarShape
(),
result
)),
),
)
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedBatchedMatrixMulOp
def
matmul
(
def
matmul
(
inp1
:
Tensor
,
inp1
:
Tensor
,
inp2
:
Tensor
,
inp2
:
Tensor
,
...
@@ -822,85 +1036,39 @@ def matmul(
...
@@ -822,85 +1036,39 @@ def matmul(
if
inp2
.
dtype
!=
dtype
:
if
inp2
.
dtype
!=
dtype
:
inp2
=
inp2
.
astype
(
dtype
)
inp2
=
inp2
.
astype
(
dtype
)
remove_row
,
remove_col
=
False
,
False
dim1
,
dim2
=
inp1
.
ndim
,
inp2
.
ndim
dim1
,
dim2
=
inp1
.
ndim
,
inp2
.
ndim
# handle dim=1 cases, dot and matrix-vector multiplication
assert
dim1
>
0
and
dim2
>
0
if
dim1
==
1
and
dim2
==
1
:
return
dot
(
inp1
,
inp2
)
# the underlying matmul op requires input dims to be at least 2
if
dim1
==
1
:
inp1
=
expand_dims
(
inp1
,
0
)
dim1
=
2
remove_row
=
True
if
dim2
==
1
:
inp2
=
expand_dims
(
inp2
,
1
)
dim2
=
2
remove_col
=
True
batch_shape
=
None
shape1
=
inp1
.
shape
shape2
=
inp2
.
shape
maxdim
=
dim1
if
dim1
>
dim2
else
dim2
maxdim
=
dim1
if
dim1
>
dim2
else
dim2
if
dim1
>=
3
or
dim2
>=
3
:
if
dim1
==
1
and
dim2
==
1
:
# dispatch to Dot
if
use_symbolic_shape
():
return
dot
(
inp1
,
inp2
)
if
dim1
>
dim2
:
elif
maxdim
<=
2
or
dim2
<=
2
:
# dispath to MatrixMul
shape2
=
concat
([
shape1
[:
-
2
],
shape2
[
-
2
:]])
extentedMatrixMulOp
=
_get_extentedMatrixMulOp
(
inp2
=
broadcast_to
(
inp2
,
shape2
)
inp1
.
device
,
if
dim1
<
dim2
:
inp1
.
dtype
,
shape1
=
concat
([
shape2
[:
-
2
],
shape1
[
-
2
:]])
dim1
,
inp1
=
broadcast_to
(
inp1
,
shape1
)
dim2
,
if
maxdim
>
3
:
transpose_a
,
batch_shape
=
shape1
[:
-
2
]
transpose_b
,
# compress inputs to 3d
compute_mode
,
(
inp1
,)
=
apply
(
format
,
builtin
.
Reshape
(),
inp1
,
concat
([
prod
(
shape1
[:
-
2
]),
shape1
[
-
2
:]])
)
(
inp2
,)
=
apply
(
builtin
.
Reshape
(),
inp2
,
concat
([
prod
(
shape2
[:
-
2
]),
shape2
[
-
2
:]])
)
else
:
if
dim1
>
dim2
:
shape2
=
shape1
[:
-
2
]
+
shape2
[
-
2
:]
inp2
=
broadcast_to
(
inp2
,
shape2
)
if
dim1
<
dim2
:
shape1
=
shape2
[:
-
2
]
+
shape1
[
-
2
:]
inp1
=
broadcast_to
(
inp1
,
shape1
)
if
maxdim
>
3
:
batch_shape
=
shape1
[:
-
2
]
# compress inputs to 3d
inp1
=
inp1
.
reshape
((
-
1
,
shape1
[
-
2
],
shape1
[
-
1
]))
inp2
=
inp2
.
reshape
((
-
1
,
shape2
[
-
2
],
shape2
[
-
1
]))
op
=
builtin
.
BatchedMatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
get_execution_strategy
(),
strategy
=
get_execution_strategy
(),
)
)
else
:
(
result
,)
=
apply
(
extentedMatrixMulOp
,
inp1
,
inp2
)
op
=
builtin
.
MatrixMul
(
return
result
transposeA
=
transpose_a
,
else
:
# dispath to BatchedMatrixMul
transposeB
=
transpose_b
,
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
compute_mode
=
compute_mode
,
inp1
.
device
,
format
=
format
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
get_execution_strategy
(),
strategy
=
get_execution_strategy
(),
)
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
,
inp1
,
inp2
)
(
result
,)
=
apply
(
op
,
inp1
,
inp2
)
return
result
if
maxdim
>
3
:
if
use_symbolic_shape
():
(
result
,)
=
apply
(
builtin
.
Reshape
(),
result
,
concat
([
batch_shape
,
result
.
shape
[
-
2
:]])
)
else
:
result
=
result
.
reshape
(
batch_shape
+
result
.
shape
[
-
2
:])
if
remove_row
:
result
=
squeeze
(
result
,
axis
=-
2
)
if
remove_col
:
result
=
squeeze
(
result
,
axis
=-
1
)
return
result
def
dot
(
inp1
:
Tensor
,
inp2
:
Tensor
)
->
Tensor
:
def
dot
(
inp1
:
Tensor
,
inp2
:
Tensor
)
->
Tensor
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录