Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
76ce81e8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
76ce81e8
编写于
9月 03, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge): fix F.nn.dropout train and inference bugs
GitOrigin-RevId: 9d9f246d7b759ae39a130742b52b10d3150ca5cc
上级
5431929e
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
31 addition
and
13 deletion
+31
-13
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+23
-9
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+8
-4
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
76ce81e8
...
...
@@ -13,7 +13,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union
from
..core._imperative_rt.core2
import
apply
,
dtype_promotion
from
..core._imperative_rt.ops
import
SubgraphBuilder
as
_SubgraphBuilder
from
..core.ops
import
builtin
from
..core.ops.builtin
import
BatchNorm
,
Elemwise
,
GetVarShape
,
Reduce
,
TypeCvt
from
..core.ops.builtin
import
(
BatchNorm
,
Elemwise
,
GetVarShape
,
Identity
,
Reduce
,
TypeCvt
,
)
from
..core.ops.special
import
Const
from
..core.tensor
import
amp
,
megbrain_graph
from
..core.tensor.array_method
import
_elwise_apply
...
...
@@ -1403,9 +1410,14 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
from megengine import tensor
import megengine.functional as F
x = tensor(np.ones(10, dtype=np.float32))
out = F.dropout(x, 1./3.)
print(out.numpy())
# test training mode
data = tensor(np.ones(10000000, dtype=np.float32))
out = F.nn.dropout(data, 1.0 / 3.0, training=True)
assert not out.numpy().all()
# test eval mode
out = F.nn.dropout(data, 1.0 / 3.0, training=False)
assert out.numpy().all()
Outputs:
...
...
@@ -1416,14 +1428,16 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
"""
assert
0
<=
drop_prob
<
1
if
drop_prob
==
0
:
if
not
training
or
drop_prob
==
0
:
return
inp
# model in training mode, e.g. model.train()
rv
=
uniform
(
size
=
inp
.
shape
)
mask
=
rv
>
drop_prob
inp
*=
mask
.
astype
(
inp
.
dtype
)
if
training
:
inp
*=
1
/
(
1
-
drop_prob
)
return
inp
ret
=
inp
*
mask
.
astype
(
inp
.
dtype
)
ret
*=
1
/
(
1
-
drop_prob
)
return
ret
def
one_hot
(
inp
:
Tensor
,
num_classes
:
int
)
->
Tensor
:
...
...
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
76ce81e8
...
...
@@ -57,10 +57,14 @@ def test_where():
def
test_dropout
():
data
=
tensor
(
np
.
ones
(
10
,
dtype
=
np
.
float32
))
out
=
F
.
dropout
(
data
,
1.0
/
3.0
,
training
=
False
)
assert
out
.
numpy
().
sum
()
>=
0.0
# test training mode
data
=
tensor
(
np
.
ones
(
10000000
,
dtype
=
np
.
float32
))
out
=
F
.
nn
.
dropout
(
data
,
1.0
/
3.0
,
training
=
True
)
assert
not
out
.
numpy
().
all
()
# test eval mode
out
=
F
.
nn
.
dropout
(
data
,
1.0
/
3.0
,
training
=
False
)
assert
out
.
numpy
().
all
()
def
test_matinv
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录