未验证 提交 632a0064 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick] update multi_dot exposure rules (#36018) (#36131)

根据线性代数库的API暴露规则修改multi_dot的API暴露规则:
1、在python/paddle/tensor/linalg.py 路径下实现
2、在python/paddle/linalg.py 下import并加入__all__列表
3、在python/paddle/tensor/init.py下引入并加入tensor_method_func列表
4、删除了pythonpaddle/init.py的import
上级 c576169b
...@@ -387,6 +387,7 @@ tensor_method_func = [ #noqa ...@@ -387,6 +387,7 @@ tensor_method_func = [ #noqa
'bitwise_not', 'bitwise_not',
'broadcast_tensors', 'broadcast_tensors',
'uniform_', 'uniform_',
'multi_dot',
'solve', 'solve',
] ]
......
...@@ -975,8 +975,8 @@ def t(input, name=None): ...@@ -975,8 +975,8 @@ def t(input, name=None):
return out return out
check_variable_and_dtype( check_variable_and_dtype(
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'], input, 'input', ['float16', 'float32', 'float64', 'int32',
'transpose') 'int64'], 'transpose')
helper = LayerHelper('t', **locals()) helper = LayerHelper('t', **locals())
out = helper.create_variable_for_type_inference(input.dtype) out = helper.create_variable_for_type_inference(input.dtype)
...@@ -1360,7 +1360,7 @@ def det(x, name=None): ...@@ -1360,7 +1360,7 @@ def det(x, name=None):
Returns: Returns:
y (Tensor):the determinant value of a square matrix or batches of square matrices. y (Tensor):the determinant value of a square matrix or batches of square matrices.
Example: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -1415,7 +1415,7 @@ def slogdet(x, name=None): ...@@ -1415,7 +1415,7 @@ def slogdet(x, name=None):
y (Tensor): A tensor containing the sign of the determinant and the natural logarithm y (Tensor): A tensor containing the sign of the determinant and the natural logarithm
of the absolute value of determinant, respectively. of the absolute value of determinant, respectively.
Example: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
...@@ -1630,8 +1630,8 @@ def eigvals(x, name=None): ...@@ -1630,8 +1630,8 @@ def eigvals(x, name=None):
""" """
check_variable_and_dtype(x, 'dtype', check_variable_and_dtype(x, 'dtype',
['float32', 'float64', 'complex64', 'complex128'], ['float32', 'float64', 'complex64',
'eigvals') 'complex128'], 'eigvals')
x_shape = list(x.shape) x_shape = list(x.shape)
if len(x_shape) < 2: if len(x_shape) < 2:
...@@ -1657,7 +1657,7 @@ def multi_dot(x, name=None): ...@@ -1657,7 +1657,7 @@ def multi_dot(x, name=None):
""" """
Multi_dot is an operator that calculates multiple matrix multiplications. Multi_dot is an operator that calculates multiple matrix multiplications.
Supports inputs of float, double and float16 dtypes. This function does not Supports inputs of float16(only GPU support), float32 and float64 dtypes. This function does not
support batched inputs. support batched inputs.
The input tensor in [x] must be 2-D except for the first and last can be 1-D. The input tensor in [x] must be 2-D except for the first and last can be 1-D.
...@@ -1998,8 +1998,8 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None): ...@@ -1998,8 +1998,8 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
helper = LayerHelper('pinv', **locals()) helper = LayerHelper('pinv', **locals())
dtype = x.dtype dtype = x.dtype
check_variable_and_dtype( check_variable_and_dtype(
x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], x, 'dtype', ['float32', 'float64', 'complex64',
'pinv') 'complex128'], 'pinv')
if dtype == paddle.complex128: if dtype == paddle.complex128:
s_type = 'float64' s_type = 'float64'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册