未验证 提交 6133ca4e 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Support float16 in selu (#54030)

* [AMP OP&Test] Support float16 in selu

* fix
上级 97fe79a9
...@@ -274,4 +274,5 @@ PD_REGISTER_KERNEL(selu, ...@@ -274,4 +274,5 @@ PD_REGISTER_KERNEL(selu,
phi::SeluKernel, phi::SeluKernel,
float, float,
double, double,
phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
...@@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(selu_grad, ...@@ -24,4 +24,5 @@ PD_REGISTER_KERNEL(selu_grad,
phi::SeluGradKernel, phi::SeluGradKernel,
float, float,
double, double,
phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -43,13 +43,15 @@ class SeluTest(OpTest): ...@@ -43,13 +43,15 @@ class SeluTest(OpTest):
self.op_type = "selu" self.op_type = "selu"
self.python_api = paddle.nn.functional.selu self.python_api = paddle.nn.functional.selu
self.x_shape = [3, 5, 5, 10] self.x_shape = [3, 5, 5, 10]
self.dtype = np.float64
self.init_x_shape() self.init_x_shape()
self.init_dtype() self.init_dtype()
alpha = 1.6732632423543772848170429916717 alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946 scale = 1.0507009873554804934193349852946
if self.dtype == np.uint16:
x = np.random.normal(size=self.x_shape).astype(np.float32)
else:
x = np.random.normal(size=self.x_shape).astype(self.dtype) x = np.random.normal(size=self.x_shape).astype(self.dtype)
# Since zero point in selu is not differentiable, avoid randomize # Since zero point in selu is not differentiable, avoid randomize
...@@ -58,6 +60,10 @@ class SeluTest(OpTest): ...@@ -58,6 +60,10 @@ class SeluTest(OpTest):
out = ref_selu(x, scale, alpha) out = ref_selu(x, scale, alpha)
if self.dtype == np.uint16:
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
else:
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': out} self.outputs = {'Out': out}
...@@ -70,7 +76,7 @@ class SeluTest(OpTest): ...@@ -70,7 +76,7 @@ class SeluTest(OpTest):
pass pass
def init_dtype(self): def init_dtype(self):
pass self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -79,6 +85,27 @@ class SeluTest(OpTest): ...@@ -79,6 +85,27 @@ class SeluTest(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class SeluTestFP16OP(SeluTest):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class SeluTestBF16OP(SeluTest):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Out')
class TestSeluAPI(unittest.TestCase): class TestSeluAPI(unittest.TestCase):
# test paddle.nn.SELU, paddle.nn.functional.selu # test paddle.nn.SELU, paddle.nn.functional.selu
def setUp(self): def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册