Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2bfee7d3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2bfee7d3
编写于
10月 25, 2021
作者:
F
From00
提交者:
GitHub
10月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick] Add new API 'tensordot' (#36273) (#36454)
* Add new API tensordot cherry-pick #36273
上级
8c0bacd4
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
451 addition
and
0 deletion
+451
-0
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_tensordot.py
python/paddle/fluid/tests/unittests/test_tensordot.py
+238
-0
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+2
-0
python/paddle/tensor/manipulation.py
python/paddle/tensor/manipulation.py
+208
-0
未找到文件。
python/paddle/__init__.py
浏览文件 @
2bfee7d3
...
@@ -152,6 +152,7 @@ from .tensor.manipulation import unbind # noqa: F401
...
@@ -152,6 +152,7 @@ from .tensor.manipulation import unbind # noqa: F401
from
.tensor.manipulation
import
roll
# noqa: F401
from
.tensor.manipulation
import
roll
# noqa: F401
from
.tensor.manipulation
import
chunk
# noqa: F401
from
.tensor.manipulation
import
chunk
# noqa: F401
from
.tensor.manipulation
import
tolist
# noqa: F401
from
.tensor.manipulation
import
tolist
# noqa: F401
from
.tensor.manipulation
import
tensordot
# noqa: F401
from
.tensor.math
import
abs
# noqa: F401
from
.tensor.math
import
abs
# noqa: F401
from
.tensor.math
import
acos
# noqa: F401
from
.tensor.math
import
acos
# noqa: F401
from
.tensor.math
import
asin
# noqa: F401
from
.tensor.math
import
asin
# noqa: F401
...
@@ -470,6 +471,7 @@ __all__ = [ # noqa
...
@@ -470,6 +471,7 @@ __all__ = [ # noqa
'bmm'
,
'bmm'
,
'chunk'
,
'chunk'
,
'tolist'
,
'tolist'
,
'tensordot'
,
'greater_than'
,
'greater_than'
,
'shard_index'
,
'shard_index'
,
'argsort'
,
'argsort'
,
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
2bfee7d3
...
@@ -1038,3 +1038,4 @@ if(WITH_GPU OR WITH_ROCM)
...
@@ -1038,3 +1038,4 @@ if(WITH_GPU OR WITH_ROCM)
endif
()
endif
()
set_tests_properties
(
test_inplace_addto_strategy PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_inplace_addto_strategy PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_eigvals_op PROPERTIES TIMEOUT 400
)
set_tests_properties
(
test_eigvals_op PROPERTIES TIMEOUT 400
)
set_tests_properties
(
test_tensordot PROPERTIES TIMEOUT 1000
)
python/paddle/fluid/tests/unittests/test_tensordot.py
0 → 100644
浏览文件 @
2bfee7d3
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
unittest
import
paddle.fluid.core
as
core
import
numpy
as
np
import
itertools
as
it
np
.
set_printoptions
(
threshold
=
np
.
inf
)
def
tensordot_np
(
x
,
y
,
axes
):
if
isinstance
(
axes
,
paddle
.
fluid
.
framework
.
Variable
):
axes
=
axes
.
tolist
()
# np.tensordot does not support empty axes
if
not
axes
:
axes
=
0
if
(
isinstance
(
axes
,
(
tuple
,
list
))):
if
all
(
np
.
issubdtype
(
type
(
i
),
np
.
integer
)
for
i
in
axes
):
axes
=
[
axes
,
axes
]
else
:
axes_x
=
axes
[
0
]
if
len
(
axes
)
>
1
:
axes_y
=
axes
[
1
]
else
:
axes_y
=
axes_x
len_axes_x
,
len_axes_y
=
len
(
axes_x
),
len
(
axes_y
)
if
len_axes_x
<
len_axes_y
:
axes_x
=
axes_x
+
axes_y
[
len_axes_x
:]
elif
len_axes_y
<
len_axes_x
:
axes_y
=
axes_y
+
axes_x
[
len_axes_y
:]
axes
=
[
axes_x
,
axes_y
]
# np.tensordot does not support broadcast
if
(
isinstance
(
axes
,
(
tuple
,
list
))):
axes_x
,
axes_y
=
axes
else
:
axes_x
=
list
(
range
(
x
.
ndim
-
axes
,
x
.
ndim
))
axes_y
=
list
(
range
(
axes
))
shape_x
,
shape_y
=
list
(
np
.
shape
(
x
)),
list
(
np
.
shape
(
y
))
for
i
in
range
(
len
(
axes_x
)):
dim_x
,
dim_y
=
axes_x
[
i
],
axes_y
[
i
]
sx
,
sy
=
shape_x
[
dim_x
],
shape_y
[
dim_y
]
if
sx
==
1
:
shape_y
[
dim_y
]
=
1
y
=
np
.
sum
(
y
,
dim_y
)
y
=
np
.
reshape
(
y
,
shape_y
)
elif
sy
==
1
:
shape_x
[
dim_x
]
=
1
x
=
np
.
sum
(
x
,
dim_x
)
x
=
np
.
reshape
(
x
,
shape_x
)
return
np
.
tensordot
(
x
,
y
,
axes
)
class
TestTensordotAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
set_dtype
()
self
.
set_input_shape
()
self
.
set_input_data
()
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
set_input_shape
(
self
):
self
.
x_shape
=
[
5
,
5
,
5
,
5
]
self
.
y_shape
=
[
5
,
5
,
5
,
5
]
def
set_input_data
(
self
):
self
.
x
=
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
self
.
all_axes
=
[
2
]
def
run_dygraph
(
self
,
place
):
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
self
.
x
,
place
=
place
)
y
=
paddle
.
to_tensor
(
self
.
y
,
place
=
place
)
paddle_res
=
paddle
.
tensordot
(
x
,
y
,
self
.
axes
)
np_res
=
tensordot_np
(
self
.
x
,
self
.
y
,
self
.
axes
)
np
.
testing
.
assert_allclose
(
paddle_res
,
np_res
,
rtol
=
1e-6
)
def
run_static
(
self
,
place
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
self
.
x_shape
,
dtype
=
self
.
dtype
)
y
=
paddle
.
static
.
data
(
name
=
'y'
,
shape
=
self
.
y_shape
,
dtype
=
self
.
dtype
)
z
=
paddle
.
tensordot
(
x
,
y
,
self
.
axes
)
exe
=
paddle
.
static
.
Executor
(
place
)
paddle_res
=
exe
.
run
(
feed
=
{
'x'
:
self
.
x
,
'y'
:
self
.
y
},
fetch_list
=
[
z
])
np_res
=
tensordot_np
(
self
.
x
,
self
.
y
,
self
.
axes
)
np
.
testing
.
assert_allclose
(
paddle_res
[
0
],
np_res
,
rtol
=
1e-6
)
def
test_cases
(
self
):
self
.
all_axes
=
[]
axial_index
=
range
(
4
)
all_permutations
=
list
(
it
.
permutations
(
axial_index
,
0
))
+
list
(
it
.
permutations
(
axial_index
,
1
))
+
list
(
it
.
permutations
(
axial_index
,
2
))
+
list
(
it
.
permutations
(
axial_index
,
3
))
+
list
(
it
.
permutations
(
axial_index
,
4
))
self
.
all_axes
.
extend
(
list
(
i
)
for
i
in
all_permutations
)
for
axes_x
in
all_permutations
:
for
axes_y
in
all_permutations
:
if
len
(
axes_x
)
<
len
(
axes_y
):
supplementary_axes_x
=
axes_x
+
axes_y
[
len
(
axes_x
):]
if
any
(
supplementary_axes_x
.
count
(
i
)
>
1
for
i
in
supplementary_axes_x
):
continue
elif
len
(
axes_y
)
<
len
(
axes_x
):
supplementary_axes_y
=
axes_y
+
axes_x
[
len
(
axes_y
):]
if
any
(
supplementary_axes_y
.
count
(
i
)
>
1
for
i
in
supplementary_axes_y
):
continue
self
.
all_axes
.
append
([
list
(
axes_x
),
list
(
axes_y
)])
self
.
all_axes
.
extend
(
range
(
5
))
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
axes
in
self
.
all_axes
:
self
.
axes
=
axes
for
place
in
places
:
self
.
run_dygraph
(
place
)
self
.
run_static
(
place
)
class
TestTensordotAPIFloat64
(
TestTensordotAPI
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float64
class
TestTensordotAPIAxesType
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
3
,
4
,
4
]
self
.
y_shape
=
[
4
,
4
,
5
]
def
test_cases
(
self
):
self
.
all_axes
=
[
0
,
1
,
2
,
(
1
,
),
[
1
],
((
1
,
),
),
([
1
],
),
((
2
,
1
),
(
0
,
)),
(
(
1
,
2
),
(
0
,
1
)),
([
1
,
2
],
[
0
,
1
]),
([
1
,
2
],
[
0
,
1
]),
[[
1
,
2
],
[
0
,
1
]]
]
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
axes
in
self
.
all_axes
:
self
.
axes
=
axes
for
place
in
places
:
self
.
run_dygraph
(
place
)
self
.
run_static
(
place
)
# The 'axes' with type 'Tensor' in tensordot is not available in static mode
paddle
.
disable_static
()
for
place
in
places
:
self
.
all_axes
=
[
paddle
.
to_tensor
([
1
]),
(
paddle
.
to_tensor
([
1
])),
(
paddle
.
to_tensor
([
1
,
2
]),
paddle
.
to_tensor
([
0
,
1
])),
[
paddle
.
to_tensor
([
1
,
2
]),
paddle
.
to_tensor
([
0
,
1
])],
paddle
.
to_tensor
([[
1
,
2
],
[
0
,
1
]])
]
for
axes
in
self
.
all_axes
:
self
.
axes
=
axes
for
place
in
places
:
self
.
run_dygraph
(
place
)
def
test_error
(
self
):
self
.
all_axes
=
[[[[
0
],
[
1
]]],
0.1
,
-
1
,
100
,
[[
1
,
2
],
[
0
,
0
]],
[[
1
,
2
],
[
0
,
-
1
]],
[
0
,
1
,
2
,
3
]]
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
self
.
x
)
y
=
paddle
.
to_tensor
(
self
.
y
)
for
axes
in
self
.
all_axes
:
with
self
.
assertRaises
(
BaseException
):
paddle
.
tensordot
(
x
,
y
,
axes
)
class
TestTensordotAPIAxesTypeFloat64
(
TestTensordotAPIAxesType
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float64
class
TestTensordotAPIBroadcastCase1
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
1
,
1
,
1
,
5
]
self
.
y_shape
=
[
1
,
5
,
1
,
1
]
class
TestTensordotAPIBroadcastCase2
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
1
,
5
,
5
,
5
]
self
.
y_shape
=
[
1
,
1
,
1
,
5
]
class
TestTensordotAPIBroadcastCase3
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
5
,
5
,
5
,
1
]
self
.
y_shape
=
[
5
,
5
,
1
,
5
]
class
TestTensordotAPIBroadcastCase4
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
5
,
5
,
5
,
1
]
self
.
y_shape
=
[
1
,
1
,
1
,
1
]
class
TestTensordotAPIBroadcastCase5
(
TestTensordotAPI
):
def
set_input_shape
(
self
):
self
.
x_shape
=
[
1
,
1
,
5
,
5
]
self
.
y_shape
=
[
5
,
5
,
1
,
5
]
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/tensor/__init__.py
浏览文件 @
2bfee7d3
...
@@ -105,6 +105,7 @@ from .manipulation import flip # noqa: F401
...
@@ -105,6 +105,7 @@ from .manipulation import flip # noqa: F401
from
.manipulation
import
unbind
# noqa: F401
from
.manipulation
import
unbind
# noqa: F401
from
.manipulation
import
roll
# noqa: F401
from
.manipulation
import
roll
# noqa: F401
from
.manipulation
import
chunk
# noqa: F401
from
.manipulation
import
chunk
# noqa: F401
from
.manipulation
import
tensordot
# noqa: F401
from
.math
import
abs
# noqa: F401
from
.math
import
abs
# noqa: F401
from
.math
import
acos
# noqa: F401
from
.math
import
acos
# noqa: F401
from
.math
import
asin
# noqa: F401
from
.math
import
asin
# noqa: F401
...
@@ -346,6 +347,7 @@ tensor_method_func = [ #noqa
...
@@ -346,6 +347,7 @@ tensor_method_func = [ #noqa
'slice'
,
'slice'
,
'split'
,
'split'
,
'chunk'
,
'chunk'
,
'tensordot'
,
'squeeze'
,
'squeeze'
,
'squeeze_'
,
'squeeze_'
,
'stack'
,
'stack'
,
...
...
python/paddle/tensor/manipulation.py
浏览文件 @
2bfee7d3
...
@@ -2173,3 +2173,211 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
...
@@ -2173,3 +2173,211 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
return
paddle
.
fluid
.
layers
.
strided_slice
(
return
paddle
.
fluid
.
layers
.
strided_slice
(
input
=
x
,
axes
=
axes
,
starts
=
starts
,
ends
=
ends
,
strides
=
strides
)
input
=
x
,
axes
=
axes
,
starts
=
starts
,
ends
=
ends
,
strides
=
strides
)
def
tensordot
(
x
,
y
,
axes
=
2
,
name
=
None
):
r
"""
This function computes a contraction, which sum the product of elements from two tensors along the given axes.
Args:
x (Tensor): The left tensor for contraction with data type ``float32`` or ``float64``.
y (Tensor): The right tensor for contraction with the same data type as ``x``.
axes (int|tuple|list|Tensor, optional): The axes to contract for ``x`` and ``y``, defaulted to integer ``2``.
1. It could be a non-negative integer ``n``,
in which the function will sum over the last ``n`` axes of ``x`` and the first ``n`` axes of ``y`` in order.
2. It could be a 1-d tuple or list with data type ``int``, in which ``x`` and ``y`` will be contracted along the same given axes.
For example, ``axes`` =[0, 1] applies contraction along the first two axes for ``x`` and the first two axes for ``y``.
3. It could be a tuple or list containing one or two 1-d tuple|list|Tensor with data type ``int``.
When containing one tuple|list|Tensor, the data in tuple|list|Tensor specified the same axes for ``x`` and ``y`` to contract.
When containing two tuple|list|Tensor, the first will be applied to ``x`` and the second to ``y``.
When containing more than two tuple|list|Tensor, only the first two axis sequences will be used while the others will be ignored.
4. It could be a tensor, in which the ``axes`` tensor will be translated to a python list
and applied the same rules described above to determine the contraction axes.
Note that the ``axes`` with Tensor type is ONLY available in Dygraph mode.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Return:
Output (Tensor): The contraction result with the same data type as ``x`` and ``y``.
In general, :math:`output.ndim = x.ndim + y.ndim - 2 \times n_{axes}`, where :math:`n_{axes}` denotes the number of axes to be contracted.
NOTES:
1. This function supports tensor broadcast,
the size in the corresponding dimensions of ``x`` and ``y`` should be equal, or applies to the broadcast rules.
2. This function also supports axes expansion,
when the two given axis sequences for ``x`` and ``y`` are of different lengths,
the shorter sequence will expand the same axes as the longer one at the end.
For example, if ``axes`` =[[0, 1, 2, 3], [1, 0]],
the axis sequence for ``x`` is [0, 1, 2, 3],
while the corresponding axis sequences for ``y`` will be expanded from [1, 0] to [1, 0, 2, 3].
Examples:
.. code-block:: python
import paddle
data_type = 'float64'
# For two 2-d tensor x and y, the case axes=0 is equivalent to outer product.
# Note that tensordot supports empty axis sequence, so all the axes=0, axes=[], axes=[[]], and axes=[[],[]] are equivalent cases.
x = paddle.arange(4, dtype=data_type).reshape([2, 2])
y = paddle.arange(4, dtype=data_type).reshape([2, 2])
z = paddle.tensordot(x, y, axes=0)
# z = [[[[0., 0.],
# [0., 0.]],
#
# [[0., 1.],
# [2., 3.]]],
#
#
# [[[0., 2.],
# [4., 6.]],
#
# [[0., 3.],
# [6., 9.]]]]
# For two 1-d tensor x and y, the case axes=1 is equivalent to inner product.
x = paddle.arange(10, dtype=data_type)
y = paddle.arange(10, dtype=data_type)
z1 = paddle.tensordot(x, y, axes=1)
z2 = paddle.dot(x, y)
# z1 = z2 = [285.]
# For two 2-d tensor x and y, the case axes=1 is equivalent to matrix multiplication.
x = paddle.arange(6, dtype=data_type).reshape([2, 3])
y = paddle.arange(12, dtype=data_type).reshape([3, 4])
z1 = paddle.tensordot(x, y, axes=1)
z2 = paddle.matmul(x, y)
# z1 = z2 = [[20., 23., 26., 29.],
# [56., 68., 80., 92.]]
# When axes is a 1-d int list, x and y will be contracted along the same given axes.
# Note that axes=[1, 2] is equivalent to axes=[[1, 2]], axes=[[1, 2], []], axes=[[1, 2], [1]], and axes=[[1, 2], [1, 2]].
x = paddle.arange(24, dtype=data_type).reshape([2, 3, 4])
y = paddle.arange(36, dtype=data_type).reshape([3, 3, 4])
z = paddle.tensordot(x, y, axes=[1, 2])
# z = [[506. , 1298., 2090.],
# [1298., 3818., 6338.]]
# When axes is a list containing two 1-d int list, the first will be applied to x and the second to y.
x = paddle.arange(60, dtype=data_type).reshape([3, 4, 5])
y = paddle.arange(24, dtype=data_type).reshape([4, 3, 2])
z = paddle.tensordot(x, y, axes=([1, 0], [0, 1]))
# z = [[4400., 4730.],
# [4532., 4874.],
# [4664., 5018.],
# [4796., 5162.],
# [4928., 5306.]]
# Thanks to the support of axes expansion, axes=[[0, 1, 3, 4], [1, 0, 3, 4]] can be abbreviated as axes= [[0, 1, 3, 4], [1, 0]].
x = paddle.arange(720, dtype=data_type).reshape([2, 3, 4, 5, 6])
y = paddle.arange(720, dtype=data_type).reshape([3, 2, 4, 5, 6])
z = paddle.tensordot(x, y, axes=[[0, 1, 3, 4], [1, 0]])
# z = [[23217330., 24915630., 26613930., 28312230.],
# [24915630., 26775930., 28636230., 30496530.],
# [26613930., 28636230., 30658530., 32680830.],
# [28312230., 30496530., 32680830., 34865130.]]
"""
op_type
=
'tensordot'
input_dtype
=
[
'float32'
,
'float64'
]
check_variable_and_dtype
(
x
,
'x'
,
input_dtype
,
op_type
)
check_variable_and_dtype
(
y
,
'y'
,
input_dtype
,
op_type
)
check_type
(
axes
,
'axes'
,
(
int
,
tuple
,
list
,
Variable
),
op_type
)
def
_var_to_list
(
var
):
if
in_dygraph_mode
():
return
tolist
(
var
)
raise
TypeError
(
"The 'axes' with type 'Tensor' in "
+
op_type
+
" is not available in static graph mode, "
"please convert its type to int|Tuple|List, or use dynamic graph mode."
)
axes_x
=
[]
axes_y
=
[]
if
np
.
issubdtype
(
type
(
axes
),
np
.
integer
):
assert
axes
>=
0
,
(
"The 'axes' in "
+
op_type
+
f
" should not be negative, but received axes=
{
axes
}
."
)
axes_x
=
range
(
x
.
ndim
-
axes
,
x
.
ndim
)
axes_y
=
range
(
axes
)
else
:
if
isinstance
(
axes
,
Variable
):
axes
=
_var_to_list
(
axes
)
if
not
axes
or
np
.
issubdtype
(
type
(
axes
[
0
]),
np
.
integer
):
axes_x
=
axes
else
:
axes_x
=
axes
[
0
]
if
len
(
axes
)
>
1
:
axes_y
=
axes
[
1
]
if
isinstance
(
axes_x
,
Variable
):
axes_x
=
_var_to_list
(
axes_x
)
if
isinstance
(
axes_y
,
Variable
):
axes_y
=
_var_to_list
(
axes_y
)
axes_x
,
axes_y
=
list
(
axes_x
),
list
(
axes_y
)
len_axes_x
,
len_axes_y
=
len
(
axes_x
),
len
(
axes_y
)
if
len_axes_x
<
len_axes_y
:
axes_x
.
extend
(
axes_y
[
len_axes_x
:])
elif
len_axes_y
<
len_axes_x
:
axes_y
.
extend
(
axes_x
[
len_axes_y
:])
shape_x
,
shape_y
=
list
(
x
.
shape
),
list
(
y
.
shape
)
need_contracted_dim_x
=
np
.
zeros
((
x
.
ndim
),
dtype
=
bool
)
need_contracted_dim_y
=
np
.
zeros
((
y
.
ndim
),
dtype
=
bool
)
contraction_size
=
1
for
i
in
range
(
len
(
axes_x
)):
dim_x
,
dim_y
=
axes_x
[
i
],
axes_y
[
i
]
sx
,
sy
=
shape_x
[
dim_x
],
shape_y
[
dim_y
]
if
sx
==
1
:
shape_y
[
dim_y
]
=
1
y
=
y
.
sum
(
dim_y
).
reshape
(
shape_y
)
elif
sy
==
1
:
shape_x
[
dim_x
]
=
1
x
=
x
.
sum
(
dim_x
).
reshape
(
shape_x
)
else
:
assert
sx
==
sy
,
"The dimensional size for 'x' and 'y' in "
+
op_type
+
f
" should match each other, but 'x' has size
{
sx
}
in dim
{
dim_x
}
while 'y' has size
{
sy
}
in dim
{
dim_y
}
."
need_contracted_dim_x
[
dim_x
]
=
True
need_contracted_dim_y
[
dim_y
]
=
True
contraction_size
*=
shape_x
[
dim_x
]
perm_x
=
[]
perm_y
=
[]
shape_out
=
[]
not_contraction_size_x
=
1
not_contraction_size_y
=
1
for
i
in
range
(
x
.
ndim
):
if
not
need_contracted_dim_x
[
i
]:
perm_x
.
append
(
i
)
shape_out
.
append
(
shape_x
[
i
])
not_contraction_size_x
*=
shape_x
[
i
]
perm_x
.
extend
(
axes_x
)
perm_y
.
extend
(
axes_y
)
for
i
in
range
(
y
.
ndim
):
if
not
need_contracted_dim_y
[
i
]:
perm_y
.
append
(
i
)
shape_out
.
append
(
shape_y
[
i
])
not_contraction_size_y
*=
shape_y
[
i
]
if
not
shape_out
:
shape_out
=
[
1
]
x
=
x
.
transpose
(
perm
=
perm_x
).
reshape
(
[
not_contraction_size_x
,
contraction_size
])
y
=
y
.
transpose
(
perm
=
perm_y
).
reshape
(
[
contraction_size
,
not_contraction_size_y
])
out
=
x
.
matmul
(
y
).
reshape
(
shape_out
)
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录