Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9779bc7f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9779bc7f
编写于
3月 02, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): allow rng op infer shape fallible
GitOrigin-RevId: 687844500cc2cab18de576b1484215c72329e4b8
上级
8f7fa90c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
7 deletion
+22
-7
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+3
-1
imperative/src/impl/ops/rng.cpp
imperative/src/impl/ops/rng.cpp
+19
-6
未找到文件。
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
9779bc7f
...
...
@@ -71,7 +71,8 @@ def test_dropout():
with
gm
:
out
=
F
.
nn
.
dropout
(
data
,
rate
,
training
=
True
)
gm
.
backward
(
out
,
tensor
(
np
.
ones
(
shape
,
dtype
=
np
.
float32
)))
assert
not
out
.
numpy
().
all
()
if
len
(
shape
)
!=
0
:
assert
not
out
.
numpy
().
all
()
np
.
testing
.
assert_allclose
(
out
.
numpy
(),
data
.
grad
.
numpy
(),
1e-7
,
1e-7
)
def
test_multiple_dropout
(
shape
,
rate
):
...
...
@@ -99,6 +100,7 @@ def test_dropout():
out4
=
F
.
nn
.
dropout
(
data
,
rate
,
training
=
True
)
assert
not
(
out1
.
numpy
()
==
out4
.
numpy
()).
all
()
test_dropout_with_shape
([],
0.4
)
test_dropout_with_shape
([
13
,
17
,
63
,
21
],
0.4
)
test_dropout_with_shape
([
16
,
32
,
64
],
0.3
)
test_multiple_dropout
([
1024
],
0.2
)
...
...
imperative/src/impl/ops/rng.cpp
浏览文件 @
9779bc7f
...
...
@@ -559,25 +559,33 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
}
dest
.
comp_node
=
inputs
[
0
].
comp_node
;
dest
.
layout
=
_InferLayout
<
rng_with_shape
>::
do_infer
(
inputs
[
0
],
xxx_rng_def
);
return
{{
dest
},
true
};
return
{{
dest
},
inputs
[
0
].
layout
.
ndim
!=
0
};
}
template
<
>
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
<
ShuffleRNG
>
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
bool
success
=
inputs
[
0
].
layout
.
ndim
!=
0
;
SmallVector
<
LogicalTensorDesc
>
dests
(
2
);
dests
[
0
].
comp_node
=
inputs
[
0
].
comp_node
;
dests
[
0
].
layout
=
TensorLayout
(
inputs
[
0
].
layout
);
dests
[
0
].
layout
.
dtype
=
inputs
[
0
].
layout
.
dtype
;
dests
[
1
].
comp_node
=
inputs
[
0
].
comp_node
;
dests
[
1
].
layout
=
TensorLayout
(
TensorShape
({
inputs
[
0
].
layout
.
shape
[
0
]}),
dtype
::
Int32
());
return
{
dests
,
true
};
if
(
success
)
{
dests
[
1
].
layout
=
TensorLayout
(
TensorShape
({
inputs
[
0
].
layout
.
shape
[
0
]}),
dtype
::
Int32
());
}
else
{
dests
[
1
].
layout
=
TensorLayout
(
dtype
::
Int32
());
}
return
{
dests
,
success
};
}
template
<
>
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
<
Dropout
>
(
const
OpDef
&
op
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
bool
success
=
inputs
[
0
].
layout
.
ndim
!=
0
;
SmallVector
<
LogicalTensorDesc
>
dests
(
2
);
auto
cn
=
inputs
[
0
].
comp_node
;
dests
[
0
].
comp_node
=
cn
;
...
...
@@ -590,8 +598,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
inputs
[
0
].
layout
);
};
dests
[
1
].
comp_node
=
cn
;
dests
[
1
].
layout
=
TensorLayout
(
TensorShape
({
get_mask_size
()}),
dtype
::
Byte
());
return
{
dests
,
true
};
if
(
success
)
{
dests
[
1
].
layout
=
TensorLayout
(
TensorShape
({
get_mask_size
()}),
dtype
::
Byte
());
}
else
{
dests
[
1
].
layout
=
TensorLayout
(
dtype
::
Byte
());
}
return
{
dests
,
success
};
}
template
<
typename
Op
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录