From 76ce81e82802a30e37400bb3fb1893ad76dc2376 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 3 Sep 2021 23:43:41 +0800 Subject: [PATCH] fix(mge): fix F.nn.dropout train and inference bugs GitOrigin-RevId: 9d9f246d7b759ae39a130742b52b10d3150ca5cc --- imperative/python/megengine/functional/nn.py | 32 +++++++++++++------ .../test/unit/functional/test_functional.py | 12 ++++--- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 47d474fe7..b5f4dbde9 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -13,7 +13,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core.ops import builtin -from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt +from ..core.ops.builtin import ( + BatchNorm, + Elemwise, + GetVarShape, + Identity, + Reduce, + TypeCvt, +) from ..core.ops.special import Const from ..core.tensor import amp, megbrain_graph from ..core.tensor.array_method import _elwise_apply @@ -1403,9 +1410,14 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: from megengine import tensor import megengine.functional as F - x = tensor(np.ones(10, dtype=np.float32)) - out = F.dropout(x, 1./3.) - print(out.numpy()) + # test training mode + data = tensor(np.ones(10000000, dtype=np.float32)) + out = F.nn.dropout(data, 1.0 / 3.0, training=True) + assert not out.numpy().all() + + # test eval mode + out = F.nn.dropout(data, 1.0 / 3.0, training=False) + assert out.numpy().all() Outputs: @@ -1416,14 +1428,16 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: """ assert 0 <= drop_prob < 1 - if drop_prob == 0: + if not training or drop_prob == 0: return inp + + # model in training mode, e.g. model.train() rv = uniform(size=inp.shape) mask = rv > drop_prob - inp *= mask.astype(inp.dtype) - if training: - inp *= 1 / (1 - drop_prob) - return inp + ret = inp * mask.astype(inp.dtype) + ret *= 1 / (1 - drop_prob) + + return ret def one_hot(inp: Tensor, num_classes: int) -> Tensor: diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 36c7beb1f..2e7d7ab5e 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -57,10 +57,14 @@ def test_where(): def test_dropout(): - data = tensor(np.ones(10, dtype=np.float32)) - out = F.dropout(data, 1.0 / 3.0, training=False) - - assert out.numpy().sum() >= 0.0 + # test training mode + data = tensor(np.ones(10000000, dtype=np.float32)) + out = F.nn.dropout(data, 1.0 / 3.0, training=True) + assert not out.numpy().all() + + # test eval mode + out = F.nn.dropout(data, 1.0 / 3.0, training=False) + assert out.numpy().all() def test_matinv(): -- GitLab