diff --git a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py index e04bf0b2b7058f70d63c8290692fe66a3c3384be..7254dd9df31821c2c8653676f22c98030943aa2b 100644 --- a/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py +++ b/python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py @@ -97,26 +97,26 @@ class TestModelCastBF16(unittest.TestCase): with self.static_graph(): t_bf16 = layers.data( - name='t_bf16', shape=[size, size], dtype=np.uint16 + name='t_bf16', shape=[size, size], dtype=np.int32 ) tt_bf16 = layers.data( - name='tt_bf16', shape=[size, size], dtype=np.uint16 + name='tt_bf16', shape=[size, size], dtype=np.int32 ) t = layers.data(name='t', shape=[size, size], dtype='float32') tt = layers.data(name='tt', shape=[size, size], dtype='float32') - ret = layers.elementwise_add(t, tt) - ret = layers.elementwise_mul(ret, t) + ret = paddle.add(t, tt) + ret = paddle.multiply(ret, t) ret = paddle.reshape(ret, [0, 0]) with amp.bf16.bf16_guard(): - ret_bf16 = layers.elementwise_add(t_bf16, tt_bf16) - ret_bf16 = layers.elementwise_mul(ret_bf16, t_bf16) + ret_bf16 = paddle.add(t_bf16, tt_bf16) + ret_bf16 = paddle.multiply(ret_bf16, t_bf16) ret_bf16 = paddle.reshape(ret_bf16, [0, 0]) with amp.bf16.bf16_guard(): - ret_fp32bf16 = layers.elementwise_add(t, tt) - ret_fp32bf16 = layers.elementwise_mul(ret_fp32bf16, t) + ret_fp32bf16 = paddle.add(t, tt) + ret_fp32bf16 = paddle.multiply(ret_fp32bf16, t) ret_fp32bf16 = paddle.reshape(ret_fp32bf16, [0, 0]) ( @@ -147,11 +147,11 @@ class TestModelCastBF16(unittest.TestCase): tt = layers.data(name='tt', shape=[size, size], dtype='float32') with amp.bf16.bf16_guard(): - ret = layers.elementwise_add(t, tt) + ret = paddle.add(t, tt) ret = paddle.reshape(ret, [0, 0]) ret = paddle.nn.functional.elu(ret) - ret = layers.elementwise_mul(ret, t) - ret = layers.elementwise_add(ret, tt) + ret = paddle.multiply(ret, t) + ret = paddle.add(ret, tt) static_ret_bf16 = self.get_static_graph_result( feed={'t': n, 'tt': nn},