未验证 提交 780c7a1d 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support test_var_base bf16 case (#41377)

* [Eager]Polish enable/disable_legacy_dygraph logic

* fix test_var_base print_tensor

* fix bug caused by arange

* Updated bf16 cast case

* BF16 astype to float32
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Npangyoki <pangyoki@126.com>
Co-authored-by: Nzyfncg <zhangyunfei07@baidu.com>
上级 08811d9b
......@@ -996,7 +996,7 @@ class TestVarBase(unittest.TestCase):
self.assertListEqual(list(var_base.shape), list(static_var.shape))
def test_tensor_str(self):
def func_test_tensor_str(self):
paddle.enable_static()
paddle.disable_static(paddle.CPUPlace())
paddle.seed(10)
......@@ -1016,7 +1016,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str2(self):
def test_tensor_str(self):
with _test_eager_guard():
self.func_test_tensor_str()
self.func_test_tensor_str()
def func_test_tensor_str2(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor([[1.5111111, 1.0], [0, 0]])
a_str = str(a)
......@@ -1028,7 +1033,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str3(self):
def test_tensor_str2(self):
with _test_eager_guard():
self.func_test_tensor_str2()
self.func_test_tensor_str2()
def func_test_tensor_str3(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor([[-1.5111111, 1.0], [0, -0.5]])
a_str = str(a)
......@@ -1040,7 +1050,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_scaler(self):
def test_tensor_str3(self):
with _test_eager_guard():
self.func_test_tensor_str3()
self.func_test_tensor_str3()
def func_test_tensor_str_scaler(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor(np.array(False))
a_str = str(a)
......@@ -1051,7 +1066,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_shape_with_zero(self):
def test_tensor_str_scaler(self):
with _test_eager_guard():
self.func_test_tensor_str_scaler()
self.func_test_tensor_str_scaler()
def func_test_tensor_str_shape_with_zero(self):
paddle.disable_static(paddle.CPUPlace())
x = paddle.ones((10, 10))
y = paddle.fluid.layers.where(x == 0)
......@@ -1063,7 +1083,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_linewidth(self):
def test_tensor_str_shape_with_zero(self):
with _test_eager_guard():
self.func_test_tensor_str_shape_with_zero()
self.func_test_tensor_str_shape_with_zero()
def func_test_tensor_str_linewidth(self):
paddle.disable_static(paddle.CPUPlace())
paddle.seed(2021)
x = paddle.rand([128])
......@@ -1091,7 +1116,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_linewidth2(self):
def test_tensor_str_linewidth(self):
with _test_eager_guard():
self.func_test_tensor_str_linewidth()
self.func_test_tensor_str_linewidth()
def func_test_tensor_str_linewidth2(self):
paddle.disable_static(paddle.CPUPlace())
paddle.seed(2021)
x = paddle.rand([128])
......@@ -1114,7 +1144,12 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_bf16(self):
def test_tensor_str_linewidth2(self):
with _test_eager_guard():
self.func_test_tensor_str_linewidth2()
self.func_test_tensor_str_linewidth2()
def func_tensor_str_bf16(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.to_tensor([[1.5, 1.0], [0, 0]])
a = paddle.cast(a, dtype=core.VarDesc.VarType.BF16)
......@@ -1128,6 +1163,11 @@ class TestVarBase(unittest.TestCase):
self.assertEqual(a_str, expected)
paddle.enable_static()
def test_tensor_str_bf16(self):
with _test_eager_guard():
self.func_tensor_str_bf16()
self.func_tensor_str_bf16()
def test_print_tensor_dtype(self):
paddle.disable_static(paddle.CPUPlace())
a = paddle.rand([1])
......
......@@ -264,6 +264,9 @@ def to_string(var, prefix='Tensor'):
def _format_dense_tensor(tensor, indent):
if tensor.dtype == core.VarDesc.VarType.BF16:
tensor = tensor.astype('float32')
np_tensor = tensor.numpy()
if len(tensor.shape) == 0:
......@@ -330,6 +333,10 @@ def sparse_tensor_to_string(tensor, prefix='Tensor'):
def tensor_to_string(tensor, prefix='Tensor'):
indent = len(prefix) + 1
dtype = convert_dtype(tensor.dtype)
if tensor.dtype == core.VarDesc.VarType.BF16:
dtype = 'bfloat16'
_template = "{prefix}(shape={shape}, dtype={dtype}, place={place}, stop_gradient={stop_gradient},\n{indent}{data})"
if tensor.is_sparse():
......@@ -342,7 +349,7 @@ def tensor_to_string(tensor, prefix='Tensor'):
return _template.format(
prefix=prefix,
shape=tensor.shape,
dtype=tensor.dtype,
dtype=dtype,
place=tensor._place_str,
stop_gradient=tensor.stop_gradient,
indent=' ' * indent,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册