diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index b87ded3ed6094365e92e59146eb6bfbd33f39240..72078a7b5d65666fbe12139d4524f4c4a61bd0bb 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1304,6 +1304,8 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: """ assert 0 <= drop_prob < 1 + if drop_prob == 0: + return inp rv = uniform(size=inp.shape) mask = rv > drop_prob inp *= mask.astype(inp.dtype)