diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 21827166d188275fc5d189adb02252e92734acd3..3640dd22bb0cd7a45bc129471d0ecd28823197dd 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -108,6 +108,7 @@ from .tensor.logic import not_equal #DEFINE_ALIAS from .tensor.logic import allclose #DEFINE_ALIAS from .tensor.logic import equal_all #DEFINE_ALIAS # from .tensor.logic import isnan #DEFINE_ALIAS +from .tensor.logic import is_tensor #DEFINE_ALIAS from .tensor.manipulation import cast #DEFINE_ALIAS from .tensor.manipulation import concat #DEFINE_ALIAS from .tensor.manipulation import expand #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_is_tensor.py b/python/paddle/fluid/tests/unittests/test_is_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..97d6c60d631d3d4a0da343d10f8d05dea625aaf5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_is_tensor.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle + +DELTA = 0.00001 + + +class TestIsTensorApi(unittest.TestCase): + def test_is_tensor_real(self, dtype="float32"): + """Test is_tensor api with a real tensor + """ + paddle.disable_static() + x = paddle.rand([3, 2, 4], dtype=dtype) + self.assertTrue(paddle.is_tensor(x)) + + def test_is_tensor_complex(self, dtype="float32"): + """Test is_tensor api with a complex tensor + """ + paddle.disable_static() + r = paddle.to_tensor(1) + i = paddle.to_tensor(2) + x = paddle.ComplexTensor(r, i) + self.assertTrue(paddle.is_tensor(x)) + + def test_is_tensor_list(self, dtype="float32"): + """Test is_tensor api with a list + """ + paddle.disable_static() + x = [1, 2, 3] + self.assertFalse(paddle.is_tensor(x)) + + def test_is_tensor_number(self, dtype="float32"): + """Test is_tensor api with a number + """ + paddle.disable_static() + x = 5 + self.assertFalse(paddle.is_tensor(x)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 773e6ebc7af2ecd5ee453068b30bdc6de0bf6967..958bfb304fb149c36454163ece08f365b1021981 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -71,6 +71,7 @@ from .logic import not_equal #DEFINE_ALIAS from .logic import allclose #DEFINE_ALIAS from .logic import equal_all #DEFINE_ALIAS # from .logic import isnan #DEFINE_ALIAS +from .logic import is_tensor #DEFINE_ALIAS from .manipulation import cast #DEFINE_ALIAS from .manipulation import concat #DEFINE_ALIAS from .manipulation import expand #DEFINE_ALIAS diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 1fc1c17d2edb269bdad6822cbe124ac176a468fd..27671a4f157475e04250ba56fd378103f7d0d829 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -19,6 +19,8 @@ from ..fluid.layers.layer_function_generator import templatedoc from .. import fluid from ..fluid.framework import in_dygraph_mode from paddle.common_ops_import import * +from ..framework import VarBase as Tensor +from ..framework import ComplexVariable as ComplexTensor # TODO: define logic functions of a tensor from ..fluid.layers import is_empty #DEFINE_ALIAS @@ -43,6 +45,7 @@ __all__ = [ 'logical_xor', 'not_equal', 'allclose', + 'is_tensor' # 'isnan' ] @@ -372,3 +375,35 @@ def not_equal(x, y, name=None): """ out = fluid.layers.not_equal(x, y, name=name, cond=None) return out + + +def is_tensor(x): + """ + + This function tests whether input object is a paddle.Tensor or a paddle.ComplexTensor. + + Args: + x (object): Object to test. + + Returns: + A boolean value. True if 'x' is a paddle.Tensor or a paddle.ComplexTensor, otherwise False. + + Examples: + .. code-block:: python + + import paddle + + input1 = paddle.rand(shape=[2, 3, 5], dtype='float32') + check = paddle.is_tensor(input1) + print(check) #True + + input2 = paddle.ComplexTensor(input1, input1) + check = paddle.is_tensor(input2) + print(check) #True + + input3 = [1, 4] + check = paddle.is_tensor(input3) + print(check) #False + + """ + return isinstance(x, Tensor) or isinstance(x, ComplexTensor)