未验证 提交 8c9c81cc 编写于 作者: F Feiyu Chan 提交者: GitHub

add doc for is_complex and is_integer and expose them as public APIs (#38158)

上级 a615002a
......@@ -66,6 +66,8 @@ import paddle.vision # noqa: F401
from .tensor.random import bernoulli # noqa: F401
from .tensor.attribute import is_complex # noqa: F401
from .tensor.attribute import is_integer # noqa: F401
from .tensor.attribute import rank # noqa: F401
from .tensor.attribute import shape # noqa: F401
from .tensor.attribute import real # noqa: F401
......@@ -379,6 +381,8 @@ __all__ = [ # noqa
'equal',
'equal_all',
'is_tensor',
'is_complex',
'is_integer',
'cross',
'where',
'log1p',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
class TestIsComplex(unittest.TestCase):
def test_for_integer(self):
x = paddle.arange(10)
self.assertFalse(paddle.is_complex(x))
def test_for_floating_point(self):
x = paddle.randn([2, 3])
self.assertFalse(paddle.is_complex(x))
def test_for_complex(self):
x = paddle.randn([2, 3]) + 1j * paddle.randn([2, 3])
self.assertTrue(paddle.is_complex(x))
def test_for_exception(self):
with self.assertRaises(TypeError):
paddle.is_complex(np.array([1, 2]))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import unittest
class TestIsInteger(unittest.TestCase):
def test_for_integer(self):
x = paddle.arange(10)
self.assertTrue(paddle.is_integer(x))
def test_for_floating_point(self):
x = paddle.randn([2, 3])
self.assertFalse(paddle.is_integer(x))
def test_for_complex(self):
x = paddle.randn([2, 3]) + 1j * paddle.randn([2, 3])
self.assertFalse(paddle.is_integer(x))
def test_for_exception(self):
with self.assertRaises(TypeError):
paddle.is_integer(np.array([1, 2]))
if __name__ == '__main__':
unittest.main()
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .attribute import is_complex # noqa: F401
from .attribute import is_integer # noqa: F401
from .attribute import rank # noqa: F401
from .attribute import shape # noqa: F401
from .attribute import real # noqa: F401
......@@ -408,6 +410,8 @@ tensor_method_func = [ #noqa
'var',
'numel',
'median',
'is_complex',
'is_integer',
'rank',
'shape',
'real',
......
......@@ -21,6 +21,7 @@ from ..fluid.data_feeder import check_variable_and_dtype
# TODO: define functions to get tensor attributes
from ..fluid.layers import rank # noqa: F401
from ..fluid.layers import shape # noqa: F401
import paddle
from paddle import _C_ops
__all__ = []
......@@ -45,6 +46,34 @@ def _real_to_complex_dtype(dtype):
def is_complex(x):
"""Return whether x is a tensor of complex data type(complex64 or complex128).
Args:
x (Tensor): The input tensor.
Returns:
bool: True if the data type of the input is complex data type, otherwise false.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1 + 2j, 3 + 4j])
print(paddle.is_complex(x))
# True
x = paddle.to_tensor([1.1, 1.2])
print(paddle.is_complex(x))
# False
x = paddle.to_tensor([1, 2, 3])
print(paddle.is_complex(x))
# False
"""
if not isinstance(x, (paddle.Tensor, paddle.static.Variable)):
raise TypeError("Expected Tensor, but received type of x: {}".format(
type(x)))
dtype = x.dtype
is_complex_dtype = (dtype == core.VarDesc.VarType.COMPLEX64 or
dtype == core.VarDesc.VarType.COMPLEX128)
......@@ -61,6 +90,34 @@ def is_floating_point(x):
def is_integer(x):
"""Return whether x is a tensor of integeral data type.
Args:
x (Tensor): The input tensor.
Returns:
bool: True if the data type of the input is integer data type, otherwise false.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1 + 2j, 3 + 4j])
print(paddle.is_integer(x))
# False
x = paddle.to_tensor([1.1, 1.2])
print(paddle.is_integer(x))
# False
x = paddle.to_tensor([1, 2, 3])
print(paddle.is_integer(x))
# True
"""
if not isinstance(x, (paddle.Tensor, paddle.static.Variable)):
raise TypeError("Expected Tensor, but received type of x: {}".format(
type(x)))
dtype = x.dtype
is_int_dtype = (dtype == core.VarDesc.VarType.UINT8 or
dtype == core.VarDesc.VarType.INT8 or
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册