Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ca2deebc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ca2deebc
编写于
2月 11, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative/tensor): make @ operator has the same functionality as matmul functional
GitOrigin-RevId: bf6136cc1a7b5cc00103fd0eee27cb8bca8c6f99
上级
e860a083
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
289 addition
and
288 deletion
+289
-288
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+278
-8
imperative/python/megengine/functional/math.py
imperative/python/megengine/functional/math.py
+5
-277
imperative/python/test/unit/core/test_tensor_wrapper.py
imperative/python/test/unit/core/test_tensor_wrapper.py
+6
-3
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
ca2deebc
...
...
@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
abc
import
collections
from
functools
import
lru_cache
from
typing
import
Union
import
numpy
as
np
...
...
@@ -24,8 +25,8 @@ from .utils import (
astype
,
cast_tensors
,
convert_inputs
,
isscalar
,
make_shape_tuple
,
subgraph
,
)
_ElwMod
=
builtin
.
Elemwise
.
Mode
...
...
@@ -73,23 +74,292 @@ def _elwise(*args, mode):
return
_elwise_apply
(
args
,
mode
)
def
_matmul
(
inp1
,
inp2
):
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
1
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
1
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
builtin
.
GetVarShape
(),
inp1
)
shape2
=
f
(
builtin
.
GetVarShape
(),
inp2
)
if
_dim1
>
2
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape1
)),
build_shape_tail
(
shape1
),
),
)
if
_dim2
>
2
:
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape2
)),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
MatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
.
value
,
)
result
=
f
(
op
,
inp1
,
inp2
)
result_shape
=
f
(
builtin
.
GetVarShape
(),
result
)
if
_dim1
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
),
build_shape_tail
(
result_shape
),
),
)
if
_dim2
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
),
build_shape_tail
(
result_shape
),
),
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedMatrixMulOp
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedBatchedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedBatchedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedBatchedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
2
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
2
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
builtin
.
GetVarShape
(),
inp1
)
shape2
=
f
(
builtin
.
GetVarShape
(),
inp2
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
_dim1
>
_dim2
:
# broadcast
shape2
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
,
idx
=-
_dim2
),
# shape1[:-_dim2]
shape2
,
)
inp2
=
f
(
builtin
.
Broadcast
(),
inp2
,
shape2
)
batch_shape
=
build_shape_head
(
shape1
)
if
_dim2
>
_dim1
:
# broadcast
shape1
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
,
idx
=-
_dim1
),
# shape2[:-_dim1]
shape1
,
)
inp1
=
f
(
builtin
.
Broadcast
(),
inp1
,
shape1
)
batch_shape
=
build_shape_head
(
shape2
)
if
_dim1
==
_dim2
:
batch_shape
=
build_shape_head
(
shape1
)
# compress inputs to 3d
if
maxdim
>
3
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape1
),
),
)
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
BatchedMatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
.
value
,
)
result
=
f
(
op
,
inp1
,
inp2
)
if
maxdim
>
3
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
batch_shape
,
build_shape_tail
(
f
(
builtin
.
GetVarShape
(),
result
)),
),
)
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedBatchedMatrixMulOp
class
_Hashable
:
def
__init__
(
self
,
value
)
->
None
:
self
.
value
=
value
def
__hash__
(
self
)
->
int
:
return
hash
(
str
(
self
.
value
))
def
__eq__
(
self
,
o
:
object
)
->
bool
:
if
not
isinstance
(
o
,
_Hashable
):
return
False
return
self
.
value
==
o
.
value
def
_matmul
(
inp1
,
inp2
,
transpose_a
=
False
,
transpose_b
=
False
,
compute_mode
=
"default"
,
format
=
"default"
,
):
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp1
,
inp2
=
cast_tensors
(
inp1
,
inp2
)
else
:
compute_mode
=
"default"
dtype
=
dtype_promotion
(
inp1
,
inp2
)
if
inp1
.
dtype
!=
dtype
:
inp1
=
inp1
.
astype
(
dtype
)
if
inp2
.
dtype
!=
dtype
:
inp2
=
inp2
.
astype
(
dtype
)
dim1
,
dim2
=
inp1
.
ndim
,
inp2
.
ndim
assert
dim1
>
0
and
dim2
>
0
maxdim
=
dim1
if
dim1
>
dim2
else
dim2
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
op
=
builtin
.
MatrixMul
(
transposeA
=
False
,
transposeB
=
False
,
compute_mode
=
compute_mode
,
format
=
"default"
)
(
result
,)
=
apply
(
op
,
inp1
,
inp2
)
return
result
Strategy
=
builtin
.
ops
.
MatrixMul
.
Strategy
strategy
=
Strategy
(
0
)
if
_config
.
_benchmark_kernel
:
strategy
|=
Strategy
.
PROFILE
else
:
strategy
|=
Strategy
.
HEURISTIC
if
_config
.
_deterministic_kernel
:
strategy
|=
Strategy
.
REPRODUCIBLE
if
dim1
==
1
and
dim2
==
1
:
# dispatch to Dot
(
result
,)
=
apply
(
builtin
.
Dot
(),
inp1
,
inp2
)
return
result
elif
maxdim
<=
2
or
dim2
<=
2
:
# dispath to MatrixMul
extentedMatrixMulOp
=
_get_extentedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
)
(
result
,)
=
apply
(
extentedMatrixMulOp
(),
inp1
,
inp2
)
return
result
else
:
# dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
strategy
),
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
(),
inp1
,
inp2
)
return
result
def
_transpose
(
data
,
axes
):
...
...
imperative/python/megengine/functional/math.py
浏览文件 @
ca2deebc
...
...
@@ -8,24 +8,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
math
from
functools
import
lru_cache
from
typing
import
Iterable
,
Optional
,
Sequence
,
Tuple
,
Union
from
..core
import
_config
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core._trace_option
import
use_symbolic_shape
from
..core.ops
import
builtin
from
..core.ops.builtin
import
BatchNorm
,
Elemwise
,
GetVarShape
,
Reduce
,
TypeCvt
from
..core.ops.special
import
Const
from
..core.tensor
import
amp
from
..core.tensor.utils
import
_normalize_axis
,
cast_tensors
,
subgraph
from
..jit
import
exclude_from_trace
from
..core.tensor.array_method
import
_matmul
from
..core.tensor.utils
import
_normalize_axis
from
..tensor
import
Tensor
from
..utils.deprecation
import
deprecated_kwargs_default
from
.debug_param
import
get_execution_strategy
from
.elemwise
import
clip
,
minimum
from
.tensor
import
broadcast_to
,
concat
,
expand_dims
,
squeeze
from
.elemwise
import
clip
from
.tensor
import
expand_dims
,
squeeze
__all__
=
[
"argmax"
,
...
...
@@ -794,229 +788,6 @@ def matinv(inp: Tensor) -> Tensor:
return
result
class
_Hashable
:
def
__init__
(
self
,
value
)
->
None
:
self
.
value
=
value
def
__hash__
(
self
)
->
int
:
return
hash
(
str
(
self
.
value
))
def
__eq__
(
self
,
o
:
object
)
->
bool
:
if
not
isinstance
(
o
,
_Hashable
):
return
False
return
self
.
value
==
o
.
value
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
1
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
1
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
GetVarShape
(),
inp1
)
shape2
=
f
(
GetVarShape
(),
inp2
)
if
_dim1
>
2
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape1
)),
build_shape_tail
(
shape1
),
),
)
if
_dim2
>
2
:
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
build_shape_head
(
shape2
)),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
MatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
.
value
,
)
result
=
f
(
op
,
inp1
,
inp2
)
result_shape
=
f
(
GetVarShape
(),
result
)
if
_dim1
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
),
build_shape_tail
(
result_shape
),
),
)
if
_dim2
>
2
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
),
build_shape_tail
(
result_shape
),
),
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedMatrixMulOp
@
lru_cache
(
maxsize
=
None
)
def
_get_extentedBatchedMatrixMulOp
(
device
,
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
,
):
@
subgraph
(
"extentedBatchedMatrixMulOp"
,
dtype
,
device
,
2
,
gopt_level
=
2
)
def
extentedBatchedMatrixMulOp
(
inputs
,
f
,
c
):
assert
len
(
inputs
)
==
2
inp1
,
inp2
=
inputs
_dim1
,
_dim2
=
dim1
,
dim2
def
build_shape_head
(
shape
,
idx
=-
2
):
# shape[:idx]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
False
,
True
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
def
build_shape_tail
(
shape
,
idx
=-
2
):
# shape[idx:]
return
f
(
builtin
.
Subtensor
(
items
=
[[
0
,
True
,
False
,
False
,
False
]]),
shape
,
c
(
idx
,
"int32"
),
)
remove_row
,
remove_col
=
False
,
False
if
_dim1
==
1
:
_dim1
=
2
remove_row
=
True
if
_dim2
==
1
:
_dim2
=
2
remove_col
=
True
if
remove_row
:
inp1
=
f
(
builtin
.
AddAxis
(
axis
=
[
0
,]),
inp1
)
if
remove_col
:
inp2
=
f
(
builtin
.
AddAxis
(
axis
=
[
1
,]),
inp2
)
shape1
=
f
(
GetVarShape
(),
inp1
)
shape2
=
f
(
GetVarShape
(),
inp2
)
maxdim
=
_dim1
if
_dim1
>
_dim2
else
_dim2
if
_dim1
>
_dim2
:
# broadcast
shape2
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape1
,
idx
=-
_dim2
),
# shape1[:-_dim2]
shape2
,
)
inp2
=
f
(
builtin
.
Broadcast
(),
inp2
,
shape2
)
batch_shape
=
build_shape_head
(
shape1
)
if
_dim2
>
_dim1
:
# broadcast
shape1
=
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
build_shape_head
(
shape2
,
idx
=-
_dim1
),
# shape2[:-_dim1]
shape1
,
)
inp1
=
f
(
builtin
.
Broadcast
(),
inp1
,
shape1
)
batch_shape
=
build_shape_head
(
shape2
)
if
_dim1
==
_dim2
:
batch_shape
=
build_shape_head
(
shape1
)
# compress inputs to 3d
if
maxdim
>
3
:
inp1
=
f
(
builtin
.
Reshape
(),
inp1
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape1
),
),
)
inp2
=
f
(
builtin
.
Reshape
(),
inp2
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
f
(
builtin
.
Reduce
(
mode
=
"product"
,
axis
=
0
),
batch_shape
),
build_shape_tail
(
shape2
),
),
)
op
=
builtin
.
BatchedMatrixMul
(
transposeA
=
transpose_a
,
transposeB
=
transpose_b
,
compute_mode
=
compute_mode
,
format
=
format
,
strategy
=
strategy
.
value
,
)
result
=
f
(
op
,
inp1
,
inp2
)
if
maxdim
>
3
:
result
=
f
(
builtin
.
Reshape
(),
result
,
f
(
builtin
.
Concat
(
axis
=
0
,
comp_node
=
device
),
batch_shape
,
build_shape_tail
(
f
(
GetVarShape
(),
result
)),
),
)
if
remove_row
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
2
]),
result
)
if
remove_col
:
result
=
f
(
builtin
.
RemoveAxis
(
axis
=
[
maxdim
-
1
]),
result
)
return
(
result
,),
(
True
,)
return
extentedBatchedMatrixMulOp
def
matmul
(
inp1
:
Tensor
,
inp2
:
Tensor
,
...
...
@@ -1067,50 +838,7 @@ def matmul(
[[10. 13.]
[28. 40.]]
"""
if
amp
.
_enabled
:
compute_mode
=
"float32"
inp1
,
inp2
=
cast_tensors
(
inp1
,
inp2
)
else
:
dtype
=
dtype_promotion
(
inp1
,
inp2
)
if
inp1
.
dtype
!=
dtype
:
inp1
=
inp1
.
astype
(
dtype
)
if
inp2
.
dtype
!=
dtype
:
inp2
=
inp2
.
astype
(
dtype
)
dim1
,
dim2
=
inp1
.
ndim
,
inp2
.
ndim
assert
dim1
>
0
and
dim2
>
0
maxdim
=
dim1
if
dim1
>
dim2
else
dim2
compute_mode
=
_config
.
_get_actual_op_param
(
compute_mode
,
_config
.
__compute_mode
)
if
dim1
==
1
and
dim2
==
1
:
# dispatch to Dot
return
dot
(
inp1
,
inp2
)
elif
maxdim
<=
2
or
dim2
<=
2
:
# dispath to MatrixMul
extentedMatrixMulOp
=
_get_extentedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
get_execution_strategy
()),
)
(
result
,)
=
apply
(
extentedMatrixMulOp
(),
inp1
,
inp2
)
return
result
else
:
# dispath to BatchedMatrixMul
extentedBatchedMatrixMulOp
=
_get_extentedBatchedMatrixMulOp
(
inp1
.
device
,
inp1
.
dtype
,
dim1
,
dim2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
,
strategy
=
_Hashable
(
get_execution_strategy
()),
)
(
result
,)
=
apply
(
extentedBatchedMatrixMulOp
(),
inp1
,
inp2
)
return
result
return
_matmul
(
inp1
,
inp2
,
transpose_a
,
transpose_b
,
compute_mode
,
format
)
def
dot
(
inp1
:
Tensor
,
inp2
:
Tensor
)
->
Tensor
:
...
...
imperative/python/test/unit/core/test_tensor_wrapper.py
浏览文件 @
ca2deebc
...
...
@@ -46,14 +46,17 @@ def test_literal_arith(is_varnode):
@
pytest
.
mark
.
parametrize
(
"is_varnode"
,
[
True
,
False
])
def
test_matmul
(
is_varnode
):
@
pytest
.
mark
.
parametrize
(
"shape_a, shape_b"
,
[((
4
,),
(
4
,)),
((
10
,
4
),
(
4
,
10
)),
((
3
,
10
,
4
),
(
3
,
4
,
10
)),],
)
def
test_matmul
(
is_varnode
,
shape_a
,
shape_b
):
if
is_varnode
:
network
=
Network
()
else
:
network
=
None
A
=
make_tensor
(
np
.
random
.
rand
(
5
,
7
).
astype
(
"float32"
),
network
)
B
=
make_tensor
(
np
.
random
.
rand
(
7
,
10
).
astype
(
"float32"
),
network
)
A
=
make_tensor
(
np
.
random
.
rand
(
*
shape_a
).
astype
(
"float32"
),
network
)
B
=
make_tensor
(
np
.
random
.
rand
(
*
shape_b
).
astype
(
"float32"
),
network
)
C
=
A
@
B
if
is_varnode
:
np
.
testing
.
assert_almost_equal
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录