未验证 提交 446d184e 编写于 作者: Z zhulei 提交者: GitHub

Add new api: is_tensor (#28111)

* Add new api: is_tensor

* Add new api: is_tensor

* Add new api: is_tensor

* Add new api: is_tensor
上级 cd372447
......@@ -108,6 +108,7 @@ from .tensor.logic import not_equal #DEFINE_ALIAS
from .tensor.logic import allclose #DEFINE_ALIAS
from .tensor.logic import equal_all #DEFINE_ALIAS
# from .tensor.logic import isnan #DEFINE_ALIAS
from .tensor.logic import is_tensor #DEFINE_ALIAS
from .tensor.manipulation import cast #DEFINE_ALIAS
from .tensor.manipulation import concat #DEFINE_ALIAS
from .tensor.manipulation import expand #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
DELTA = 0.00001
class TestIsTensorApi(unittest.TestCase):
def test_is_tensor_real(self, dtype="float32"):
"""Test is_tensor api with a real tensor
"""
paddle.disable_static()
x = paddle.rand([3, 2, 4], dtype=dtype)
self.assertTrue(paddle.is_tensor(x))
def test_is_tensor_complex(self, dtype="float32"):
"""Test is_tensor api with a complex tensor
"""
paddle.disable_static()
r = paddle.to_tensor(1)
i = paddle.to_tensor(2)
x = paddle.ComplexTensor(r, i)
self.assertTrue(paddle.is_tensor(x))
def test_is_tensor_list(self, dtype="float32"):
"""Test is_tensor api with a list
"""
paddle.disable_static()
x = [1, 2, 3]
self.assertFalse(paddle.is_tensor(x))
def test_is_tensor_number(self, dtype="float32"):
"""Test is_tensor api with a number
"""
paddle.disable_static()
x = 5
self.assertFalse(paddle.is_tensor(x))
if __name__ == '__main__':
unittest.main()
......@@ -71,6 +71,7 @@ from .logic import not_equal #DEFINE_ALIAS
from .logic import allclose #DEFINE_ALIAS
from .logic import equal_all #DEFINE_ALIAS
# from .logic import isnan #DEFINE_ALIAS
from .logic import is_tensor #DEFINE_ALIAS
from .manipulation import cast #DEFINE_ALIAS
from .manipulation import concat #DEFINE_ALIAS
from .manipulation import expand #DEFINE_ALIAS
......
......@@ -19,6 +19,8 @@ from ..fluid.layers.layer_function_generator import templatedoc
from .. import fluid
from ..fluid.framework import in_dygraph_mode
from paddle.common_ops_import import *
from ..framework import VarBase as Tensor
from ..framework import ComplexVariable as ComplexTensor
# TODO: define logic functions of a tensor
from ..fluid.layers import is_empty #DEFINE_ALIAS
......@@ -43,6 +45,7 @@ __all__ = [
'logical_xor',
'not_equal',
'allclose',
'is_tensor'
# 'isnan'
]
......@@ -372,3 +375,35 @@ def not_equal(x, y, name=None):
"""
out = fluid.layers.not_equal(x, y, name=name, cond=None)
return out
def is_tensor(x):
"""
This function tests whether input object is a paddle.Tensor or a paddle.ComplexTensor.
Args:
x (object): Object to test.
Returns:
A boolean value. True if 'x' is a paddle.Tensor or a paddle.ComplexTensor, otherwise False.
Examples:
.. code-block:: python
import paddle
input1 = paddle.rand(shape=[2, 3, 5], dtype='float32')
check = paddle.is_tensor(input1)
print(check) #True
input2 = paddle.ComplexTensor(input1, input1)
check = paddle.is_tensor(input2)
print(check) #True
input3 = [1, 4]
check = paddle.is_tensor(input3)
print(check) #False
"""
return isinstance(x, Tensor) or isinstance(x, ComplexTensor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册