未验证 提交 fe1f9c97 编写于 作者: H HongyuJia 提交者: GitHub

clean unittest test_model_cast_to_bf16 (#48705)

上级 58f08924
...@@ -97,26 +97,26 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -97,26 +97,26 @@ class TestModelCastBF16(unittest.TestCase):
with self.static_graph(): with self.static_graph():
t_bf16 = layers.data( t_bf16 = layers.data(
name='t_bf16', shape=[size, size], dtype=np.uint16 name='t_bf16', shape=[size, size], dtype=np.int32
) )
tt_bf16 = layers.data( tt_bf16 = layers.data(
name='tt_bf16', shape=[size, size], dtype=np.uint16 name='tt_bf16', shape=[size, size], dtype=np.int32
) )
t = layers.data(name='t', shape=[size, size], dtype='float32') t = layers.data(name='t', shape=[size, size], dtype='float32')
tt = layers.data(name='tt', shape=[size, size], dtype='float32') tt = layers.data(name='tt', shape=[size, size], dtype='float32')
ret = layers.elementwise_add(t, tt) ret = paddle.add(t, tt)
ret = layers.elementwise_mul(ret, t) ret = paddle.multiply(ret, t)
ret = paddle.reshape(ret, [0, 0]) ret = paddle.reshape(ret, [0, 0])
with amp.bf16.bf16_guard(): with amp.bf16.bf16_guard():
ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16) ret_bf16 = paddle.add(t_bf16, tt_bf16)
ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16) ret_bf16 = paddle.multiply(ret_bf16, t_bf16)
ret_bf16 = paddle.reshape(ret_bf16, [0, 0]) ret_bf16 = paddle.reshape(ret_bf16, [0, 0])
with amp.bf16.bf16_guard(): with amp.bf16.bf16_guard():
ret_fp32bf16 = layers.elementwise_add(t, tt) ret_fp32bf16 = paddle.add(t, tt)
ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t) ret_fp32bf16 = paddle.multiply(ret_fp32bf16, t)
ret_fp32bf16 = paddle.reshape(ret_fp32bf16, [0, 0]) ret_fp32bf16 = paddle.reshape(ret_fp32bf16, [0, 0])
( (
...@@ -147,11 +147,11 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -147,11 +147,11 @@ class TestModelCastBF16(unittest.TestCase):
tt = layers.data(name='tt', shape=[size, size], dtype='float32') tt = layers.data(name='tt', shape=[size, size], dtype='float32')
with amp.bf16.bf16_guard(): with amp.bf16.bf16_guard():
ret = layers.elementwise_add(t, tt) ret = paddle.add(t, tt)
ret = paddle.reshape(ret, [0, 0]) ret = paddle.reshape(ret, [0, 0])
ret = paddle.nn.functional.elu(ret) ret = paddle.nn.functional.elu(ret)
ret = layers.elementwise_mul(ret, t) ret = paddle.multiply(ret, t)
ret = layers.elementwise_add(ret, tt) ret = paddle.add(ret, tt)
static_ret_bf16 = self.get_static_graph_result( static_ret_bf16 = self.get_static_graph_result(
feed={'t': n, 'tt': nn}, feed={'t': n, 'tt': nn},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册