未验证 提交 b8b0712c 编写于 作者: Y YuhangLi 提交者: GitHub

[AMP OP&Test]numel op fp/bf 16 support (#52172)

* [AMP OP&Test]numel op fp/bf 16 support

* dtype update

* remove err case
上级 f2c96bc2
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
from paddle.fluid import core
class TestNumelOp(OpTest):
......@@ -26,7 +27,7 @@ class TestNumelOp(OpTest):
self.op_type = "size"
self.python_api = paddle.numel
self.init()
x = np.random.random((self.shape)).astype("float64")
x = np.random.random((self.shape)).astype(self.dtype)
self.inputs = {
'Input': x,
}
......@@ -37,16 +38,65 @@ class TestNumelOp(OpTest):
def init(self):
self.shape = (6, 56, 8, 55)
self.dtype = np.float64
class TestNumelOp1(TestNumelOp):
def init(self):
self.shape = (11, 66)
self.dtype = np.float64
class TestNumelOp2(TestNumelOp):
def init(self):
self.shape = (0,)
self.dtype = np.float64
class TestNumelOpFP16(TestNumelOp):
def init(self):
self.dtype = np.float16
self.shape = (6, 56, 8, 55)
class TestNumelOp1FP16(TestNumelOp):
def init(self):
self.dtype = np.float16
self.shape = (11, 66)
class TestNumelOp2FP16(TestNumelOp):
def init(self):
self.dtype = np.float16
self.shape = (0,)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class TestNumelOpBF16(OpTest):
def setUp(self):
self.op_type = "size"
self.python_api = paddle.numel
self.dtype = np.uint16
self.init()
x = np.random.random((self.shape)).astype(np.float32)
self.inputs = {'Input': convert_float_to_uint16(x)}
self.outputs = {'Out': np.array([np.size(x)])}
def test_check_output(self):
place = paddle.CUDAPlace(0)
self.check_output_with_place(place)
def init(self):
self.shape = (6, 56, 8, 55)
class TestNumelOp1BF16(TestNumelOpBF16):
def init(self):
self.shape = (11, 66)
class TestNumelAPI(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册