未验证 提交 0fdd3656 编写于 作者: L Leo Chen 提交者: GitHub

Add fast path for dropout when p == 0 (#29553)

* add fast path for p == 0 in dropout

* add ut
上级 917a1149
......@@ -1476,6 +1476,9 @@ class Dropout(layers.Layer):
self._is_test = is_test
def forward(self, input):
# fast return for p == 0
if self._dropout_prob == 0:
return input
prog = default_main_program()
if (self._seed is None or self._seed == 0) and prog.random_seed != 0:
self._seed = prog.random_seed
......
......@@ -1007,6 +1007,9 @@ def dropout(x,
x = fluid.data(name="data", shape=[None, 32, 32], dtype="float32")
dropped = fluid.layers.dropout(x, dropout_prob=0.5)
"""
# fast return for p == 0
if dropout_prob == 0:
return x
def get_attrs(prog, dropout_prob, is_test, seed):
if (seed is None or seed == 0) and prog.random_seed != 0:
......
......@@ -302,13 +302,16 @@ class TestDropoutFAPI(unittest.TestCase):
training=False,
mode='downscale_in_infer')
res10 = paddle.nn.functional.dropout(x=input, p=1., training=True)
res11 = paddle.fluid.layers.dropout(x=input, dropout_prob=0.)
in_np = np.random.random([40, 40]).astype("float32")
res_np = in_np
res_np2 = np.zeros_like(in_np)
exe = fluid.Executor(place)
res_list = [res1, res2, res3, res4, res5, res6, res7, res8, res9]
res_list = [
res1, res2, res3, res4, res5, res6, res7, res8, res9, res11
]
for res in res_list:
fetches = exe.run(fluid.default_main_program(),
feed={"input": in_np},
......@@ -383,8 +386,12 @@ class TestDropoutFAPI(unittest.TestCase):
mode='downscale_in_infer')
res10 = paddle.nn.functional.dropout(
x=input, p=1., training=True)
dropout = paddle.fluid.dygraph.Dropout(p=0, )
res11 = dropout(input)
res_list = [res1, res2, res3, res4, res5, res6, res7, res8, res9]
res_list = [
res1, res2, res3, res4, res5, res6, res7, res8, res9, res11
]
for res in res_list:
self.assertTrue(np.allclose(res.numpy(), res_np))
self.assertTrue(np.allclose(res10.numpy(), res_np2))
......
......@@ -887,6 +887,10 @@ def dropout(x,
print(y_01)
"""
# fast return for p == 0
if p == 0:
return x
if not isinstance(p, (float, int)):
raise TypeError("p argument should be a number")
if p < 0 or p > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册