未验证 提交 a4644c50 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[BUG] fix bug of float/int/long/index Tensor (#55568)

上级 ee506c2f
......@@ -114,21 +114,27 @@ def monkey_patch_math_tensor():
), "only one element variable can be converted to float."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
return float(np.array(var).flatten()[0])
if var.dtype == core.VarDesc.VarType.BF16:
var = var.astype('float32')
return float(np.array(var))
def _long_(var):
numel = np.prod(var.shape)
assert numel == 1, "only one element variable can be converted to long."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
return int(np.array(var).flatten()[0])
if var.dtype == core.VarDesc.VarType.BF16:
var = var.astype('float32')
return int(np.array(var))
def _int_(var):
numel = np.prod(var.shape)
assert numel == 1, "only one element variable can be converted to int."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
return int(np.array(var).flatten()[0])
if var.dtype == core.VarDesc.VarType.BF16:
var = var.astype('float32')
return int(np.array(var))
def _len_(var):
assert var.ndim > 0, "len() of a 0-D tensor is wrong"
......@@ -146,7 +152,9 @@ def monkey_patch_math_tensor():
), "only one element variable can be converted to python index."
tensor = var.value().get_tensor()
assert tensor._is_initialized(), "variable's tensor is not initialized"
return int(np.array(var).flatten()[0])
if var.dtype == core.VarDesc.VarType.BF16:
var = var.astype('float32')
return int(np.array(var))
@property
def _ndim_(var):
......
......@@ -242,6 +242,11 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
self.assertTrue(int(a) == 100)
self.assertTrue(int(a) == 100)
a = paddle.to_tensor(1000000.0, dtype='bfloat16')
self.assertTrue(float(a) == 999424.0)
self.assertTrue(int(a) == 999424)
self.assertTrue(int(a) == 999424)
def test_len(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
with fluid.dygraph.guard():
......@@ -260,6 +265,16 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
str1 = "just test"
self.assertTrue(str1[var1] == 's')
var1 = paddle.to_tensor(2.0, dtype='bfloat16')
i_tmp = 0
for i in range(var1):
self.assertTrue(i == i_tmp)
i_tmp = i_tmp + 1
list1 = [1, 2, 3, 4, 5]
self.assertTrue(list1[var1] == 3)
str1 = "just test"
self.assertTrue(str1[var1] == 's')
def test_np_left_mul(self):
with fluid.dygraph.guard():
t = np.sqrt(2.0 * np.pi)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册