Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2bfee7d3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2bfee7d3
编写于
10月 25, 2021
作者:
F
From00
提交者:
GitHub
10月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick] Add new API 'tensordot' (#36273) (#36454)
* Add new API tensordot cherry-pick #36273
上级
8c0bacd4
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
451 addition
and
0 deletion
+451
-0
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_tensordot.py
python/paddle/fluid/tests/unittests/test_tensordot.py
+238
-0
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+2
-0
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+208
-0
未找到文件。
python/paddle/__init__.py
浏览文件 @
2bfee7d3
...
...
@@ -152,6 +152,7 @@ from .tensor.manipulation import unbind # noqa: F401
from
.tensor.manipulation
import
roll
# noqa: F401
from
.tensor.manipulation
import
chunk
# noqa: F401
from
.tensor.manipulation
import
tolist
# noqa: F401
from
.tensor.manipulation
import
tensordot
# noqa: F401
from
.tensor.math
import
abs
# noqa: F401
from
.tensor.math
import
acos
# noqa: F401
from
.tensor.math
import
asin
# noqa: F401
...
...
@@ -470,6 +471,7 @@ __all__ = [ # noqa
'bmm'
,
'chunk'
,
'tolist'
,
'tensordot'
,
'greater_than'
,
'shard_index'
,
'argsort'
,
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
2bfee7d3
...
...
@@ -1038,3 +1038,4 @@ if(WITH_GPU OR WITH_ROCM)
endif
()
set_tests_properties
(
test_inplace_addto_strategy PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_eigvals_op PROPERTIES TIMEOUT 400
)
set_tests_properties
(
test_tensordot PROPERTIES TIMEOUT 1000
)
python/paddle/fluid/tests/unittests/test_tensordot.py
0 → 100644
浏览文件 @
2bfee7d3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
unittest
import
paddle.fluid.core
as
core
import
numpy
as
np
import
itertools
as
it
np
.
set_printoptions
(
threshold
=
np
.
inf
)
def
tensordot_np
(
x
,
y
,
axes
):
if
isinstance
(
axes
,
paddle
.
fluid
.
framework
.
Variable
):
axes
=
axes
.
tolist
()
# np.tensordot does not support empty axes
if
not
axes
:
axes
=
0
if
(
isinstance
(
axes
,
(
tuple
,
list
))):
if
all
(
np
.
issubdtype
(
type
(
i
),
np
.
integer
)
for
i
in
axes
):
axes
=
[
axes
,
axes
]
else
:
axes_x
=
axes
[
0
]
if
len
(
axes
)
>
1
:
axes_y
=
axes
[
1
]
else
:
axes_y
=
axes_x
len_axes_x
,
len_axes_y
=
len
(
axes_x
),
len
(
axes_y
)
if
len_axes_x
<
len_axes_y
:
axes_x
=
axes_x
+
axes_y
[
len_axes_x
:]
elif
len_axes_y
<
len_axes_x
:
axes_y
=
axes_y
+
axes_x
[
len_axes_y
:]
axes
=
[
axes_x
,
axes_y
]
# np.tensordot does not support broadcast
if
(
isinstance
(
axes
,
(
tuple
,
list
))):
axes_x
,
axes_y
=
axes
else
:
axes_x
=
list
(
range
(
x
.
ndim
-
axes
,
x
.
ndim
))
axes_y
=
list
(
range
(
axes
))
shape_x
,
shape_y
=
list
(
np
.
shape
(
x
)),
list
(
np
.
shape
(
y
))
for
i
in
range
(
len
(
axes_x
)):
dim_x
,
dim_y
=
axes_x
[
i
],
axes_y
[
i
]
sx
,
sy
=
shape_x
[
dim_x
],
shape_y
[
dim_y
]
if
sx
==
1
:
shape_y
[
dim_y
]
=
1
y
=
np
.
sum
(
y
,
dim_y
)
y
=
np
.
reshape
(
y
,
shape_y
)
elif
sy
==
1
:
shape_x
[
dim_x
]
=
1
x
=
np
.
sum
(
x
,
dim_x
)
x
=
np
.
reshape
(
x
,
shape_x
)
return
np
.
tensordot
(
x
,
y
,
axes
)
class
TestTensordotAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
set_dtype
()
self
.
set_input_shape
()
self
.
set_input_data
()
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
set_input_shape
(
self
):
self
.
x_shape
=
[
5
,
5
,
5
,
5
]
self
.
y_shape
=
[
5
,
5
,
5
,
5
]
def
set_input_data
(
self
):
self
.
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
self
.
all_axes
=
[
2
]
def
run_dygraph
(
self
,
place
):
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
self
.
x
,
place
=
place
)
y
=
paddle
.
to_tensor
(
self
.
y
,
place
=
place
)
paddle_res
=
paddle
.
tensordot
(
x
,
y
,
self
.
axes
)
np_res
=
tensordot_np
(
self
.
x
,
self
.
y
,
self
.
axes
)
np
.
testing
.
assert_allclose
(
paddle_res
,
np_res
,
rtol
=
1e-6
)
def
run_static
(
self
,
place
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
self
.
x_shape
,
dtype
=
self
.
dtype
)
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
self
.
y_shape
,
dtype
=
self
.
dtype
)
z
=
paddle
.
tensordot
(
x
,
y
,
self
.
axes
)
exe
=
paddle
.
static
.
Executor
(
place
)
paddle_res
=
exe
.
run
(
feed
=
{
'x'
:
self
.
x
,
'y'
:
self
.
y
},
fetch_list
=
[
z
])
np_res
=
tensordot_np
(
self
.
x
,
self
.
y
,
self
.
axes
)
np
.
testing
.
assert_allclose
(
paddle_res
[
0
],
np_res
,
rtol
=
1e-6
)
def
test_cases
(
self
):
self
.
all_axes
=
[]
axial_index
=
range
(
4
)
all_permutations
=
list
(
it
.
permutations
(
axial_index
,
0
))
+
list
(
it
.
permutations
(
axial_index
,
1
))
+
list
(
it
.
permutations
(
axial_index
,
2
))
+
list
(
it
.
permutations
(
axial_index
,
3
))
+
list
(
it
.
permutations
(
axial_index
,
4
))
self
.
all_axes
.
extend
(
list
(
i
)
for
i
in
all_permutations
)
for
axes_x
in
all_permutations
:
for
axes_y
in
all_permutations
:
if
len
(
axes_x
)
<
len
(
axes_y
):
supplementary_axes_x
=
axes_x
+
axes_y
[
len
(
axes_x
):]
if
any
(
supplementary_axes_x
.
count
(
i
)
>
1
for
i
in
supplementary_axes_x
):
continue
elif
len
(
axes_y
)
<
len
(
axes_x
):
supplementary_axes_y
=
axes_y
+
axes_x
[
len
(
axes_y
):]
if
any
(
supplementary_axes_y
.
count
(
i
)
>
1
for
i
in
supplementary_axes_y
):
continue
self
.
all_axes
.
append
([
list
(
axes_x
),
list
(
axes_y
)])
self
.
all_axes
.
extend
(
range
(
5
))
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
axes
in
self
.
all_axes
:
self
.
axes
=
axes
for
place
in
places
:
self
.
run_dygraph
(
place
)
self
.
run_static
(
place
)
class
TestTensordotAPIFloat64
(
TestTensordotAPI
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float64
class
TestTensordotAPIAxesType
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
3
,
4
,
4
]
self
.
y_shape
=
[
4
,
4
,
5
]
def
test_cases
(
self
):
self
.
all_axes
=
[
0
,
1
,
2
,
(
1
,
),
[
1
],
((
1
,
),
),
([
1
],
),
((
2
,
1
),
(
0
,
)),
(
(
1
,
2
),
(
0
,
1
)),
([
1
,
2
],
[
0
,
1
]),
([
1
,
2
],
[
0
,
1
]),
[[
1
,
2
],
[
0
,
1
]]
]
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
axes
in
self
.
all_axes
:
self
.
axes
=
axes
for
place
in
places
:
self
.
run_dygraph
(
place
)
self
.
run_static
(
place
)
# The 'axes' with type 'Tensor' in tensordot is not available in static mode
paddle
.
disable_static
()
for
place
in
places
:
self
.
all_axes
=
[
paddle
.
to_tensor
([
1
]),
(
paddle
.
to_tensor
([
1
])),
(
paddle
.
to_tensor
([
1
,
2
]),
paddle
.
to_tensor
([
0
,
1
])),
[
paddle
.
to_tensor
([
1
,
2
]),
paddle
.
to_tensor
([
0
,
1
])],
paddle
.
to_tensor
([[
1
,
2
],
[
0
,
1
]])
]
for
axes
in
self
.
all_axes
:
self
.
axes
=
axes
for
place
in
places
:
self
.
run_dygraph
(
place
)
def
test_error
(
self
):
self
.
all_axes
=
[[[[
0
],
[
1
]]],
0.1
,
-
1
,
100
,
[[
1
,
2
],
[
0
,
0
]],
[[
1
,
2
],
[
0
,
-
1
]],
[
0
,
1
,
2
,
3
]]
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
self
.
x
)
y
=
paddle
.
to_tensor
(
self
.
y
)
for
axes
in
self
.
all_axes
:
with
self
.
assertRaises
(
BaseException
):
paddle
.
tensordot
(
x
,
y
,
axes
)
class
TestTensordotAPIAxesTypeFloat64
(
TestTensordotAPIAxesType
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float64
class
TestTensordotAPIBroadcastCase1
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
1
,
1
,
1
,
5
]
self
.
y_shape
=
[
1
,
5
,
1
,
1
]
class
TestTensordotAPIBroadcastCase2
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
1
,
5
,
5
,
5
]
self
.
y_shape
=
[
1
,
1
,
1
,
5
]
class
TestTensordotAPIBroadcastCase3
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
5
,
5
,
5
,
1
]
self
.
y_shape
=
[
5
,
5
,
1
,
5
]
class
TestTensordotAPIBroadcastCase4
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
5
,
5
,
5
,
1
]
self
.
y_shape
=
[
1
,
1
,
1
,
1
]
class
TestTensordotAPIBroadcastCase5
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
1
,
1
,
5
,
5
]
self
.
y_shape
=
[
5
,
5
,
1
,
5
]
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/tensor/__init__.py
浏览文件 @
2bfee7d3
...
...
@@ -105,6 +105,7 @@ from .manipulation import flip # noqa: F401
from
.manipulation
import
unbind
# noqa: F401
from
.manipulation
import
roll
# noqa: F401
from
.manipulation
import
chunk
# noqa: F401
from
.manipulation
import
tensordot
# noqa: F401
from
.math
import
abs
# noqa: F401
from
.math
import
acos
# noqa: F401
from
.math
import
asin
# noqa: F401
...
...
@@ -346,6 +347,7 @@ tensor_method_func = [ #noqa
'slice'
,
'split'
,
'chunk'
,
'tensordot'
,
'squeeze'
,
'squeeze_'
,
'stack'
,
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
2bfee7d3
...
...
@@ -2173,3 +2173,211 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
return
paddle
.
fluid
.
layers
.
strided_slice
(
input
=
x
,
axes
=
axes
,
starts
=
starts
,
ends
=
ends
,
strides
=
strides
)
def
tensordot
(
x
,
y
,
axes
=
2
,
name
=
None
):
r
"""
This function computes a contraction, which sum the product of elements from two tensors along the given axes.
Args:
x (Tensor): The left tensor for contraction with data type ``float32`` or ``float64``.
y (Tensor): The right tensor for contraction with the same data type as ``x``.
axes (int|tuple|list|Tensor, optional): The axes to contract for ``x`` and ``y``, defaulted to integer ``2``.
1. It could be a non-negative integer ``n``,
in which the function will sum over the last ``n`` axes of ``x`` and the first ``n`` axes of ``y`` in order.
2. It could be a 1-d tuple or list with data type ``int``, in which ``x`` and ``y`` will be contracted along the same given axes.
For example, ``axes`` =[0, 1] applies contraction along the first two axes for ``x`` and the first two axes for ``y``.
3. It could be a tuple or list containing one or two 1-d tuple|list|Tensor with data type ``int``.
When containing one tuple|list|Tensor, the data in tuple|list|Tensor specified the same axes for ``x`` and ``y`` to contract.
When containing two tuple|list|Tensor, the first will be applied to ``x`` and the second to ``y``.
When containing more than two tuple|list|Tensor, only the first two axis sequences will be used while the others will be ignored.
4. It could be a tensor, in which the ``axes`` tensor will be translated to a python list
and applied the same rules described above to determine the contraction axes.
Note that the ``axes`` with Tensor type is ONLY available in Dygraph mode.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Return:
Output (Tensor): The contraction result with the same data type as ``x`` and ``y``.
In general, :math:`output.ndim = x.ndim + y.ndim - 2 \times n_{axes}`, where :math:`n_{axes}` denotes the number of axes to be contracted.
NOTES:
1. This function supports tensor broadcast,
the size in the corresponding dimensions of ``x`` and ``y`` should be equal, or applies to the broadcast rules.
2. This function also supports axes expansion,
when the two given axis sequences for ``x`` and ``y`` are of different lengths,
the shorter sequence will expand the same axes as the longer one at the end.
For example, if ``axes`` =[[0, 1, 2, 3], [1, 0]],
the axis sequence for ``x`` is [0, 1, 2, 3],
while the corresponding axis sequences for ``y`` will be expanded from [1, 0] to [1, 0, 2, 3].
Examples:
.. code-block:: python
import paddle
data_type = 'float64'
# For two 2-d tensor x and y, the case axes=0 is equivalent to outer product.
# Note that tensordot supports empty axis sequence, so all the axes=0, axes=[], axes=[[]], and axes=[[],[]] are equivalent cases.
x = paddle.arange(4, dtype=data_type).reshape([2, 2])
y = paddle.arange(4, dtype=data_type).reshape([2, 2])
z = paddle.tensordot(x, y, axes=0)
# z = [[[[0., 0.],
# [0., 0.]],
#
# [[0., 1.],
# [2., 3.]]],
#
#
# [[[0., 2.],
# [4., 6.]],
#
# [[0., 3.],
# [6., 9.]]]]
# For two 1-d tensor x and y, the case axes=1 is equivalent to inner product.
x = paddle.arange(10, dtype=data_type)
y = paddle.arange(10, dtype=data_type)
z1 = paddle.tensordot(x, y, axes=1)
z2 = paddle.dot(x, y)
# z1 = z2 = [285.]
# For two 2-d tensor x and y, the case axes=1 is equivalent to matrix multiplication.
x = paddle.arange(6, dtype=data_type).reshape([2, 3])
y = paddle.arange(12, dtype=data_type).reshape([3, 4])
z1 = paddle.tensordot(x, y, axes=1)
z2 = paddle.matmul(x, y)
# z1 = z2 = [[20., 23., 26., 29.],
# [56., 68., 80., 92.]]
# When axes is a 1-d int list, x and y will be contracted along the same given axes.
# Note that axes=[1, 2] is equivalent to axes=[[1, 2]], axes=[[1, 2], []], axes=[[1, 2], [1]], and axes=[[1, 2], [1, 2]].
x = paddle.arange(24, dtype=data_type).reshape([2, 3, 4])
y = paddle.arange(36, dtype=data_type).reshape([3, 3, 4])
z = paddle.tensordot(x, y, axes=[1, 2])
# z = [[506. , 1298., 2090.],
# [1298., 3818., 6338.]]
# When axes is a list containing two 1-d int list, the first will be applied to x and the second to y.
x = paddle.arange(60, dtype=data_type).reshape([3, 4, 5])
y = paddle.arange(24, dtype=data_type).reshape([4, 3, 2])
z = paddle.tensordot(x, y, axes=([1, 0], [0, 1]))
# z = [[4400., 4730.],
# [4532., 4874.],
# [4664., 5018.],
# [4796., 5162.],
# [4928., 5306.]]
# Thanks to the support of axes expansion, axes=[[0, 1, 3, 4], [1, 0, 3, 4]] can be abbreviated as axes= [[0, 1, 3, 4], [1, 0]].
x = paddle.arange(720, dtype=data_type).reshape([2, 3, 4, 5, 6])
y = paddle.arange(720, dtype=data_type).reshape([3, 2, 4, 5, 6])
z = paddle.tensordot(x, y, axes=[[0, 1, 3, 4], [1, 0]])
# z = [[23217330., 24915630., 26613930., 28312230.],
# [24915630., 26775930., 28636230., 30496530.],
# [26613930., 28636230., 30658530., 32680830.],
# [28312230., 30496530., 32680830., 34865130.]]
"""
op_type
=
'tensordot'
input_dtype
=
[
'float32'
,
'float64'
]
check_variable_and_dtype
(
x
,
'x'
,
input_dtype
,
op_type
)
check_variable_and_dtype
(
y
,
'y'
,
input_dtype
,
op_type
)
check_type
(
axes
,
'axes'
,
(
int
,
tuple
,
list
,
Variable
),
op_type
)
def
_var_to_list
(
var
):
if
in_dygraph_mode
():
return
tolist
(
var
)
raise
TypeError
(
"The 'axes' with type 'Tensor' in "
+
op_type
+
" is not available in static graph mode, "
"please convert its type to int|Tuple|List, or use dynamic graph mode."
)
axes_x
=
[]
axes_y
=
[]
if
np
.
issubdtype
(
type
(
axes
),
np
.
integer
):
assert
axes
>=
0
,
(
"The 'axes' in "
+
op_type
+
f
" should not be negative, but received axes=
{
axes
}
."
)
axes_x
=
range
(
x
.
ndim
-
axes
,
x
.
ndim
)
axes_y
=
range
(
axes
)
else
:
if
isinstance
(
axes
,
Variable
):
axes
=
_var_to_list
(
axes
)
if
not
axes
or
np
.
issubdtype
(
type
(
axes
[
0
]),
np
.
integer
):
axes_x
=
axes
else
:
axes_x
=
axes
[
0
]
if
len
(
axes
)
>
1
:
axes_y
=
axes
[
1
]
if
isinstance
(
axes_x
,
Variable
):
axes_x
=
_var_to_list
(
axes_x
)
if
isinstance
(
axes_y
,
Variable
):
axes_y
=
_var_to_list
(
axes_y
)
axes_x
,
axes_y
=
list
(
axes_x
),
list
(
axes_y
)
len_axes_x
,
len_axes_y
=
len
(
axes_x
),
len
(
axes_y
)
if
len_axes_x
<
len_axes_y
:
axes_x
.
extend
(
axes_y
[
len_axes_x
:])
elif
len_axes_y
<
len_axes_x
:
axes_y
.
extend
(
axes_x
[
len_axes_y
:])
shape_x
,
shape_y
=
list
(
x
.
shape
),
list
(
y
.
shape
)
need_contracted_dim_x
=
np
.
zeros
((
x
.
ndim
),
dtype
=
bool
)
need_contracted_dim_y
=
np
.
zeros
((
y
.
ndim
),
dtype
=
bool
)
contraction_size
=
1
for
i
in
range
(
len
(
axes_x
)):
dim_x
,
dim_y
=
axes_x
[
i
],
axes_y
[
i
]
sx
,
sy
=
shape_x
[
dim_x
],
shape_y
[
dim_y
]
if
sx
==
1
:
shape_y
[
dim_y
]
=
1
y
=
y
.
sum
(
dim_y
).
reshape
(
shape_y
)
elif
sy
==
1
:
shape_x
[
dim_x
]
=
1
x
=
x
.
sum
(
dim_x
).
reshape
(
shape_x
)
else
:
assert
sx
==
sy
,
"The dimensional size for 'x' and 'y' in "
+
op_type
+
f
" should match each other, but 'x' has size
{
sx
}
in dim
{
dim_x
}
while 'y' has size
{
sy
}
in dim
{
dim_y
}
."
need_contracted_dim_x
[
dim_x
]
=
True
need_contracted_dim_y
[
dim_y
]
=
True
contraction_size
*=
shape_x
[
dim_x
]
perm_x
=
[]
perm_y
=
[]
shape_out
=
[]
not_contraction_size_x
=
1
not_contraction_size_y
=
1
for
i
in
range
(
x
.
ndim
):
if
not
need_contracted_dim_x
[
i
]:
perm_x
.
append
(
i
)
shape_out
.
append
(
shape_x
[
i
])
not_contraction_size_x
*=
shape_x
[
i
]
perm_x
.
extend
(
axes_x
)
perm_y
.
extend
(
axes_y
)
for
i
in
range
(
y
.
ndim
):
if
not
need_contracted_dim_y
[
i
]:
perm_y
.
append
(
i
)
shape_out
.
append
(
shape_y
[
i
])
not_contraction_size_y
*=
shape_y
[
i
]
if
not
shape_out
:
shape_out
=
[
1
]
x
=
x
.
transpose
(
perm
=
perm_x
).
reshape
(
[
not_contraction_size_x
,
contraction_size
])
y
=
y
.
transpose
(
perm
=
perm_y
).
reshape
(
[
contraction_size
,
not_contraction_size_y
])
out
=
x
.
matmul
(
y
).
reshape
(
shape_out
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录