Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1fa143ce
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1fa143ce
编写于
10月 19, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/functional): matmul supports symbolic shape, batched mv multiply
GitOrigin-RevId: c4d8cf3306cd833828eca0fc7372397cbf2cc36f
上级
d47cf332
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
68 addition
and
49 deletion
+68
-49
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+38
-36
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+30
-13
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
1fa143ce
...
@@ -23,7 +23,7 @@ from ..tensor import Tensor
...
@@ -23,7 +23,7 @@ from ..tensor import Tensor
from
.debug_param
import
get_conv_execution_strategy
from
.debug_param
import
get_conv_execution_strategy
from
.distributed
import
all_reduce_sum
from
.distributed
import
all_reduce_sum
from
.elemwise
import
exp
,
floor
,
log
,
log1p
,
maximum
,
minimum
,
relu
from
.elemwise
import
exp
,
floor
,
log
,
log1p
,
maximum
,
minimum
,
relu
from
.math
import
argsort
,
max
,
sum
from
.math
import
argsort
,
max
,
prod
,
sum
from
.tensor
import
(
from
.tensor
import
(
broadcast_to
,
broadcast_to
,
concat
,
concat
,
...
@@ -972,38 +972,42 @@ def matmul(
...
@@ -972,38 +972,42 @@ def matmul(
[28. 40.]]
[28. 40.]]
"""
"""
remove_row
,
remove_col
=
False
,
False
inp1
,
inp2
=
utils
.
convert_inputs
(
inp1
,
inp2
)
inp1
,
inp2
=
utils
.
convert_inputs
(
inp1
,
inp2
)
dim1
,
dim2
=
inp1
.
ndim
,
inp2
.
ndim
dim1
,
dim2
=
inp1
.
ndim
,
inp2
.
ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if
dim1
==
1
and
dim2
==
1
:
if
dim1
==
1
and
dim2
==
1
:
return
dot
(
inp1
,
inp2
)
return
dot
(
inp1
,
inp2
)
# the underlying matmul op requires input dims to be at least 2
shp
=
None
if
dim1
==
1
:
if
dim1
>
3
or
dim2
>
3
:
inp1
=
expand_dims
(
inp1
,
0
)
shape1
,
shape2
=
list
(
inp1
.
shape
),
list
(
inp2
.
shape
)
dim1
=
2
if
dim1
!=
dim2
:
remove_row
=
True
if
dim2
==
1
:
inp2
=
expand_dims
(
inp2
,
1
)
dim2
=
2
remove_col
=
True
batch_shape
=
None
shape1
=
astensor1d
(
inp1
.
shape
,
inp1
,
dtype
=
"int32"
,
device
=
inp1
.
device
)
shape2
=
astensor1d
(
inp2
.
shape
,
inp2
,
dtype
=
"int32"
,
device
=
inp2
.
device
)
if
dim1
>=
3
or
dim2
>=
3
:
if
dim1
==
dim2
:
assert
(
shape1
[:
-
2
]
==
shape2
[:
-
2
]
).
min
(),
"operands could not be broadcasted together."
if
dim1
>
dim2
:
shape2
=
concat
([
shape1
[:
-
2
],
shape2
[
-
2
:]])
inp2
=
broadcast_to
(
inp2
,
shape2
)
if
dim1
<
dim2
:
if
dim1
<
dim2
:
shape1
=
shape2
[:
dim2
-
dim1
]
+
shape1
shape1
=
concat
([
shape2
[:
-
2
],
shape1
[
-
2
:]])
inp1
=
broadcast_to
(
inp1
,
shape1
)
inp1
=
broadcast_to
(
inp1
,
shape1
)
else
:
batch_shape
=
shape1
[:
-
2
]
shape2
=
shape1
[:
dim1
-
dim2
]
+
shape2
# compress inputs to 3d
inp2
=
broadcast_to
(
inp2
,
shape2
)
inp1
=
inp1
.
reshape
(
concat
([
prod
(
shape1
[:
-
2
]),
shape1
[
-
2
:]]))
reshaped_batch_size
=
1
inp2
=
inp2
.
reshape
(
concat
([
prod
(
shape2
[:
-
2
]),
shape2
[
-
2
:]]))
for
i
in
shape1
[:
-
2
]:
reshaped_batch_size
*=
i
inp1
=
inp1
.
reshape
(
*
([
reshaped_batch_size
]
+
shape1
[
-
2
:]))
inp2
=
inp2
.
reshape
(
*
([
reshaped_batch_size
]
+
shape2
[
-
2
:]))
op
=
builtin
.
BatchedMatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
)
shp
=
shape1
[:
-
1
]
+
shape2
[
-
1
:]
elif
dim1
==
3
or
dim2
==
3
:
if
dim2
<
3
:
inp2
=
broadcast_to
(
inp2
,
inp1
.
shape
[:
1
]
+
inp2
.
shape
)
elif
dim1
<
3
:
inp1
=
broadcast_to
(
inp1
,
inp2
.
shape
[:
1
]
+
inp1
.
shape
)
op
=
builtin
.
BatchedMatrixMul
(
op
=
builtin
.
BatchedMatrixMul
(
transposeA
=
transpose_a
,
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
transposeB
=
transpose_b
,
...
@@ -1011,12 +1015,6 @@ def matmul(
...
@@ -1011,12 +1015,6 @@ def matmul(
format
=
format
,
format
=
format
,
)
)
else
:
else
:
if
dim1
==
1
:
shp
=
(
inp2
.
shape
[
1
],)
inp1
=
expand_dims
(
inp1
,
0
)
if
dim2
==
1
:
shp
=
(
inp1
.
shape
[
0
],)
inp2
=
expand_dims
(
inp2
,
1
)
op
=
builtin
.
MatrixMul
(
op
=
builtin
.
MatrixMul
(
transposeA
=
transpose_a
,
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
transposeB
=
transpose_b
,
...
@@ -1025,8 +1023,12 @@ def matmul(
...
@@ -1025,8 +1023,12 @@ def matmul(
)
)
(
result
,)
=
apply
(
op
,
inp1
,
inp2
)
(
result
,)
=
apply
(
op
,
inp1
,
inp2
)
if
shp
is
not
None
:
if
batch_shape
is
not
None
:
result
=
result
.
reshape
(
shp
)
result
=
result
.
reshape
(
concat
([
batch_shape
,
result
.
shape
[
-
2
:]]))
if
remove_row
:
result
=
squeeze
(
result
,
axis
=-
2
)
if
remove_col
:
result
=
squeeze
(
result
,
axis
=-
1
)
return
result
return
result
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
1fa143ce
...
@@ -77,24 +77,41 @@ def test_matmul():
...
@@ -77,24 +77,41 @@ def test_matmul():
opr_test
(
cases
,
F
.
matmul
,
ref_fn
=
np
.
matmul
)
opr_test
(
cases
,
F
.
matmul
,
ref_fn
=
np
.
matmul
)
batch_size
=
10
batch_size
=
10
shape1
=
(
batch_size
,
2
,
3
)
shape1
=
(
2
,)
shape2
=
(
batch_size
,
3
,
4
)
shape2
=
(
batch_size
,
2
,
3
)
shape3
=
(
batch_size
,
10
,
4
,
5
)
shape3
=
(
batch_size
,
3
,
4
)
shape4
=
(
batch_size
,
10
,
4
,
2
)
shape5
=
(
batch_size
,
10
,
2
,
4
)
data1
=
np
.
random
.
random
(
shape1
).
astype
(
"float32"
)
data1
=
np
.
random
.
random
(
shape1
).
astype
(
"float32"
)
data2
=
np
.
random
.
random
(
shape2
).
astype
(
"float32"
)
data2
=
np
.
random
.
random
(
shape2
).
astype
(
"float32"
)
data3
=
np
.
random
.
random
(
shape3
).
astype
(
"float32"
)
data3
=
np
.
random
.
random
(
shape3
).
astype
(
"float32"
)
data4
=
np
.
random
.
random
(
shape4
).
astype
(
"float32"
)
data5
=
np
.
random
.
random
(
shape5
).
astype
(
"float32"
)
cases
=
[{
"input"
:
[
data1
,
data2
]},
{
"input"
:
[
data2
,
data3
]}]
cases
=
[
for
i
in
range
(
0
,
batch_size
):
{
"input"
:
[
data1
,
data2
]},
{
"input"
:
[
data2
,
data3
]},
{
"input"
:
[
data3
,
data4
]},
{
"input"
:
[
data4
,
data5
]},
]
for
_
in
range
(
0
,
batch_size
):
opr_test
(
cases
,
F
.
matmul
,
ref_fn
=
np
.
matmul
,
)
def
compare_fn
(
x
,
y
):
opr_test
(
x
.
numpy
()[
i
,
...]
==
y
[{
"input"
:
[
data1
,
data4
]}],
F
.
matmul
,
ref_fn
=
lambda
x
,
y
:
np
.
matmul
(
x
,
y
.
transpose
(
0
,
1
,
3
,
2
)),
transpose_b
=
True
,
)
opr_test
(
opr_test
(
cases
,
[{
"input"
:
[
data3
,
data2
]}]
,
F
.
matmul
,
F
.
matmul
,
compare_fn
=
compare_fn
,
ref_fn
=
lambda
x
,
y
:
np
.
matmul
(
x
.
transpose
(
0
,
2
,
1
),
y
.
transpose
(
0
,
2
,
1
)),
ref_fn
=
lambda
x
,
y
:
np
.
matmul
(
x
[
i
,
...],
y
[
i
,
...]),
transpose_a
=
True
,
transpose_b
=
True
,
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录