提交 77309609 编写于 作者: M Megvii Engine Team

perf(functional/dropout): add fastpath for dropout

GitOrigin-RevId: 3bf8546908c2cd41a9d33c6236107a15e53f9fb4
上级 cc07b96f
......@@ -1304,6 +1304,8 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
"""
assert 0 <= drop_prob < 1
if drop_prob == 0:
return inp
rv = uniform(size=inp.shape)
mask = rv > drop_prob
inp *= mask.astype(inp.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册