diff --git a/python/paddle/fluid/tests/unittests/test_gelu_op.py b/python/paddle/fluid/tests/unittests/test_gelu_op.py index de34b63c9398e70ea97adf61fbee0183d8dc6468..abfb65c27a951a9b34a8de0b8b164e47dc314700 100644 --- a/python/paddle/fluid/tests/unittests/test_gelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_gelu_op.py @@ -21,6 +21,7 @@ import paddle.fluid as fluid import paddle.fluid.dygraph as dg import paddle import paddle.nn.functional as F +from paddle.fluid.framework import _test_eager_guard def gelu(x, approximate): @@ -91,6 +92,10 @@ class TestGeluOp(unittest.TestCase): np.allclose( x_g_ref, x_g_fast_math, rtol=1e-5, atol=5e-4)) + def test_fast_math_eager(self): + with _test_eager_guard(): + self.test_fast_math() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 56b32d41a9bd189efaaca6c4440f8a97341f3d5a..73c423a23e6baceb372a915b595b96061f994cc7 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -23,6 +23,7 @@ from paddle.fluid import Program, program_guard from op_test import OpTest, skip_check_grad_ci import paddle import paddle.nn.functional as F +from paddle.fluid.framework import _test_eager_guard def ref_prelu(x, weight): @@ -76,6 +77,10 @@ class TestFunctionalPReluAPI(unittest.TestCase): self.dygraph_check(self.weight_np_0) self.dygraph_check(self.weight_np_1) + def test_dygraph_api_eager(self): + with _test_eager_guard(): + self.test_dygraph_api() + def test_error(self): with paddle.static.program_guard(paddle.static.Program()): weight_fp32 = paddle.fluid.data( @@ -151,13 +156,19 @@ class TestNNPReluAPI(unittest.TestCase): paddle.enable_static() +def prelu_api_wrapper(x, weight, data_format="NCHW"): + weight = weight.reshape([-1]) + return paddle.nn.functional.prelu(x, weight, data_format, name=None) + + class PReluTest(OpTest): def setUp(self): self.init_dtype() self.init_input_shape() + self.eager_mode = True self.init_attr() self.op_type = "prelu" - self.python_api = paddle.nn.functional.prelu + self.python_api = prelu_api_wrapper x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) # Since zero point in prelu is not differentiable, avoid randomize @@ -178,6 +189,8 @@ class PReluTest(OpTest): alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]]) else: alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) + # eager check don't support mode = 'all' + self.eager_mode = False alpha_np = alpha_np.astype(self.dtype) self.inputs = {'X': x_np, 'Alpha': alpha_np} @@ -208,10 +221,10 @@ class PReluTest(OpTest): self.attrs = {'mode': "channel", "data_format": "NCHW"} def test_check_output(self): - self.check_output(check_eager=False) + self.check_output(check_eager=self.eager_mode) def test_check_grad(self): - self.check_grad(['X', 'Alpha'], 'Out', check_eager=False) + self.check_grad(['X', 'Alpha'], 'Out', check_eager=self.eager_mode) @skip_check_grad_ci( @@ -375,7 +388,7 @@ def create_test_fp16_class(parent, place = core.CUDAPlace(0) if core.is_float16_supported(place): self.check_output_with_place( - place, atol=atol, check_eager=False) + place, atol=atol, check_eager=self.eager_mode) def test_check_grad(self): place = core.CUDAPlace(0) @@ -384,7 +397,7 @@ def create_test_fp16_class(parent, place, ['X', 'Alpha'], 'Out', max_relative_error=max_relative_error, - check_eager=False) + check_eager=self.eager_mode) cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op") TestPReluFp16Case.__name__ = cls_name diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index ce82b10701b3c8f52b332d578c3a5b585eb75cc9..551f4c7b29d6b20d60e0d15bf956fd9494f6ad2f 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -175,7 +175,10 @@ def gelu(x, approximate=False, name=None): # [ 0.84119201, 1.39957154]] """ - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_gelu(x, approximate) + + if _in_legacy_dygraph(): return _C_ops.gelu(x, 'approximate', approximate) check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'gelu')