Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a430c912
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a430c912
编写于
8月 18, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): let CondTake support empty input
GitOrigin-RevId: dfb401a945d5d75909f7b78448b3713623c28a2c
上级
432fdb7e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
90 addition
and
30 deletion
+90
-30
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+30
-0
imperative/src/impl/ops/cond_take.cpp
imperative/src/impl/ops/cond_take.cpp
+19
-14
src/opr/impl/misc.cpp
src/opr/impl/misc.cpp
+23
-4
src/opr/include/megbrain/opr/misc.h
src/opr/include/megbrain/opr/misc.h
+1
-0
src/opr/test/misc.cpp
src/opr/test/misc.cpp
+17
-12
未找到文件。
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
a430c912
...
@@ -18,6 +18,7 @@ import megengine.amp as amp
...
@@ -18,6 +18,7 @@ import megengine.amp as amp
import
megengine.core.ops.builtin
as
builtin
import
megengine.core.ops.builtin
as
builtin
import
megengine.core.tensor.dtype
as
dtype
import
megengine.core.tensor.dtype
as
dtype
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.jit
as
jit
from
megengine
import
Parameter
,
Tensor
,
is_cuda_available
,
tensor
from
megengine
import
Parameter
,
Tensor
,
is_cuda_available
,
tensor
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.core._trace_option
import
use_symbolic_shape
from
megengine.core.autodiff.grad
import
Grad
from
megengine.core.autodiff.grad
import
Grad
...
@@ -859,6 +860,35 @@ def test_condtake():
...
@@ -859,6 +860,35 @@ def test_condtake():
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
np
.
where
(
y
.
reshape
(
-
1
))[
0
])
# @pytest.mark.parametrize("is_symbolic", [None, False, True])
def
test_condtake
(
is_symbolic
=
None
):
shapes
=
[
(
3
,
3
,
3
),
(
0
,),
(
3
,
0
,
3
),
]
def
fn
(
mask
,
data
):
return
F
.
cond_take
(
mask
,
data
)
if
is_symbolic
is
not
None
:
fn
=
jit
.
trace
(
symbolic
=
is_symbolic
)(
fn
)
for
shp
in
shapes
:
x_np
=
np
.
random
.
randn
(
*
shp
).
astype
(
"float32"
)
mask_np
=
x_np
>
0
x
=
tensor
(
x_np
)
mask
=
tensor
(
mask_np
)
ref_out
=
x_np
[
mask_np
]
ref_idx
=
mask_np
.
flatten
().
nonzero
()[
0
]
for
i
in
range
(
3
):
out
,
idx
=
fn
(
mask
,
x
)
np
.
testing
.
assert_equal
(
out
.
numpy
(),
ref_out
)
np
.
testing
.
assert_equal
(
idx
.
numpy
(),
ref_idx
)
if
is_symbolic
is
None
:
break
def
test_condtake_is_same
():
def
test_condtake_is_same
():
op1
=
builtin
.
CondTake
()
op1
=
builtin
.
CondTake
()
op2
=
builtin
.
CondTake
()
op2
=
builtin
.
CondTake
()
...
...
imperative/src/impl/ops/cond_take.cpp
浏览文件 @
a430c912
...
@@ -45,25 +45,30 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -45,25 +45,30 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
auto
&&
inp
=
inputs
[
0
];
auto
&&
inp
=
inputs
[
0
];
auto
&&
msk
=
inputs
[
1
];
auto
&&
msk
=
inputs
[
1
];
SmallVector
<
TensorPtr
>
out
;
mgb_assert
(
inp
->
layout
().
eq_shape
(
msk
->
layout
()),
mgb_assert
(
inp
->
layout
().
eq_shape
(
msk
->
layout
()),
"input shape does not match mask shape"
);
"input shape does not match mask shape"
);
mgb_assert
(
msk
->
get_value
().
dtype
().
enumv
()
==
DTypeEnum
::
Bool
,
mgb_assert
(
msk
->
get_value
().
dtype
().
enumv
()
==
DTypeEnum
::
Bool
,
"mask dtype must be bool"
);
"mask dtype must be bool"
);
DnnOprCaller
<
megdnn
::
CondTake
>
dnn_op
(
inp
->
comp_node
());
dnn_op
.
op
->
param
().
val
=
1
;
TensorLayout
m_layout
({
dnn_op
.
op
->
get_workspace_in_bytes
(
inp
->
layout
())},
dtype
::
Byte
());
auto
dnn_workspace
=
dnn_op
.
create_workspace
(
m_layout
);
MegDNNDynOutMallocImpl
<
2
>
policy
{
inp
->
comp_node
()};
MegDNNDynOutMallocImpl
<
2
>
policy
{
inp
->
comp_node
()};
if
(
inp
->
layout
().
is_empty
())
{
dnn_op
.
op
->
exec
(
inp
->
dev_tensor
().
as_megdnn
(),
// empty tensor
msk
->
dev_tensor
().
as_megdnn
(),
policy
.
alloc_output
(
0
,
inp
->
layout
().
dtype
,
{
0
},
nullptr
);
dnn_workspace
,
policy
.
alloc_output
(
1
,
dtype
::
Int32
(),
{
0
},
nullptr
);
&
policy
);
}
else
{
DnnOprCaller
<
megdnn
::
CondTake
>
dnn_op
(
inp
->
comp_node
());
SmallVector
<
TensorPtr
>
out
;
dnn_op
.
op
->
param
().
val
=
1
;
TensorLayout
m_layout
({
dnn_op
.
op
->
get_workspace_in_bytes
(
inp
->
layout
())},
dtype
::
Byte
());
auto
dnn_workspace
=
dnn_op
.
create_workspace
(
m_layout
);
dnn_op
.
op
->
exec
(
inp
->
dev_tensor
().
as_megdnn
(),
msk
->
dev_tensor
().
as_megdnn
(),
dnn_workspace
,
&
policy
);
}
out
.
push_back
(
policy
.
at
(
0
));
out
.
push_back
(
policy
.
at
(
0
));
out
.
push_back
(
policy
.
at
(
1
));
out
.
push_back
(
policy
.
at
(
1
));
return
out
;
return
out
;
...
...
src/opr/impl/misc.cpp
浏览文件 @
a430c912
...
@@ -264,6 +264,15 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
...
@@ -264,6 +264,15 @@ CondTake::CondTake(VarNode *data, VarNode *mask,
}
}
}
}
CondTake
::
NodeProp
*
CondTake
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
ret
->
add_dep_type_existing_var
(
input
(
1
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
#if MGB_ENABLE_GRAD
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
CondTake
)
{
MGB_IMPL_OPR_GRAD
(
CondTake
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
!
out_grad
[
2
]);
mgb_assert
(
out_grad
.
size
()
==
3
&&
!
out_grad
[
2
]);
...
@@ -305,11 +314,21 @@ void CondTake::add_input_layout_constraint() {
...
@@ -305,11 +314,21 @@ void CondTake::add_input_layout_constraint() {
}
}
void
CondTake
::
scn_do_execute
()
{
void
CondTake
::
scn_do_execute
()
{
auto
&&
data
=
input
(
0
)
->
dev_tensor
();
auto
&&
mask
=
input
(
1
)
->
dev_tensor
();
intl
::
MegDNNDynOutMallocImpl
dyn_malloc
{
this
,
comp_node
()};
intl
::
MegDNNDynOutMallocImpl
dyn_malloc
{
this
,
comp_node
()};
megdnn_opr
()
->
exec
(
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
if
(
data
.
layout
().
is_empty
())
{
input
(
1
)
->
dev_tensor
().
as_megdnn
(),
mgb_assert
(
data
.
layout
().
eq_shape
(
mask
.
layout
()),
intl
::
get_megdnn_workspace_from_var
(
output
().
back
()),
"CondTake shape differs: data=%s mask=%s"
,
&
dyn_malloc
);
data
.
layout
().
TensorShape
::
to_string
().
c_str
(),
mask
.
layout
().
TensorShape
::
to_string
().
c_str
());
dyn_malloc
.
alloc_output
(
0
,
data
.
layout
().
dtype
,
{
0
},
nullptr
);
dyn_malloc
.
alloc_output
(
1
,
dtype
::
Int32
(),
{
0
},
nullptr
);
}
else
{
megdnn_opr
()
->
exec
(
data
.
as_megdnn
(),
mask
.
as_megdnn
(),
intl
::
get_megdnn_workspace_from_var
(
output
().
back
()),
&
dyn_malloc
);
}
}
}
/* ================= TopK ================= */
/* ================= TopK ================= */
...
...
src/opr/include/megbrain/opr/misc.h
浏览文件 @
a430c912
...
@@ -151,6 +151,7 @@ MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // {
...
@@ -151,6 +151,7 @@ MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // {
void
init_output_static_infer_desc
()
override
;
void
init_output_static_infer_desc
()
override
;
void
scn_do_execute
()
override
;
void
scn_do_execute
()
override
;
void
add_input_layout_constraint
()
override
;
void
add_input_layout_constraint
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
public
:
public
:
CondTake
(
VarNode
*
data
,
VarNode
*
mask
,
CondTake
(
VarNode
*
data
,
VarNode
*
mask
,
...
...
src/opr/test/misc.cpp
浏览文件 @
a430c912
...
@@ -256,20 +256,25 @@ TEST(TestOprMisc, CondTake) {
...
@@ -256,20 +256,25 @@ TEST(TestOprMisc, CondTake) {
run
(
mki
({
100
}));
run
(
mki
({
100
}));
}
}
TEST
(
TestOprMisc
,
CondTakeEmpty
Out
)
{
TEST
(
TestOprMisc
,
CondTakeEmpty
IO
)
{
using
Param
=
opr
::
CondTake
::
Param
;
using
Param
=
opr
::
CondTake
::
Param
;
HostTensorGenerator
<>
gen
;
HostTensorGenerator
<>
gen
;
auto
host_x
=
gen
({
1
});
auto
check
=
[
&
](
const
TensorShape
&
shp
)
{
host_x
->
ptr
<
float
>
()[
0
]
=
1
;
auto
host_x
=
gen
(
shp
);
auto
graph
=
ComputingGraph
::
make
();
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
);
auto
out
=
opr
::
CondTake
::
make
(
x
,
x
,
{
Param
::
Mode
::
LT
});
auto
y
=
x
+
1
;
HostTensorND
host_out0
,
host_out1
;
auto
out
=
opr
::
CondTake
::
make
(
x
,
y
,
{
Param
::
Mode
::
EQ
});
auto
func
=
graph
->
compile
({
make_callback_copy
(
out
[
0
],
host_out0
),
HostTensorND
host_out0
,
host_out1
;
make_callback_copy
(
out
[
1
],
host_out1
)});
auto
func
=
graph
->
compile
({
make_callback_copy
(
out
[
0
],
host_out0
),
func
->
execute
();
make_callback_copy
(
out
[
1
],
host_out1
)});
ASSERT_EQ
(
TensorShape
{
0
},
host_out0
.
shape
());
func
->
execute
();
ASSERT_EQ
(
TensorShape
{
0
},
host_out1
.
shape
());
ASSERT_EQ
(
TensorShape
{
0
},
host_out0
.
shape
());
ASSERT_EQ
(
TensorShape
{
0
},
host_out1
.
shape
());
};
check
({
1
});
check
({
0
});
check
({
1
,
0
});
}
}
TEST
(
TestOprMisc
,
TopKValueOnly
)
{
TEST
(
TestOprMisc
,
TopKValueOnly
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录