Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b74afde8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b74afde8
编写于
7月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): let reduce support empty IO
GitOrigin-RevId: 88b37123a8fa7f7dafbb1b0c506fb79f1e5a24c4
上级
1af350c6
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
143 addition
and
6 deletion
+143
-6
imperative/python/test/unit/functional/test_math.py
imperative/python/test/unit/functional/test_math.py
+26
-1
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+5
-0
src/core/test/graph/misc.cpp
src/core/test/graph/misc.cpp
+2
-4
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+52
-1
src/opr/include/megbrain/opr/basic_arith.h
src/opr/include/megbrain/opr/basic_arith.h
+1
-0
src/opr/test/basic_arith/reduction.cpp
src/opr/test/basic_arith/reduction.cpp
+57
-0
未找到文件。
imperative/python/test/unit/functional/test_math.py
浏览文件 @
b74afde8
...
@@ -13,7 +13,7 @@ import pytest
...
@@ -13,7 +13,7 @@ import pytest
from
utils
import
opr_test
from
utils
import
opr_test
import
megengine.functional
as
F
import
megengine.functional
as
F
from
megengine
import
tensor
from
megengine
import
jit
,
tensor
def
common_test_reduce
(
opr
,
ref_opr
):
def
common_test_reduce
(
opr
,
ref_opr
):
...
@@ -204,3 +204,28 @@ def test_topk(descending, sorted, inp1d, kth_only):
...
@@ -204,3 +204,28 @@ def test_topk(descending, sorted, inp1d, kth_only):
if
not
sorted
:
if
not
sorted
:
values
=
np_sort
(
values
)
values
=
np_sort
(
values
)
np
.
testing
.
assert_equal
(
values
,
np_sort
(
data
)[...,
:
k
])
np
.
testing
.
assert_equal
(
values
,
np_sort
(
data
)[...,
:
k
])
@
pytest
.
mark
.
parametrize
(
"is_trace"
,
[
True
,
False
])
def
test_reduce_on_empty_tensor
(
is_trace
):
dtypes
=
[
np
.
float32
,
np
.
int32
,
np
.
bool
]
inputs
=
[
(
np
.
random
.
random
((
0
,)),
None
),
(
np
.
random
.
random
((
3
,
0
,
2
)),
1
),
(
np
.
random
.
random
((
10
,
10
,
0
,
10
)),
0
),
]
def
run_test
(
fn
,
ref_fn
,
input
,
dtype
,
axis
=
None
,
symbolic
=
False
):
if
is_trace
:
fn
=
jit
.
trace
(
symbolic
=
symbolic
)(
fn
)
for
i
in
range
(
3
):
out
=
fn
(
tensor
(
input
,
dtype
=
dtype
),
axis
=
axis
).
numpy
()
out_ref
=
ref_fn
(
input
.
astype
(
dtype
),
axis
=
axis
)
np
.
testing
.
assert_equal
(
out
,
out_ref
)
for
dtype
in
dtypes
:
for
inp
,
axis
in
inputs
:
run_test
(
F
.
sum
,
np
.
sum
,
inp
,
dtype
,
axis
,
True
)
run_test
(
F
.
sum
,
np
.
sum
,
inp
,
dtype
,
axis
,
False
)
run_test
(
F
.
prod
,
np
.
prod
,
inp
,
dtype
,
axis
,
True
)
run_test
(
F
.
prod
,
np
.
prod
,
inp
,
dtype
,
axis
,
False
)
imperative/src/impl/proxy_graph.cpp
浏览文件 @
b74afde8
...
@@ -84,6 +84,11 @@ public:
...
@@ -84,6 +84,11 @@ public:
auto
&&
dev_tensor
=
tensor
.
dev_tensor
();
auto
&&
dev_tensor
=
tensor
.
dev_tensor
();
var
->
m_comp_node
=
dev_tensor
.
comp_node
();
var
->
m_comp_node
=
dev_tensor
.
comp_node
();
var
->
m_shape
=
dev_tensor
.
shape
();
var
->
m_shape
=
dev_tensor
.
shape
();
if
(
dev_tensor
.
empty
())
{
auto
layout
=
dev_tensor
.
layout
();
layout
.
init_contiguous_stride
();
dev_tensor
.
reset
(
dev_tensor
.
storage
(),
layout
);
}
var
->
m_dev_tensor
=
dev_tensor
;
var
->
m_dev_tensor
=
dev_tensor
;
var
->
m_mem_plan
.
reset_from_owner_var
().
chunk
()
var
->
m_mem_plan
.
reset_from_owner_var
().
chunk
()
.
mem_alloc_status
.
set_from_owner_var
();
.
mem_alloc_status
.
set_from_owner_var
();
...
...
src/core/test/graph/misc.cpp
浏览文件 @
b74afde8
...
@@ -1364,7 +1364,7 @@ TEST(TestGraph, EmptyShapeCheck) {
...
@@ -1364,7 +1364,7 @@ TEST(TestGraph, EmptyShapeCheck) {
using
Param
=
opr
::
CondTake
::
Param
;
using
Param
=
opr
::
CondTake
::
Param
;
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
CondTake
::
make
(
x
,
x
,
{
Param
::
Mode
::
GT
})[
0
],
y
=
opr
::
CondTake
::
make
(
x
,
x
,
{
Param
::
Mode
::
GT
})[
0
],
z
=
opr
::
reduce_
sum
(
y
,
y
.
make_scalar
(
1
));
z
=
opr
::
reduce_
max
(
y
,
y
.
make_scalar
(
1
));
HostTensorND
host_z
;
HostTensorND
host_z
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
z
,
host_z
)});
auto
func
=
graph
->
compile
({
make_callback_copy
(
z
,
host_z
)});
func
->
execute
();
func
->
execute
();
...
@@ -1377,7 +1377,7 @@ TEST(TestGraph, EmptyShapeCheck) {
...
@@ -1377,7 +1377,7 @@ TEST(TestGraph, EmptyShapeCheck) {
func
->
execute
();
func
->
execute
();
}
catch
(
const
MegBrainError
&
exc
)
{
}
catch
(
const
MegBrainError
&
exc
)
{
std
::
string
msg
{
exc
.
what
()};
std
::
string
msg
{
exc
.
what
()};
ASSERT_TRUE
(
msg
.
find
(
"empty
output var
"
)
!=
ASSERT_TRUE
(
msg
.
find
(
"empty
input is not allowed
"
)
!=
std
::
string
::
npos
)
std
::
string
::
npos
)
<<
"bad message "
<<
msg
;
<<
"bad message "
<<
msg
;
throw
;
throw
;
...
@@ -2413,8 +2413,6 @@ TEST(TestMemReuse, ResetEmptyDevTensor) {
...
@@ -2413,8 +2413,6 @@ TEST(TestMemReuse, ResetEmptyDevTensor) {
y
=
opr
::
Reduce
::
make
(
x
,
{
opr
::
Reduce
::
Mode
::
MAX
,
0
});
y
=
opr
::
Reduce
::
make
(
x
,
{
opr
::
Reduce
::
Mode
::
MAX
,
0
});
HostTensorND
host_y
;
HostTensorND
host_y
;
auto
func
=
g
->
compile
({
make_callback_copy
(
y
,
host_y
)});
auto
func
=
g
->
compile
({
make_callback_copy
(
y
,
host_y
)});
auto
&&
recv
=
x
.
node
()
->
owner_graph
()
->
var_receiver_in_current_comp_seq
(
x
.
node
());
ASSERT_TRUE
(
!
recv
.
is_empty_allowed
());
if
(
inp_shp
.
is_empty
())
{
if
(
inp_shp
.
is_empty
())
{
ASSERT_ANY_THROW
(
func
->
execute
().
wait
());
ASSERT_ANY_THROW
(
func
->
execute
().
wait
());
}
else
{
}
else
{
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
b74afde8
...
@@ -1072,6 +1072,7 @@ class Reduce::KernScheduler {
...
@@ -1072,6 +1072,7 @@ class Reduce::KernScheduler {
m_apply_side_effect
;
m_apply_side_effect
;
std
::
unique_ptr
<
megdnn
::
Elemwise
>
m_elemwise_trans_opr
;
std
::
unique_ptr
<
megdnn
::
Elemwise
>
m_elemwise_trans_opr
;
std
::
unique_ptr
<
megdnn
::
TypeCvt
>
m_typecvt_opr
;
std
::
unique_ptr
<
megdnn
::
TypeCvt
>
m_typecvt_opr
;
std
::
unique_ptr
<
megdnn
::
Fill
>
m_fill_opr
;
DeviceTensorND
m_side_affect_wkspc
;
DeviceTensorND
m_side_affect_wkspc
;
};
};
...
@@ -1338,6 +1339,47 @@ void Reduce::KernScheduler::execute(
...
@@ -1338,6 +1339,47 @@ void Reduce::KernScheduler::execute(
}
}
mgb_assert
(
!
m_kern_param
.
empty
());
mgb_assert
(
!
m_kern_param
.
empty
());
// empty input
if
(
input
.
shape_valid
()
&&
input
.
empty
())
{
auto
mode
=
m_kern_param
[
0
].
kparam
.
mode
;
if
(
!
m_fill_opr
)
{
m_fill_opr
=
intl
::
get_megdnn_handle
(
dest
.
comp_node
())
->
create_operator
<
megdnn
::
Fill
>
();
}
std
::
string
err_msg
;
switch
(
mode
)
{
case
Reduce
::
Mode
::
SUM
:
if
(
!
dest
.
empty
())
{
m_fill_opr
->
param
()
=
0
;
m_fill_opr
->
exec
(
dest
.
as_megdnn
(),
{});
}
break
;
case
Reduce
::
Mode
::
PRODUCT
:
if
(
!
dest
.
empty
())
{
m_fill_opr
->
param
()
=
1
;
m_fill_opr
->
exec
(
dest
.
as_megdnn
(),
{});
}
break
;
case
Reduce
::
Mode
::
MEAN
:
err_msg
=
"mean"
;
break
;
case
Reduce
::
Mode
::
MIN
:
err_msg
=
"min"
;
break
;
case
Reduce
::
Mode
::
MAX
:
err_msg
=
"max"
;
break
;
case
Reduce
::
Mode
::
SUM_SQR
:
err_msg
=
"sum_sqr"
;
break
;
default:
mgb_throw
(
MegBrainError
,
"bad reduce mode"
);
}
if
(
!
err_msg
.
empty
())
{
mgb_throw
(
MegBrainError
,
"empty input is not allowed for reduce mode: %s"
,
err_msg
.
c_str
());
}
return
;
}
mgb_assert
(
input
.
layout
().
is_contiguous
()
&&
mgb_assert
(
input
.
layout
().
is_contiguous
()
&&
input
.
raw_ptr
()
==
m_kern_param
[
0
].
input
.
raw_ptr
&&
input
.
raw_ptr
()
==
m_kern_param
[
0
].
input
.
raw_ptr
&&
dest
.
raw_ptr
()
==
m_kern_param
.
back
().
output
.
raw_ptr
);
dest
.
raw_ptr
()
==
m_kern_param
.
back
().
output
.
raw_ptr
);
...
@@ -1425,7 +1467,9 @@ Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param ¶m,
...
@@ -1425,7 +1467,9 @@ Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param ¶m,
mgb_throw
(
GraphError
,
"invalid param data_type: %d"
,
mgb_throw
(
GraphError
,
"invalid param data_type: %d"
,
int
(
param
.
data_type
));
int
(
param
.
data_type
));
}
}
add_output
(
None
)
->
dtype
(
out_dtype
);
add_output
(
None
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
)
.
dtype
(
out_dtype
);
cg
::
add_workspace_output
(
this
);
cg
::
add_workspace_output
(
this
);
add_equivalence_component
<
PODHash
<
Param
>>
(
&
m_param
);
add_equivalence_component
<
PODHash
<
Param
>>
(
&
m_param
);
...
@@ -1703,6 +1747,13 @@ void Reduce::perform(
...
@@ -1703,6 +1747,13 @@ void Reduce::perform(
ksched
.
execute
(
opr
.
get
(),
*
input_contig
,
dest
);
ksched
.
execute
(
opr
.
get
(),
*
input_contig
,
dest
);
}
}
Reduce
::
NodeProp
*
Reduce
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
void
Reduce
::
create_megdnn_opr
()
{
void
Reduce
::
create_megdnn_opr
()
{
set_megdnn_opr
(
intl
::
get_megdnn_handle
(
comp_node
())
->
set_megdnn_opr
(
intl
::
get_megdnn_handle
(
comp_node
())
->
create_operator
<
megdnn
::
Reduce
>
());
create_operator
<
megdnn
::
Reduce
>
());
...
...
src/opr/include/megbrain/opr/basic_arith.h
浏览文件 @
b74afde8
...
@@ -335,6 +335,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic<
...
@@ -335,6 +335,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic<
void
add_input_layout_constraint
()
override
final
;
void
add_input_layout_constraint
()
override
final
;
void
scn_do_execute
()
override
final
;
void
scn_do_execute
()
override
final
;
void
init_output_static_infer_desc
()
override
final
;
void
init_output_static_infer_desc
()
override
final
;
NodeProp
*
do_make_node_prop
()
const
override
;
void
create_megdnn_opr
()
override
;
void
create_megdnn_opr
()
override
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
...
...
src/opr/test/basic_arith/reduction.cpp
浏览文件 @
b74afde8
...
@@ -900,4 +900,61 @@ TEST(TestBasicArithReduction, StaticInferValueDType) {
...
@@ -900,4 +900,61 @@ TEST(TestBasicArithReduction, StaticInferValueDType) {
run_test
(
F16
,
F16
,
ParamType
::
FLOAT_O16xC32
);
run_test
(
F16
,
F16
,
ParamType
::
FLOAT_O16xC32
);
}
}
TEST
(
TestBasicArithReduction
,
EmptyInput
)
{
using
Param
=
opr
::
Reduce
::
Param
;
using
Mode
=
opr
::
Reduce
::
Mode
;
auto
check_allow_empty
=
[](
const
Param
&
param
,
const
TensorShape
&
inpshp
,
double
target_val
)
{
HostTensorGenerator
<>
gen
;
auto
graph
=
ComputingGraph
::
make
();
auto
host_x
=
gen
(
inpshp
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
Reduce
::
make
(
x
,
param
,
{});
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
func
->
execute
().
wait
();
if
(
!
host_y
.
shape
().
is_empty
())
{
size_t
size
=
host_y
.
layout
().
total_nr_elems
();
#define cb(DType) \
if (host_y.layout().dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
auto ptr = host_y.ptr<ctype>(); \
ctype target = static_cast<ctype>(target_val); \
for (size_t i = 0; i < size; ++i) { \
ASSERT_TRUE(ptr[i] == target); \
} \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
#undef cb
}
else
{
ASSERT_TRUE
(
host_y
.
empty
());
}
};
auto
check_forbid_empty
=
[](
const
Param
&
param
,
const
TensorShape
&
inpshp
)
{
HostTensorGenerator
<>
gen
;
auto
graph
=
ComputingGraph
::
make
();
auto
host_x
=
gen
(
inpshp
);
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
Reduce
::
make
(
x
,
param
,
{});
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
ASSERT_ANY_THROW
(
func
->
execute
().
wait
());
};
check_allow_empty
({
Mode
::
SUM
,
0
,
{}},
{
0
},
0
);
check_allow_empty
({
Mode
::
SUM
,
-
1
,
{}},
{
2
,
0
,
3
},
0
);
check_allow_empty
({
Mode
::
SUM
,
1
,
{}},
{
2
,
0
,
3
},
0
);
check_allow_empty
({
Mode
::
PRODUCT
,
0
,
{}},
{
0
,
1
,
2
},
1
);
check_allow_empty
({
Mode
::
PRODUCT
,
1
,
{}},
{
0
,
0
,
0
},
1
);
check_allow_empty
({
Mode
::
PRODUCT
,
2
,
{}},
{
0
,
0
,
0
},
1
);
check_forbid_empty
({
Mode
::
MAX
,
0
,
{}},
{
0
});
check_forbid_empty
({
Mode
::
MIN
,
-
1
,
{}},
{
0
,
1
,
2
});
check_forbid_empty
({
Mode
::
MEAN
,
0
,
{}},
{
0
,
0
});
check_forbid_empty
({
Mode
::
SUM_SQR
,
1
,
{}},
{
2
,
1
,
0
});
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录