Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f2e1bb41
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f2e1bb41
编写于
5月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): let more indexing ops support empty shape
GitOrigin-RevId: db4eba5877293cb1865801e315f53026527f6c6f
上级
a4879fc6
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
145 addition
and
29 deletion
+145
-29
src/opr/impl/indexing.cpp
src/opr/impl/indexing.cpp
+34
-12
src/opr/include/megbrain/opr/indexing.h
src/opr/include/megbrain/opr/indexing.h
+4
-1
src/opr/test/basic_arith/elemwise.cpp
src/opr/test/basic_arith/elemwise.cpp
+5
-16
src/opr/test/indexing.cpp
src/opr/test/indexing.cpp
+87
-0
test/src/helper.cpp
test/src/helper.cpp
+10
-0
test/src/include/megbrain/test/helper.h
test/src/include/megbrain/test/helper.h
+5
-0
未找到文件。
src/opr/impl/indexing.cpp
浏览文件 @
f2e1bb41
...
@@ -226,17 +226,19 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr(
...
@@ -226,17 +226,19 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr(
}
}
/* ==================== MultiAxisVecFancyIndexingHelper ==================== */
/* ==================== MultiAxisVecFancyIndexingHelper ==================== */
const
megdnn
::
IndexingMultiAxisVec
::
IndexDesc
&
std
::
pair
<
const
megdnn
::
IndexingMultiAxisVec
::
IndexDesc
&
,
bool
>
intl
::
MultiAxisVecFancyIndexingHelper
::
make_megdnn_index_desc
(
intl
::
MultiAxisVecFancyIndexingHelper
::
make_megdnn_index_desc
(
size_t
inp_ndim
,
bool
warn_all_scalar
)
{
size_t
inp_ndim
,
bool
warn_all_scalar
)
{
auto
&&
index
=
m_megdnn_index_cache
;
auto
&&
index
=
m_megdnn_index_cache
;
index
.
clear
();
index
.
clear
();
bool
is_empty_shape
=
false
;
for
(
auto
i
:
reverse_adaptor
(
m_input2idxonly_axis_indexer
))
{
for
(
auto
i
:
reverse_adaptor
(
m_input2idxonly_axis_indexer
))
{
if
(
i
)
{
if
(
i
)
{
index
.
push_back
({
index
.
push_back
({
i
->
axis
.
get
(
inp_ndim
),
i
->
axis
.
get
(
inp_ndim
),
i
->
idx
.
node
()
->
dev_tensor
().
as_megdnn
()});
i
->
idx
.
node
()
->
dev_tensor
().
as_megdnn
()});
is_empty_shape
|=
index
.
back
().
vec
.
layout
.
is_empty
();
}
}
}
}
...
@@ -264,7 +266,7 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
...
@@ -264,7 +266,7 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
m_scalar_idx_warn_printed
=
true
;
m_scalar_idx_warn_printed
=
true
;
}
}
return
index
;
return
{
index
,
is_empty_shape
}
;
}
}
/* ==================== IndexingMultiAxisVecBase ==================== */
/* ==================== IndexingMultiAxisVecBase ==================== */
...
@@ -272,6 +274,8 @@ template<class Opr>
...
@@ -272,6 +274,8 @@ template<class Opr>
cg
::
OperatorNodeBase
::
NodeProp
*
cg
::
OperatorNodeBase
::
NodeProp
*
IndexingMultiAxisVecBase
<
Opr
>::
do_make_node_prop
()
const
{
IndexingMultiAxisVecBase
<
Opr
>::
do_make_node_prop
()
const
{
auto
prop
=
Super
::
do_make_node_prop
();
auto
prop
=
Super
::
do_make_node_prop
();
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
for
(
auto
i
:
m_input2idxonly_axis_indexer
)
{
for
(
auto
i
:
m_input2idxonly_axis_indexer
)
{
if
(
i
)
{
if
(
i
)
{
prop
->
add_dep_type_existing_var
(
prop
->
add_dep_type_existing_var
(
...
@@ -360,13 +364,13 @@ void IndexingMultiAxisVecBase<Opr>::scn_do_execute() {
...
@@ -360,13 +364,13 @@ void IndexingMultiAxisVecBase<Opr>::scn_do_execute() {
auto
&&
index_desc
=
make_megdnn_index_desc
(
auto
&&
index_desc
=
make_megdnn_index_desc
(
inp
.
layout
().
ndim
,
ShouldWarnOnScalarIndexer
<
Opr
>::
val
);
inp
.
layout
().
ndim
,
ShouldWarnOnScalarIndexer
<
Opr
>::
val
);
auto
&&
odev
=
output
(
0
)
->
dev_tensor
();
auto
&&
odev
=
output
(
0
)
->
dev_tensor
();
if
(
index_desc
.
empty
())
{
if
(
index_desc
.
first
.
empty
())
{
odev
.
copy_from_fixlayout
(
inp
);
odev
.
copy_from_fixlayout
(
inp
);
}
else
{
}
else
{
if
(
index_desc
[
0
].
vec
.
layout
[
0
]
)
{
if
(
!
index_desc
.
second
)
{
// only call megdnn exec if result is not empty
// only call megdnn exec if result is not empty
this
->
megdnn_opr
(
*
this
).
exec
(
this
->
megdnn_opr
(
*
this
).
exec
(
inp
.
as_megdnn
(),
index_desc
,
odev
.
as_megdnn
(),
inp
.
as_megdnn
(),
index_desc
.
first
,
odev
.
as_megdnn
(),
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
else
{
}
else
{
mgb_assert
(
odev
.
empty
());
mgb_assert
(
odev
.
empty
());
...
@@ -391,7 +395,11 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
...
@@ -391,7 +395,11 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
auto
inp
=
this
->
fancy_indexing_get_tensors_for_modify_in_scn_do_execute
();
auto
inp
=
this
->
fancy_indexing_get_tensors_for_modify_in_scn_do_execute
();
auto
index_desc
=
this
->
make_megdnn_index_desc
(
auto
index_desc
=
this
->
make_megdnn_index_desc
(
inp
.
first
.
layout
().
ndim
,
ShouldWarnOnScalarIndexer
<
Opr
>::
val
);
inp
.
first
.
layout
().
ndim
,
ShouldWarnOnScalarIndexer
<
Opr
>::
val
);
if
(
index_desc
.
empty
())
{
if
(
index_desc
.
second
){
mgb_assert
(
inp
.
second
.
shape
().
is_empty
());
return
;
}
if
(
index_desc
.
first
.
empty
())
{
using
IMT
=
IndexingModifyType
;
using
IMT
=
IndexingModifyType
;
static
constexpr
auto
modify_type
=
static
constexpr
auto
modify_type
=
IndexingModifyTypeGetter
<
Opr
>::
value
;
IndexingModifyTypeGetter
<
Opr
>::
value
;
...
@@ -410,11 +418,28 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
...
@@ -410,11 +418,28 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
}
else
{
}
else
{
this
->
megdnn_opr
(
*
this
).
exec
(
this
->
megdnn_opr
(
*
this
).
exec
(
inp
.
first
.
as_megdnn
(),
inp
.
second
.
as_megdnn
(),
inp
.
first
.
as_megdnn
(),
inp
.
second
.
as_megdnn
(),
index_desc
,
index_desc
.
first
,
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
}
}
}
template
<
class
Opr
>
cg
::
OperatorNodeBase
::
NodeProp
*
intl
::
IndexingModifyMultiAxisVecHelper
<
Opr
>::
do_make_node_prop
()
const
{
auto
prop
=
Super
::
do_make_node_prop
();
using
DT
=
NodeProp
::
DepType
;
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
prop
->
add_dep_type_existing_var
(
input
(
1
),
DT
::
VALUE_ALLOW_EMPTY
);
for
(
auto
i
:
m_input2idxonly_axis_indexer
)
{
if
(
i
)
{
prop
->
add_dep_type_existing_var
(
i
->
idx
.
node
(),
DT
::
VALUE_ALLOW_EMPTY
);
}
}
return
prop
;
}
template
<
class
Opr
>
template
<
class
Opr
>
void
intl
::
IndexingModifyMultiAxisVecHelper
<
Opr
>::
void
intl
::
IndexingModifyMultiAxisVecHelper
<
Opr
>::
add_input_layout_constraint
()
{
add_input_layout_constraint
()
{
...
@@ -429,7 +454,6 @@ add_input_layout_constraint() {
...
@@ -429,7 +454,6 @@ add_input_layout_constraint() {
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
IndexingMultiAxisVec
,
"indexing_multi_axis_vec"
,
false
,
IndexingMultiAxisVec
,
"indexing_multi_axis_vec"
,
false
,
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
output
(
1
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
);
);
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
IndexingSetMultiAxisVec
,
"indexing_set_multi_axis_vec"
,
false
);
IndexingSetMultiAxisVec
,
"indexing_set_multi_axis_vec"
,
false
);
...
@@ -469,12 +493,10 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
...
@@ -469,12 +493,10 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
MeshIndexing
,
"mesh_indexing"
,
false
,
MeshIndexing
,
"mesh_indexing"
,
false
,
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
););
output
(
1
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
););
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
MGB_IMPL_FANCY_INDEXING_OPR_GET
(
BatchedMeshIndexing
,
"batched_mesh_indexing"
,
false
,
BatchedMeshIndexing
,
"batched_mesh_indexing"
,
false
,
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
););
output
(
1
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
););
MGB_IMPL_OPR_GRAD
(
MeshIndexing
)
{
MGB_IMPL_OPR_GRAD
(
MeshIndexing
)
{
if
(
wrt_idx
!=
0
)
{
if
(
wrt_idx
!=
0
)
{
...
...
src/opr/include/megbrain/opr/indexing.h
浏览文件 @
f2e1bb41
...
@@ -117,7 +117,9 @@ namespace intl {
...
@@ -117,7 +117,9 @@ namespace intl {
protected:
protected:
using
Super
::
Super
;
using
Super
::
Super
;
const
megdnn
::
IndexingMultiAxisVec
::
IndexDesc
&
//! return IndexDesc and whether it has an AxisIndexer with
//! empty shape
std
::
pair
<
const
megdnn
::
IndexingMultiAxisVec
::
IndexDesc
&
,
bool
>
make_megdnn_index_desc
(
make_megdnn_index_desc
(
size_t
inp_ndim
,
bool
warn_all_scalar
=
true
);
size_t
inp_ndim
,
bool
warn_all_scalar
=
true
);
};
};
...
@@ -130,6 +132,7 @@ namespace intl {
...
@@ -130,6 +132,7 @@ namespace intl {
void
init_output_static_infer_desc
()
override
final
;
void
init_output_static_infer_desc
()
override
final
;
void
scn_do_execute
()
override
final
;
void
scn_do_execute
()
override
final
;
NodeProp
*
do_make_node_prop
()
const
override
;
void
add_input_layout_constraint
()
override
final
;
void
add_input_layout_constraint
()
override
final
;
protected
:
protected
:
...
...
src/opr/test/basic_arith/elemwise.cpp
浏览文件 @
f2e1bb41
...
@@ -649,17 +649,6 @@ namespace {
...
@@ -649,17 +649,6 @@ namespace {
>
TernaryTraitTypes
;
>
TernaryTraitTypes
;
TYPED_TEST_CASE
(
TestOprBasicArithTernaryElemwise
,
TernaryTraitTypes
);
TYPED_TEST_CASE
(
TestOprBasicArithTernaryElemwise
,
TernaryTraitTypes
);
::
testing
::
AssertionResult
assert_shape_equal
(
const
TensorShape
&
v0
,
const
TensorShape
&
v1
)
{
if
(
v0
.
eq_shape
(
v1
))
return
::
testing
::
AssertionSuccess
()
<<
v0
.
to_string
()
<<
" == "
<<
v1
.
to_string
();
else
return
::
testing
::
AssertionFailure
()
<<
v0
.
to_string
()
<<
" != "
<<
v1
.
to_string
();
}
#define ASSERT_SHAPE_EQ(v0, v1) ASSERT_TRUE(assert_shape_equal(v0, v1))
}
// anonymous namespace
}
// anonymous namespace
template
<
typename
Trait
,
typename
dtype
>
template
<
typename
Trait
,
typename
dtype
>
...
@@ -974,7 +963,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) {
...
@@ -974,7 +963,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) {
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_TRUE
(
host_y
.
empty
());
ASSERT_TRUE
(
host_y
.
empty
());
ASSERT_TRUE
(
host_y
.
shape
().
is_empty
());
ASSERT_TRUE
(
host_y
.
shape
().
is_empty
());
ASSERT_SHAPE_EQ
(
host_y
.
shape
(),
TensorShape
({
3
,
0
,
1
,
3
}));
MGB_
ASSERT_SHAPE_EQ
(
host_y
.
shape
(),
TensorShape
({
3
,
0
,
1
,
3
}));
}
}
TEST
(
TestOprBasicArithElemwise
,
EmptyInputOutputBinary
)
{
TEST
(
TestOprBasicArithElemwise
,
EmptyInputOutputBinary
)
{
...
@@ -997,14 +986,14 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
...
@@ -997,14 +986,14 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_TRUE
(
host_z
.
empty
());
ASSERT_TRUE
(
host_z
.
empty
());
ASSERT_TRUE
(
host_z
.
shape
().
is_empty
());
ASSERT_TRUE
(
host_z
.
shape
().
is_empty
());
ASSERT_SHAPE_EQ
(
host_z
.
shape
(),
TensorShape
({
0
,
8
,
0
,
7
}));
MGB_
ASSERT_SHAPE_EQ
(
host_z
.
shape
(),
TensorShape
({
0
,
8
,
0
,
7
}));
// Broadcast to 0 (2)
// Broadcast to 0 (2)
host_y
->
resize
({
2
,
8
,
1
,
7
});
host_y
->
resize
({
2
,
8
,
1
,
7
});
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_TRUE
(
host_z
.
empty
());
ASSERT_TRUE
(
host_z
.
empty
());
ASSERT_TRUE
(
host_z
.
shape
().
is_empty
());
ASSERT_TRUE
(
host_z
.
shape
().
is_empty
());
ASSERT_SHAPE_EQ
(
host_z
.
shape
(),
TensorShape
({
0
,
8
,
1
,
7
}));
MGB_
ASSERT_SHAPE_EQ
(
host_z
.
shape
(),
TensorShape
({
0
,
8
,
1
,
7
}));
// Scalar broadcast
// Scalar broadcast
z
=
x
+
x
.
make_scalar
(
1.
f
);
z
=
x
+
x
.
make_scalar
(
1.
f
);
...
@@ -1012,7 +1001,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
...
@@ -1012,7 +1001,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_TRUE
(
host_z
.
empty
());
ASSERT_TRUE
(
host_z
.
empty
());
ASSERT_TRUE
(
host_z
.
shape
().
is_empty
());
ASSERT_TRUE
(
host_z
.
shape
().
is_empty
());
ASSERT_SHAPE_EQ
(
host_z
.
shape
(),
TensorShape
({
0
,
8
,
1
,
7
}));
MGB_
ASSERT_SHAPE_EQ
(
host_z
.
shape
(),
TensorShape
({
0
,
8
,
1
,
7
}));
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
src/opr/test/indexing.cpp
浏览文件 @
f2e1bb41
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/utility.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megbrain/test/megdnn_helper.h"
...
@@ -1195,6 +1196,92 @@ TEST(TestOprIndexing, SetMeshIndexing) {
...
@@ -1195,6 +1196,92 @@ TEST(TestOprIndexing, SetMeshIndexing) {
}
}
}
}
namespace
{
template
<
class
Opr
>
void
run_multi_axis_vec_empty_shape
(
const
TensorShape
&
ishp
,
const
TensorShape
&
idx0
,
const
TensorShape
&
idx1
,
const
TensorShape
&
tshp
)
{
mgb_assert
(
ishp
.
ndim
>=
4
);
mgb_assert
(
idx0
.
is_empty
()
||
idx1
.
is_empty
());
using
AI
=
opr
::
indexing
::
AxisIndexer
;
auto
graph
=
ComputingGraph
::
make
();
HostTensorGenerator
<>
gen
;
HostTensorGenerator
<
dtype
::
Int32
>
gen_idx
;
auto
host_x
=
gen
(
ishp
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
idx_dynamic_shape
=
opr
::
MarkDynamicVar
::
make
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
*
gen_idx
(
idx0
))),
idx_static_shape
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
*
gen_idx
(
idx1
)),
y
=
Opr
::
make
(
x
,
{
AI
::
make_interval
(
-
1
,
None
,
None
,
x
.
make_scalar
(
2
)),
AI
::
make_index
(
1
,
idx_dynamic_shape
),
AI
::
make_index
(
2
,
idx_static_shape
)});
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
func
->
execute
();
ASSERT_TRUE
(
host_y
.
shape
().
is_empty
());
MGB_ASSERT_SHAPE_EQ
(
host_y
.
shape
(),
tshp
);
}
template
<
class
Opr
>
void
run_modify_multi_axis_vec_empty_shape
(
const
TensorShape
&
ishp
,
const
TensorShape
&
vshp
,
const
TensorShape
&
idx0
,
const
TensorShape
&
idx1
)
{
mgb_assert
(
ishp
.
ndim
>=
4
);
mgb_assert
(
vshp
.
is_empty
()
&&
(
idx0
.
is_empty
()
||
idx1
.
is_empty
()));
using
AI
=
opr
::
indexing
::
AxisIndexer
;
auto
graph
=
ComputingGraph
::
make
();
HostTensorGenerator
<>
gen
;
HostTensorGenerator
<
dtype
::
Int32
>
gen_idx
;
auto
host_x
=
gen
(
ishp
),
host_v
=
gen
(
vshp
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
v
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_v
),
idx_dynamic_shape
=
opr
::
MarkDynamicVar
::
make
(
opr
::
ImmutableTensor
::
make
(
*
graph
,
*
gen_idx
(
idx0
))),
idx_static_shape
=
opr
::
ImmutableTensor
::
make
(
*
graph
,
*
gen_idx
(
idx1
)),
y
=
Opr
::
make
(
x
,
v
,
{
AI
::
make_interval
(
-
1
,
None
,
None
,
x
.
make_scalar
(
2
)),
AI
::
make_index
(
1
,
idx_dynamic_shape
),
AI
::
make_index
(
2
,
idx_static_shape
)});
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
func
->
execute
();
MGB_ASSERT_TENSOR_EQ
(
*
host_x
,
host_y
);
}
}
TEST
(
TestOprIndexing
,
MultiAxisVecEmptyShape
)
{
TensorShape
ishp
{
8
,
2
,
3
,
4
};
size_t
n
=
ishp
[
0
],
last_ndim
=
ishp
[
ishp
.
ndim
-
1
]
/
2
;
run_multi_axis_vec_empty_shape
<
opr
::
IndexingMultiAxisVec
>
(
ishp
,
{
0
},
{
0
},
{
n
,
0
,
last_ndim
});
run_multi_axis_vec_empty_shape
<
opr
::
MeshIndexing
>
(
ishp
,
{
0
},
{
2
},
{
n
,
0
,
2
,
last_ndim
});
run_multi_axis_vec_empty_shape
<
opr
::
MeshIndexing
>
(
ishp
,
{
3
},
{
0
},
{
n
,
3
,
0
,
last_ndim
});
run_multi_axis_vec_empty_shape
<
opr
::
BatchedMeshIndexing
>
(
ishp
,
{
n
,
0
},
{
n
,
2
},
{
n
,
0
,
2
,
last_ndim
});
run_multi_axis_vec_empty_shape
<
opr
::
BatchedMeshIndexing
>
(
ishp
,
{
n
,
4
},
{
n
,
0
},
{
n
,
4
,
0
,
last_ndim
});
run_modify_multi_axis_vec_empty_shape
<
opr
::
IndexingIncrMultiAxisVec
>
(
ishp
,
{
n
,
0
,
last_ndim
},
{
0
},
{
0
});
run_modify_multi_axis_vec_empty_shape
<
opr
::
IndexingSetMultiAxisVec
>
(
ishp
,
{
n
,
0
,
last_ndim
},
{
0
},
{
0
});
run_modify_multi_axis_vec_empty_shape
<
opr
::
IncrMeshIndexing
>
(
ishp
,
{
n
,
0
,
2
,
last_ndim
},
{
0
},
{
2
});
run_modify_multi_axis_vec_empty_shape
<
opr
::
SetMeshIndexing
>
(
ishp
,
{
n
,
3
,
0
,
last_ndim
},
{
3
},
{
0
});
run_modify_multi_axis_vec_empty_shape
<
opr
::
BatchedIncrMeshIndexing
>
(
ishp
,
{
n
,
4
,
0
,
last_ndim
},
{
n
,
4
},
{
n
,
0
});
run_modify_multi_axis_vec_empty_shape
<
opr
::
BatchedSetMeshIndexing
>
(
ishp
,
{
n
,
0
,
5
,
last_ndim
},
{
n
,
0
},
{
n
,
5
});
}
#endif // MGB_ENABLE_EXCEPTION
#endif // MGB_ENABLE_EXCEPTION
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
test/src/helper.cpp
浏览文件 @
f2e1bb41
...
@@ -158,6 +158,16 @@ namespace mgb {
...
@@ -158,6 +158,16 @@ namespace mgb {
return
::
testing
::
AssertionSuccess
();
return
::
testing
::
AssertionSuccess
();
}
}
::
testing
::
AssertionResult
mgb
::
__assert_shape_equal
(
const
TensorShape
&
v0
,
const
TensorShape
&
v1
)
{
if
(
v0
.
eq_shape
(
v1
))
return
::
testing
::
AssertionSuccess
()
<<
v0
.
to_string
()
<<
" == "
<<
v1
.
to_string
();
else
return
::
testing
::
AssertionFailure
()
<<
v0
.
to_string
()
<<
" != "
<<
v1
.
to_string
();
}
#if WIN32
#if WIN32
#include <io.h>
#include <io.h>
#include <fcntl.h>
#include <fcntl.h>
...
...
test/src/include/megbrain/test/helper.h
浏览文件 @
f2e1bb41
...
@@ -133,6 +133,11 @@ decltype(auto) container_to_vector(Container &&ct) {
...
@@ -133,6 +133,11 @@ decltype(auto) container_to_vector(Container &&ct) {
#define MGB_ASSERT_TENSOR_EQ(v0, v1) \
#define MGB_ASSERT_TENSOR_EQ(v0, v1) \
MGB_ASSERT_TENSOR_NEAR(v0, v1, 1e-6)
MGB_ASSERT_TENSOR_NEAR(v0, v1, 1e-6)
::
testing
::
AssertionResult
__assert_shape_equal
(
const
TensorShape
&
v0
,
const
TensorShape
&
v1
);
#define MGB_ASSERT_SHAPE_EQ(v0, v1) \
ASSERT_TRUE(::mgb::__assert_shape_equal(v0, v1))
/*!
/*!
* \brief xorshift+ RNG, which is very fast
* \brief xorshift+ RNG, which is very fast
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录