From d54e8420be00353af3f3524be8ea6d039e49c6be Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 25 Sep 2017 10:31:32 -0700 Subject: [PATCH] Stabilize prelu gradient check --- python/paddle/v2/framework/tests/test_prelu_op.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index 2b6b7db3680..676fd9f7c55 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -7,6 +7,14 @@ class PReluTest(OpTest): def setUp(self): self.op_type = "prelu" x_np = np.random.normal(size=(10, 10)).astype("float32") + + for pos, val in np.ndenumerate(x_np): + # Since zero point in prelu is not differentiable, avoid randomize + # zero. + while abs(val) < 1e-3: + x_np[pos] = np.random.normal() + val = x_np[pos] + x_np_sign = np.sign(x_np) x_np = x_np_sign * np.maximum(x_np, .005) alpha_np = np.array([.1]) -- GitLab