diff --git a/python/paddle/fluid/tests/unittests/test_numel_op.py b/python/paddle/fluid/tests/unittests/test_numel_op.py index e15ffdfcd682dd7a5a2ff31fcf0eaea3531aad80..e4122bc0fb1b7d63ba998631d48374e480a24784 100644 --- a/python/paddle/fluid/tests/unittests/test_numel_op.py +++ b/python/paddle/fluid/tests/unittests/test_numel_op.py @@ -15,10 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid +from paddle.fluid import core class TestNumelOp(OpTest): @@ -26,7 +27,7 @@ class TestNumelOp(OpTest): self.op_type = "size" self.python_api = paddle.numel self.init() - x = np.random.random((self.shape)).astype("float64") + x = np.random.random((self.shape)).astype(self.dtype) self.inputs = { 'Input': x, } @@ -37,16 +38,65 @@ class TestNumelOp(OpTest): def init(self): self.shape = (6, 56, 8, 55) + self.dtype = np.float64 class TestNumelOp1(TestNumelOp): def init(self): self.shape = (11, 66) + self.dtype = np.float64 class TestNumelOp2(TestNumelOp): def init(self): self.shape = (0,) + self.dtype = np.float64 + + +class TestNumelOpFP16(TestNumelOp): + def init(self): + self.dtype = np.float16 + self.shape = (6, 56, 8, 55) + + +class TestNumelOp1FP16(TestNumelOp): + def init(self): + self.dtype = np.float16 + self.shape = (11, 66) + + +class TestNumelOp2FP16(TestNumelOp): + def init(self): + self.dtype = np.float16 + self.shape = (0,) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and do not support bfloat16", +) +class TestNumelOpBF16(OpTest): + def setUp(self): + self.op_type = "size" + self.python_api = paddle.numel + self.dtype = np.uint16 + self.init() + x = np.random.random((self.shape)).astype(np.float32) + self.inputs = {'Input': convert_float_to_uint16(x)} + self.outputs = {'Out': np.array([np.size(x)])} + + def test_check_output(self): + place = paddle.CUDAPlace(0) + self.check_output_with_place(place) + + def init(self): + self.shape = (6, 56, 8, 55) + + +class TestNumelOp1BF16(TestNumelOpBF16): + def init(self): + self.shape = (11, 66) class TestNumelAPI(unittest.TestCase):