diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 47d474fe76c49701bd0ed04578294d2bc85d63a5..b5f4dbde9df7e7c051c151e914af193fa614152f 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -13,7 +13,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin -from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt +from ..core.ops.builtin import ( + BatchNorm, + Elemwise, + GetVarShape, + Identity, + Reduce, + TypeCvt, +) from ..core.ops.special import Const from ..core.tensor import amp, megbrain_graph from ..core.tensor.array_method import _elwise_apply @@ -1403,9 +1410,14 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: from megengine import tensor import megengine.functional as F - x = tensor(np.ones(10, dtype=np.float32)) - out = F.dropout(x, 1./3.) - print(out.numpy()) + # test training mode + data = tensor(np.ones(10000000, dtype=np.float32)) + out = F.nn.dropout(data, 1.0 / 3.0, training=True) + assert not out.numpy().all() + + # test eval mode + out = F.nn.dropout(data, 1.0 / 3.0, training=False) + assert out.numpy().all() Outputs: @@ -1416,14 +1428,16 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: """ assert 0 <= drop_prob < 1 - if drop_prob == 0: + if not training or drop_prob == 0: return inp + + # model in training mode, e.g. model.train() rv = uniform(size=inp.shape) mask = rv > drop_prob - inp *= mask.astype(inp.dtype) - if training: - inp *= 1 / (1 - drop_prob) - return inp + ret = inp * mask.astype(inp.dtype) + ret *= 1 / (1 - drop_prob) + + return ret def one_hot(inp: Tensor, num_classes: int) -> Tensor: diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 36c7beb1f2514575895ee3ba08db93cfb208eb39..2e7d7ab5e5ba41e095d38083b5e5660212f945eb 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -57,10 +57,14 @@ def test_where(): def test_dropout(): - data = tensor(np.ones(10, dtype=np.float32)) - out = F.dropout(data, 1.0 / 3.0, training=False) - - assert out.numpy().sum() >= 0.0 + # test training mode + data = tensor(np.ones(10000000, dtype=np.float32)) + out = F.nn.dropout(data, 1.0 / 3.0, training=True) + assert not out.numpy().all() + + # test eval mode + out = F.nn.dropout(data, 1.0 / 3.0, training=False) + assert out.numpy().all() def test_matinv():