未验证 提交 c72e2a15 编写于 作者: A Aurelius84 提交者: GitHub

[API]Support is_tensor() static branch (#50520)

上级 b8008580
......@@ -20,24 +20,43 @@ DELTA = 0.00001
class TestIsTensorApi(unittest.TestCase):
def setUp(self):
paddle.disable_static()
def tearDown(self):
paddle.enable_static()
def test_is_tensor_real(self, dtype="float32"):
"""Test is_tensor api with a real tensor"""
paddle.disable_static()
x = paddle.rand([3, 2, 4], dtype=dtype)
self.assertTrue(paddle.is_tensor(x))
def test_is_tensor_list(self, dtype="float32"):
"""Test is_tensor api with a list"""
paddle.disable_static()
x = [1, 2, 3]
self.assertFalse(paddle.is_tensor(x))
def test_is_tensor_number(self, dtype="float32"):
"""Test is_tensor api with a number"""
paddle.disable_static()
x = 5
self.assertFalse(paddle.is_tensor(x))
class TestIsTensorStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def tearDown(self):
paddle.disable_static()
def test_is_tensor(self):
x = paddle.rand([3, 2, 4], dtype='float32')
self.assertTrue(paddle.is_tensor(x))
def test_is_tensor_array(self):
x = paddle.tensor.create_array('float32')
self.assertTrue(paddle.is_tensor(x))
if __name__ == '__main__':
unittest.main()
......@@ -779,7 +779,10 @@ def is_tensor(x):
print(check) #False
"""
return isinstance(x, (Tensor, paddle.fluid.core.eager.Tensor))
if in_dygraph_mode():
return isinstance(x, (Tensor, paddle.fluid.core.eager.Tensor))
else:
return isinstance(x, Variable)
def _bitwise_op(op_name, x, y, out=None, name=None, binary_op=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册