Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7ba641fe
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7ba641fe
编写于
5月 14, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/core): fix cond op with empty shape
GitOrigin-RevId: 1953d4cd2150a6ae537501d346f4ff710fcfcef9
上级
6fe6df28
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
101 addition
and
4 deletion
+101
-4
src/opr/impl/cond.cpp
src/opr/impl/cond.cpp
+26
-4
src/opr/test/cond.cpp
src/opr/test/cond.cpp
+75
-0
未找到文件。
src/opr/impl/cond.cpp
浏览文件 @
7ba641fe
...
@@ -699,7 +699,9 @@ CondExecMark::CondExecMark(VarNode* ppv, const VarNodeArrayView& inputs,
...
@@ -699,7 +699,9 @@ CondExecMark::CondExecMark(VarNode* ppv, const VarNodeArrayView& inputs,
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
add_input
({
inputs
[
i
]});
add_input
({
inputs
[
i
]});
add_output
(
ssprintf
(
"fwd%zu"
,
i
))
->
dtype
(
inputs
[
i
]
->
dtype
());
add_output
(
ssprintf
(
"fwd%zu"
,
i
))
->
dtype
(
inputs
[
i
]
->
dtype
())
.
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
}
add_input
({
ppv
});
add_input
({
ppv
});
add_equivalence_component
<
PODHash
<
Param
>>
(
&
m_param
);
add_equivalence_component
<
PODHash
<
Param
>>
(
&
m_param
);
...
@@ -789,6 +791,10 @@ void CondExecMark::add_input_layout_constraint() {
...
@@ -789,6 +791,10 @@ void CondExecMark::add_input_layout_constraint() {
CondExecMark
::
NodeProp
*
CondExecMark
::
do_make_node_prop
()
const
{
CondExecMark
::
NodeProp
*
CondExecMark
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
dep_map
().
at
(
input
().
back
())
=
NodeProp
::
DepType
::
DEV_COMP_ORDER
;
ret
->
dep_map
().
at
(
input
().
back
())
=
NodeProp
::
DepType
::
DEV_COMP_ORDER
;
for
(
size_t
i
=
0
;
i
<
input
().
size
()
-
1
;
++
i
)
{
ret
->
add_dep_type_existing_var
(
input
(
i
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
}
return
ret
;
return
ret
;
}
}
...
@@ -859,7 +865,8 @@ CondExecMerge::CondExecMerge(const VarNodeArrayView& inputs,
...
@@ -859,7 +865,8 @@ CondExecMerge::CondExecMerge(const VarNodeArrayView& inputs,
// 2. dynamic allocator would wait for all inputs to become ready (see
// 2. dynamic allocator would wait for all inputs to become ready (see
// VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready),
// VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready),
// which would cause infinite waiting for unselected inputs.
// which would cause infinite waiting for unselected inputs.
ovar
->
add_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
);
ovar
->
add_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
)
.
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
}
MGB_MARK_USED_VAR
(
mask2str
);
MGB_MARK_USED_VAR
(
mask2str
);
...
@@ -1056,7 +1063,9 @@ void CondExecMerge::init_output_static_infer_desc() {
...
@@ -1056,7 +1063,9 @@ void CondExecMerge::init_output_static_infer_desc() {
desc
.
infer_func
=
[
this
](
DeviceTensorND
&
dest
,
const
InpVal
&
inp
)
{
desc
.
infer_func
=
[
this
](
DeviceTensorND
&
dest
,
const
InpVal
&
inp
)
{
auto
nr_branch
=
m_branch_masks
.
size
();
auto
nr_branch
=
m_branch_masks
.
size
();
bool
found
=
false
,
first
=
true
;
bool
found
=
false
,
first
=
true
;
for
(
size_t
i
=
0
;
i
<
nr_branch
;
++
i
)
{
auto
&&
shape
=
inp
.
val
.
at
(
nr_branch
).
shape
();
for
(
size_t
i
=
0
;
i
<
nr_branch
&&
!
shape
.
is_empty
();
++
i
)
{
if
(
!
inp
.
val
[
i
].
value
().
ptr
<
int
>
()[
0
])
if
(
!
inp
.
val
[
i
].
value
().
ptr
<
int
>
()[
0
])
continue
;
continue
;
auto
&&
cur
=
inp
.
val
.
at
(
nr_branch
+
i
).
value
();
auto
&&
cur
=
inp
.
val
.
at
(
nr_branch
+
i
).
value
();
...
@@ -1083,7 +1092,6 @@ void CondExecMerge::init_output_static_infer_desc() {
...
@@ -1083,7 +1092,6 @@ void CondExecMerge::init_output_static_infer_desc() {
}
}
}
}
if
(
!
found
)
{
if
(
!
found
)
{
auto
&&
shape
=
inp
.
val
.
at
(
nr_branch
).
shape
();
if
(
dest
.
storage
().
raw_storage
().
use_count
()
>
1
)
{
if
(
dest
.
storage
().
raw_storage
().
use_count
()
>
1
)
{
// likely to be assigned from some input in previous
// likely to be assigned from some input in previous
// runs; we create a new tensor to avoid modifying input
// runs; we create a new tensor to avoid modifying input
...
@@ -1115,6 +1123,7 @@ void CondExecMerge::scn_do_execute() {
...
@@ -1115,6 +1123,7 @@ void CondExecMerge::scn_do_execute() {
bool
first
=
true
;
bool
first
=
true
;
auto
&&
forwarded
=
m_mem_forwarded
;
auto
&&
forwarded
=
m_mem_forwarded
;
std
::
vector
<
bool
>
is_shape_empty
(
nr_out
,
false
);
for
(
size_t
br
=
0
;
br
<
m_branch_masks
.
size
();
++
br
)
{
for
(
size_t
br
=
0
;
br
<
m_branch_masks
.
size
();
++
br
)
{
if
(
!
m_branch_masks
[
br
]
->
enabled
())
{
if
(
!
m_branch_masks
[
br
]
->
enabled
())
{
continue
;
continue
;
...
@@ -1125,6 +1134,10 @@ void CondExecMerge::scn_do_execute() {
...
@@ -1125,6 +1134,10 @@ void CondExecMerge::scn_do_execute() {
for
(
size_t
oidx
=
0
;
oidx
<
nr_out
;
++
oidx
)
{
for
(
size_t
oidx
=
0
;
oidx
<
nr_out
;
++
oidx
)
{
bool
succ
=
output
(
oidx
)
->
reset_dev_tensor_from_other_var
(
bool
succ
=
output
(
oidx
)
->
reset_dev_tensor_from_other_var
(
inp
(
br
,
oidx
));
inp
(
br
,
oidx
));
if
(
inp
(
br
,
oidx
)
->
shape
().
is_empty
())
{
is_shape_empty
[
oidx
]
=
true
;
continue
;
}
if
(
!
is_exact_one
())
{
if
(
!
is_exact_one
())
{
if
(
forwarded
.
empty
())
{
if
(
forwarded
.
empty
())
{
forwarded
.
resize
(
nr_out
);
forwarded
.
resize
(
nr_out
);
...
@@ -1144,6 +1157,11 @@ void CondExecMerge::scn_do_execute() {
...
@@ -1144,6 +1157,11 @@ void CondExecMerge::scn_do_execute() {
auto
ovar
=
output
(
oidx
);
auto
ovar
=
output
(
oidx
);
auto
&&
src
=
inp
(
br
,
oidx
)
->
dev_tensor
().
as_megdnn
();
auto
&&
src
=
inp
(
br
,
oidx
)
->
dev_tensor
().
as_megdnn
();
auto
&&
dest
=
ovar
->
dev_tensor
().
as_megdnn
();
auto
&&
dest
=
ovar
->
dev_tensor
().
as_megdnn
();
mgb_assert
(
src
.
layout
.
eq_shape
(
dest
.
layout
),
"shape mismatch: %s vs %s in CondExecMerge"
,
src
.
layout
.
to_string
().
c_str
(),
dest
.
layout
.
to_string
().
c_str
());
if
(
is_shape_empty
[
oidx
])
continue
;
if
(
forwarded
[
oidx
])
{
if
(
forwarded
[
oidx
])
{
ovar
->
shape_alloc
(
ovar
->
shape
());
ovar
->
shape_alloc
(
ovar
->
shape
());
auto
&&
own_dest
=
ovar
->
dev_tensor
().
as_megdnn
();
auto
&&
own_dest
=
ovar
->
dev_tensor
().
as_megdnn
();
...
@@ -1200,6 +1218,10 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
...
@@ -1200,6 +1218,10 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const {
// directly
// directly
ret
->
dep_map
().
at
(
input
().
back
())
=
NodeProp
::
DepType
::
DEV_COMP_ORDER
;
ret
->
dep_map
().
at
(
input
().
back
())
=
NodeProp
::
DepType
::
DEV_COMP_ORDER
;
}
}
for
(
size_t
i
=
0
;
i
<
m_param
.
nr_output
*
m_branch_masks
.
size
();
++
i
)
{
ret
->
add_dep_type_existing_var
(
input
(
i
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
}
return
ret
;
return
ret
;
}
}
...
...
src/opr/test/cond.cpp
浏览文件 @
7ba641fe
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/cond.h"
#include "megbrain/opr/cond.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
#include "megbrain/utils/timer.h"
#include "megbrain/utils/timer.h"
...
@@ -1285,6 +1286,80 @@ TEST(TestCondExec, MultiShape) {
...
@@ -1285,6 +1286,80 @@ TEST(TestCondExec, MultiShape) {
check
(
host_d2
);
check
(
host_d2
);
}
}
TEST
(
TestCondExec
,
EmptyShape
)
{
HostTensorGenerator
<>
gen
;
auto
host_pred
=
gen
({
1
});
host_pred
->
ptr
<
float
>
()[
0
]
=
0
;
static
auto
empty_in_empty_out
=
[](
SymbolVar
x
)
{
return
x
;
};
static
auto
empty_in_scalar_out
=
[](
SymbolVar
x
)
{
return
opr
::
Concat
::
make
({
x
,
x
.
make_scalar
(
1.
f
)},
0
);
};
static
auto
scalar_in_empty_out
=
[](
SymbolVar
x
)
{
return
opr
::
CondTake
::
make
(
x
,
x
,
{})[
0
];
// whether eq 0
};
{
// EXACT_ONE
auto
graph
=
ComputingGraph
::
make
();
auto
pred
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_pred
),
empty
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
*
gen
({
0
})),
scalar
=
pred
.
make_scalar
(
1.
f
),
y0
=
empty_in_empty_out
(
make_one_cond
(
pred
+
1
,
empty
)),
y1
=
empty_in_scalar_out
(
make_one_cond
(
pred
,
empty
)),
y2
=
scalar_in_empty_out
(
make_one_cond
(
pred
-
1
,
scalar
)),
z
=
merge_one_out
({
y0
,
y1
,
y2
},
MergeMode
::
EXACT_ONE
);
HostTensorND
host_z
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
z
,
host_z
)});
func
->
execute
();
ASSERT_TRUE
(
host_z
.
layout
().
is_empty
());
host_pred
->
ptr
<
float
>
()[
0
]
=
1
;
func
->
execute
();
ASSERT_EQ
(
1.
f
,
host_z
.
ptr
<
float
>
()[
0
]);
host_pred
->
ptr
<
float
>
()[
0
]
=
2
;
func
->
execute
();
ASSERT_TRUE
(
host_z
.
layout
().
is_empty
());
}
{
// SUM
auto
graph
=
ComputingGraph
::
make
();
host_pred
->
ptr
<
float
>
()[
0
]
=
1
;
auto
pred
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_pred
),
empty
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
*
gen
({
0
})),
scalar
=
pred
.
make_scalar
(
1.
f
),
y0
=
empty_in_empty_out
(
make_one_cond
(
pred
,
empty
)),
y1
=
scalar_in_empty_out
(
make_one_cond
(
pred
,
scalar
)),
z
=
merge_one_out
({
y0
,
y1
},
MergeMode
::
SUM
);
HostTensorND
host_z
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
z
,
host_z
)});
func
->
execute
();
ASSERT_TRUE
(
host_z
.
layout
().
is_empty
());
}
{
// TAKE GRAD
auto
graph
=
ComputingGraph
::
make
();
host_pred
->
ptr
<
float
>
()[
0
]
=
0
;
auto
pred
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_pred
),
x
=
pred
.
make_scalar
(
1.2
f
),
y0
=
opr
::
CondTake
::
make
(
make_one_cond
(
pred
+
1
,
x
),
pred
,
{})[
0
],
y1
=
make_one_cond
(
pred
,
x
.
make_scalar
(
3.4
f
)),
z
=
merge_one_out
({
y0
,
y1
},
MergeMode
::
EXACT_ONE
),
g
=
cg
::
grad
(
z
,
x
);
HostTensorND
host_z
,
host_g
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
z
,
host_z
),
make_callback_copy
(
g
,
host_g
)});
func
->
execute
();
ASSERT_EQ
(
1.2
f
,
host_z
.
ptr
<
float
>
()[
0
]);
ASSERT_EQ
(
1.
f
,
host_g
.
ptr
<
float
>
()[
0
]);
host_pred
->
ptr
<
float
>
()[
0
]
=
1
;
func
->
execute
();
ASSERT_EQ
(
3.4
f
,
host_z
.
ptr
<
float
>
()[
0
]);
ASSERT_EQ
(
0.
f
,
host_g
.
ptr
<
float
>
()[
0
]);
}
}
#endif // MGB_ENABLE_COND_EXEC
#endif // MGB_ENABLE_COND_EXEC
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录