未验证 提交 3231ce9f 编写于 作者: H huangjun12 提交者: GitHub

fix alpha dropout bug when p=1, test=develop (#26977) (#27010)

上级 0a9f9f93
......@@ -43,7 +43,7 @@ class TestDropoutOp(OpTest):
class TestDropoutOpInput1d(OpTest):
def setUp(self):
self.op_type = "dropout"
self.inputs = {'X': np.random.random((2000)).astype("float32")}
self.inputs = {'X': np.random.random((2000, )).astype("float32")}
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
self.outputs = {
'Out': self.inputs['X'],
......@@ -672,9 +672,11 @@ class TestAlphaDropoutFAPI(unittest.TestCase):
res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.)
res2 = paddle.nn.functional.alpha_dropout(
x=input, p=0., training=False)
res3 = paddle.nn.functional.alpha_dropout(x=input, p=1.)
in_np = np.random.random([40, 40]).astype("float32")
res_np = in_np
res_np3 = np.zeros_like(in_np)
exe = fluid.Executor(place)
res_list = [res1, res2]
......@@ -683,6 +685,10 @@ class TestAlphaDropoutFAPI(unittest.TestCase):
feed={"input": in_np},
fetch_list=[res])
self.assertTrue(np.allclose(fetches[0], res_np))
fetches = exe.run(fluid.default_main_program(),
feed={"input": in_np},
fetch_list=[res3])
self.assertTrue(np.allclose(fetches[0], res_np3))
def test_static(self):
for place in self.places:
......@@ -693,15 +699,18 @@ class TestAlphaDropoutFAPI(unittest.TestCase):
with fluid.dygraph.guard(place):
in_np = np.random.random([40, 40]).astype("float32")
res_np = in_np
res_np3 = np.zeros_like(in_np)
input = fluid.dygraph.to_variable(in_np)
res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.)
res2 = paddle.nn.functional.alpha_dropout(
x=input, p=0., training=False)
res3 = paddle.nn.functional.alpha_dropout(x=input, p=1.)
res_list = [res1, res2]
for res in res_list:
self.assertTrue(np.allclose(res.numpy(), res_np))
self.assertTrue(np.allclose(res3.numpy(), res_np3))
class TestAlphaDropoutFAPIError(unittest.TestCase):
......
......@@ -1091,6 +1091,8 @@ def alpha_dropout(x, p=0.5, training=True, name=None):
'alpha_dropout')
if training:
if p == 1:
return layers.scale(x, scale=0.)
#get transformation params
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册