提交 76ce81e8 编写于 作者: M Megvii Engine Team

fix(mge): fix F.nn.dropout train and inference bugs

GitOrigin-RevId: 9d9f246d7b759ae39a130742b52b10d3150ca5cc
上级 5431929e
......@@ -13,7 +13,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.builtin import (
BatchNorm,
Elemwise,
GetVarShape,
Identity,
Reduce,
TypeCvt,
)
from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply
......@@ -1403,9 +1410,14 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
from megengine import tensor
import megengine.functional as F
x = tensor(np.ones(10, dtype=np.float32))
out = F.dropout(x, 1./3.)
print(out.numpy())
# test training mode
data = tensor(np.ones(10000000, dtype=np.float32))
out = F.nn.dropout(data, 1.0 / 3.0, training=True)
assert not out.numpy().all()
# test eval mode
out = F.nn.dropout(data, 1.0 / 3.0, training=False)
assert out.numpy().all()
Outputs:
......@@ -1416,14 +1428,16 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
"""
assert 0 <= drop_prob < 1
if drop_prob == 0:
if not training or drop_prob == 0:
return inp
# model in training mode, e.g. model.train()
rv = uniform(size=inp.shape)
mask = rv > drop_prob
inp *= mask.astype(inp.dtype)
if training:
inp *= 1 / (1 - drop_prob)
return inp
ret = inp * mask.astype(inp.dtype)
ret *= 1 / (1 - drop_prob)
return ret
def one_hot(inp: Tensor, num_classes: int) -> Tensor:
......
......@@ -57,10 +57,14 @@ def test_where():
def test_dropout():
data = tensor(np.ones(10, dtype=np.float32))
out = F.dropout(data, 1.0 / 3.0, training=False)
assert out.numpy().sum() >= 0.0
# test training mode
data = tensor(np.ones(10000000, dtype=np.float32))
out = F.nn.dropout(data, 1.0 / 3.0, training=True)
assert not out.numpy().all()
# test eval mode
out = F.nn.dropout(data, 1.0 / 3.0, training=False)
assert out.numpy().all()
def test_matinv():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册