未验证 提交 6133ca4e 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Support float16 in selu (#54030)

* [AMP OP&Test] Support float16 in selu

* fix
上级 97fe79a9
......@@ -274,4 +274,5 @@ PD_REGISTER_KERNEL(selu,
phi::SeluKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(selu_grad,
phi::SeluGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.nn.functional as F
......@@ -43,14 +43,16 @@ class SeluTest(OpTest):
self.op_type = "selu"
self.python_api = paddle.nn.functional.selu
self.x_shape = [3, 5, 5, 10]
self.dtype = np.float64
self.init_x_shape()
self.init_dtype()
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
x = np.random.normal(size=self.x_shape).astype(self.dtype)
if self.dtype == np.uint16:
x = np.random.normal(size=self.x_shape).astype(np.float32)
else:
x = np.random.normal(size=self.x_shape).astype(self.dtype)
# Since zero point in selu is not differentiable, avoid randomize
# zero.
......@@ -58,8 +60,12 @@ class SeluTest(OpTest):
out = ref_selu(x, scale, alpha)
self.inputs = {'X': x}
self.outputs = {'Out': out}
if self.dtype == np.uint16:
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
else:
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {
'alpha': alpha,
......@@ -70,7 +76,7 @@ class SeluTest(OpTest):
pass
def init_dtype(self):
pass
self.dtype = np.float64
def test_check_output(self):
self.check_output()
......@@ -79,6 +85,27 @@ class SeluTest(OpTest):
self.check_grad(['X'], 'Out')
class SeluTestFP16OP(SeluTest):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class SeluTestBF16OP(SeluTest):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')
class TestSeluAPI(unittest.TestCase):
# test paddle.nn.SELU, paddle.nn.functional.selu
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册