未验证 提交 73533b9b 编写于 作者: X xiongkun 提交者: GitHub

[Yaml] add unittest for prelu, gelu. (#41444)

* add gelu pythonapi and unittest

* fix prelu
上级 eea85814
...@@ -21,6 +21,7 @@ import paddle.fluid as fluid ...@@ -21,6 +21,7 @@ import paddle.fluid as fluid
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid.framework import _test_eager_guard
def gelu(x, approximate): def gelu(x, approximate):
...@@ -91,6 +92,10 @@ class TestGeluOp(unittest.TestCase): ...@@ -91,6 +92,10 @@ class TestGeluOp(unittest.TestCase):
np.allclose( np.allclose(
x_g_ref, x_g_fast_math, rtol=1e-5, atol=5e-4)) x_g_ref, x_g_fast_math, rtol=1e-5, atol=5e-4))
def test_fast_math_eager(self):
with _test_eager_guard():
self.test_fast_math()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -23,6 +23,7 @@ from paddle.fluid import Program, program_guard ...@@ -23,6 +23,7 @@ from paddle.fluid import Program, program_guard
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.fluid.framework import _test_eager_guard
def ref_prelu(x, weight): def ref_prelu(x, weight):
...@@ -76,6 +77,10 @@ class TestFunctionalPReluAPI(unittest.TestCase): ...@@ -76,6 +77,10 @@ class TestFunctionalPReluAPI(unittest.TestCase):
self.dygraph_check(self.weight_np_0) self.dygraph_check(self.weight_np_0)
self.dygraph_check(self.weight_np_1) self.dygraph_check(self.weight_np_1)
def test_dygraph_api_eager(self):
with _test_eager_guard():
self.test_dygraph_api()
def test_error(self): def test_error(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
weight_fp32 = paddle.fluid.data( weight_fp32 = paddle.fluid.data(
...@@ -151,13 +156,19 @@ class TestNNPReluAPI(unittest.TestCase): ...@@ -151,13 +156,19 @@ class TestNNPReluAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def prelu_api_wrapper(x, weight, data_format="NCHW"):
weight = weight.reshape([-1])
return paddle.nn.functional.prelu(x, weight, data_format, name=None)
class PReluTest(OpTest): class PReluTest(OpTest):
def setUp(self): def setUp(self):
self.init_dtype() self.init_dtype()
self.init_input_shape() self.init_input_shape()
self.eager_mode = True
self.init_attr() self.init_attr()
self.op_type = "prelu" self.op_type = "prelu"
self.python_api = paddle.nn.functional.prelu self.python_api = prelu_api_wrapper
x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
# Since zero point in prelu is not differentiable, avoid randomize # Since zero point in prelu is not differentiable, avoid randomize
...@@ -178,6 +189,8 @@ class PReluTest(OpTest): ...@@ -178,6 +189,8 @@ class PReluTest(OpTest):
alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]]) alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]])
else: else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
# eager check don't support mode = 'all'
self.eager_mode = False
alpha_np = alpha_np.astype(self.dtype) alpha_np = alpha_np.astype(self.dtype)
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
...@@ -208,10 +221,10 @@ class PReluTest(OpTest): ...@@ -208,10 +221,10 @@ class PReluTest(OpTest):
self.attrs = {'mode': "channel", "data_format": "NCHW"} self.attrs = {'mode': "channel", "data_format": "NCHW"}
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=self.eager_mode)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Alpha'], 'Out', check_eager=False) self.check_grad(['X', 'Alpha'], 'Out', check_eager=self.eager_mode)
@skip_check_grad_ci( @skip_check_grad_ci(
...@@ -375,7 +388,7 @@ def create_test_fp16_class(parent, ...@@ -375,7 +388,7 @@ def create_test_fp16_class(parent,
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): if core.is_float16_supported(place):
self.check_output_with_place( self.check_output_with_place(
place, atol=atol, check_eager=False) place, atol=atol, check_eager=self.eager_mode)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
...@@ -384,7 +397,7 @@ def create_test_fp16_class(parent, ...@@ -384,7 +397,7 @@ def create_test_fp16_class(parent,
place, ['X', 'Alpha'], place, ['X', 'Alpha'],
'Out', 'Out',
max_relative_error=max_relative_error, max_relative_error=max_relative_error,
check_eager=False) check_eager=self.eager_mode)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op") cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op")
TestPReluFp16Case.__name__ = cls_name TestPReluFp16Case.__name__ = cls_name
......
...@@ -175,7 +175,10 @@ def gelu(x, approximate=False, name=None): ...@@ -175,7 +175,10 @@ def gelu(x, approximate=False, name=None):
# [ 0.84119201, 1.39957154]] # [ 0.84119201, 1.39957154]]
""" """
if in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_gelu(x, approximate)
if _in_legacy_dygraph():
return _C_ops.gelu(x, 'approximate', approximate) return _C_ops.gelu(x, 'approximate', approximate)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'gelu') check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'gelu')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册