提交 e5a3c1d2 编写于 作者: Z Zhuoyuan 提交者: GitHub

Merge pull request #4372 from reyoung/feature/stable_prelu_grad_test

Stabilize prelu gradient check
...@@ -7,6 +7,14 @@ class PReluTest(OpTest): ...@@ -7,6 +7,14 @@ class PReluTest(OpTest):
def setUp(self): def setUp(self):
self.op_type = "prelu" self.op_type = "prelu"
x_np = np.random.normal(size=(10, 10)).astype("float32") x_np = np.random.normal(size=(10, 10)).astype("float32")
for pos, val in np.ndenumerate(x_np):
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
while abs(val) < 1e-3:
x_np[pos] = np.random.normal()
val = x_np[pos]
x_np_sign = np.sign(x_np) x_np_sign = np.sign(x_np)
x_np = x_np_sign * np.maximum(x_np, .005) x_np = x_np_sign * np.maximum(x_np, .005)
alpha_np = np.array([.1]) alpha_np = np.array([.1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册