Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
1a1748da
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
1a1748da
编写于
9月 09, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(opr): let Argsort support empty IO
GitOrigin-RevId: 05fcac6e472e9d7e868516c210c96d8d2987b5dc
上级
7234efe1
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
72 addition
and
5 deletion
+72
-5
imperative/python/test/unit/functional/test_math.py
imperative/python/test/unit/functional/test_math.py
+30
-4
src/opr/impl/misc.cpp
src/opr/impl/misc.cpp
+36
-1
src/opr/include/megbrain/opr/misc.h
src/opr/include/megbrain/opr/misc.h
+6
-0
未找到文件。
imperative/python/test/unit/functional/test_math.py
浏览文件 @
1a1748da
...
...
@@ -110,16 +110,42 @@ def test_sort():
data2_shape
=
(
12
,
2
)
data1
=
np
.
random
.
random
(
data1_shape
).
astype
(
np
.
float32
)
data2
=
np
.
random
.
random
(
data2_shape
).
astype
(
np
.
float32
)
output
0
=
[
np
.
sort
(
data1
),
np
.
argsort
(
data1
).
astype
(
np
.
int32
)]
output
1
=
[
np
.
sort
(
data2
),
np
.
argsort
(
data2
).
astype
(
np
.
int32
)]
output
1
=
[
np
.
sort
(
data1
),
np
.
argsort
(
data1
).
astype
(
np
.
int32
)]
output
2
=
[
np
.
sort
(
data2
),
np
.
argsort
(
data2
).
astype
(
np
.
int32
)]
cases
=
[
{
"input"
:
data1
,
"output"
:
output
0
},
{
"input"
:
data2
,
"output"
:
output
1
},
{
"input"
:
data1
,
"output"
:
output
1
},
{
"input"
:
data2
,
"output"
:
output
2
},
]
opr_test
(
cases
,
F
.
sort
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
,
False
,
True
])
def
test_sort_empty
(
is_symbolic
):
data_shapes
=
[
(
0
,),
(
10
,
0
),
]
def
fn
(
x
):
return
F
.
sort
(
x
)
for
shape
in
data_shapes
:
if
is_symbolic
is
not
None
:
fn_
=
jit
.
trace
(
symbolic
=
is_symbolic
)(
fn
)
else
:
fn_
=
fn
data
=
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
for
_
in
range
(
3
):
outs
=
fn_
(
tensor
(
data
))
ref_outs
=
(
np
.
sort
(
data
),
np
.
argsort
(
data
))
assert
len
(
ref_outs
)
==
len
(
outs
)
for
i
in
range
(
len
(
outs
)):
np
.
testing
.
assert_equal
(
outs
[
i
].
numpy
(),
ref_outs
[
i
])
if
is_symbolic
is
None
:
break
def
test_normalize
():
cases
=
[
...
...
src/opr/impl/misc.cpp
浏览文件 @
1a1748da
...
...
@@ -75,7 +75,16 @@ MEGDNN_OPR_INIT1(Argmin, "argmin")
/* ================= ArgsortForward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ArgsortForward
);
MEGDNN_OPR_CTOR_INIT1
(
ArgsortForward
,
"argsort"
)
// MEGDNN_OPR_CTOR_INIT1(ArgsortForward, "argsort")
ArgsortForward
::
ArgsortForward
(
VarNode
*
i0
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
:
Super
(
OperatorNodeBaseCtorParam
{
i0
->
owner_graph
(),
config
,
"argsort"
,
{
i0
}}
)
{
init_megdnn_opr
(
*
this
,
param
);
add_input
({
i0
});
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
// sorted value
output
(
1
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
// sorted index
intl
::
MegDNNOprInitPostCtor
<
ArgsortForward
>::
apply
(
*
this
);
}
std
::
array
<
SymbolVar
,
2
>
ArgsortForward
::
make
(
SymbolVar
in_tensor
,
const
Param
&
param
,
...
...
@@ -87,6 +96,32 @@ std::array<SymbolVar, 2> ArgsortForward::make(
return
{
node
->
output
(
0
),
node
->
output
(
1
)};
}
void
ArgsortForward
::
scn_do_execute
()
{
if
(
input
(
0
)
->
dev_tensor
().
empty
())
{
mgb_assert
(
output
(
0
)
->
dev_tensor
().
empty
()
&&
output
(
1
)
->
dev_tensor
().
empty
());
return
;
}
mgb_assert
(
!
output
(
0
)
->
dev_tensor
().
empty
()
&&
!
output
(
1
)
->
dev_tensor
().
empty
());
Super
::
scn_do_execute
();
}
void
ArgsortForward
::
get_output_var_shape
(
const
TensorShapeArray
&
inp_shape
,
TensorShapeArray
&
out_shape
)
const
{
mgb_assert
(
inp_shape
.
size
()
==
1
&&
out_shape
.
size
()
==
2
);
out_shape
[
0
]
=
inp_shape
[
0
];
out_shape
[
1
]
=
inp_shape
[
0
];
}
ArgsortForward
::
NodeProp
*
ArgsortForward
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
ArgsortForward
)
{
mgb_assert
(
out_grad
.
size
()
==
3
&&
wrt_idx
==
0
&&
!
out_grad
[
2
]);
...
...
src/opr/include/megbrain/opr/misc.h
浏览文件 @
1a1748da
...
...
@@ -55,6 +55,12 @@ MGB_DEFINE_OPR_CLASS(Argmin,
*/
MGB_DEFINE_OPR_CLASS
(
ArgsortForward
,
intl
::
MegDNNOprWrapperFwd
<
megdnn
::
ArgsortForward
>
)
// {
protected
:
NodeProp
*
do_make_node_prop
()
const
override
;
void
scn_do_execute
()
override
;
void
get_output_var_shape
(
const
TensorShapeArray
&
inp_shape
,
TensorShapeArray
&
out_shape
)
const
override
;
public
:
ArgsortForward
(
VarNode
*
in_tensor
,
const
Param
&
param
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录