提交 154d88c2 编写于 作者: Z zchen0211

fix gradient not stable

上级 3c3a6d90
...@@ -7,6 +7,8 @@ class PReluTest(OpTest): ...@@ -7,6 +7,8 @@ class PReluTest(OpTest):
def setUp(self): def setUp(self):
self.op_type = "prelu" self.op_type = "prelu"
x_np = np.random.normal(size=(10, 10)).astype("float32") x_np = np.random.normal(size=(10, 10)).astype("float32")
x_np_sign = np.sign(x_np)
x_np = x_np_sign * np.maximum(x_np, .005)
alpha_np = np.array([.1]) alpha_np = np.array([.1])
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
out_np = np.maximum(self.inputs['X'], 0.) out_np = np.maximum(self.inputs['X'], 0.)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册