Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
95eb6ae3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
95eb6ae3
编写于
8月 31, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): let more ops support empty IO
GitOrigin-RevId: 84dddb4b23638b29950e438bba2af8b5fd5166fa
上级
296a2885
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
125 addition
and
23 deletion
+125
-23
dnn/src/common/basic_types.cpp
dnn/src/common/basic_types.cpp
+14
-7
src/opr/impl/tensor_manip.cpp
src/opr/impl/tensor_manip.cpp
+24
-13
src/opr/include/megbrain/opr/tensor_manip.h
src/opr/include/megbrain/opr/tensor_manip.h
+4
-2
src/opr/test/tensor_manip.cpp
src/opr/test/tensor_manip.cpp
+83
-1
未找到文件。
dnn/src/common/basic_types.cpp
浏览文件 @
95eb6ae3
...
...
@@ -392,8 +392,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
TensorLayout
result
{
dtype
,
format
};
result
.
ndim
=
tshape
.
ndim
;
for
(
size_t
i
=
0
;
i
<
tshape
.
ndim
;
i
++
)
{
megdnn_throw_if
(
!
tshape
.
shape
[
i
],
tensor_reshape_error
,
megdnn_mangle
(
"target shape is 0"
));
result
.
shape
[
i
]
=
tshape
.
shape
[
i
];
result
.
stride
[
i
]
=
(
tshape
.
shape
[
i
]
==
1
);
}
...
...
@@ -409,8 +407,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
for
(
size_t
i
=
0
;
i
<
tshape
.
ndim
;
++
i
)
{
int
target_idx
=
tshape
.
ndim
-
i
-
1
;
int
cur_idx
=
ndim
-
i
-
1
;
megdnn_throw_if
(
!
tshape
.
shape
[
target_idx
],
tensor_reshape_error
,
megdnn_mangle
(
"target shape is 0"
));
size_t
cur_shape
=
(
cur_idx
>=
0
?
shape
[
cur_idx
]
:
1
),
cur_stride
=
(
cur_idx
>=
0
?
stride
[
cur_idx
]
:
0
);
if
(
tshape
.
shape
[
target_idx
]
!=
cur_shape
)
{
...
...
@@ -434,10 +430,16 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
bool
TensorLayout
::
try_reshape
(
TensorLayout
&
result
,
const
TensorShape
&
tshp
)
const
{
megdnn_assert
(
tshp
.
ndim
);
bool
is_empty_shape
=
false
;
for
(
size_t
i
=
0
;
i
<
tshp
.
ndim
;
++
i
)
{
megdnn_throw_if
(
!
tshp
.
shape
[
i
],
tensor_reshape_error
,
if
(
!
tshp
.
shape
[
i
])
{
megdnn_throw_if
(
!
format
.
is_default
(),
tensor_reshape_error
,
megdnn_mangle
(
ssprintf
(
"bad target tshp: %s"
,
tshp
.
to_string
().
c_str
())));
is_empty_shape
=
true
;
break
;
}
}
megdnn_throw_if
(
...
...
@@ -454,6 +456,11 @@ bool TensorLayout::try_reshape(TensorLayout& result,
result
.
format
=
this
->
format
;
result
.
TensorShape
::
operator
=
(
tshp
);
if
(
is_empty_shape
)
{
result
.
init_contiguous_stride
();
return
true
;
}
size_t
sdim
=
0
,
prod
=
1
,
cont_sdim
=
0
;
for
(
size_t
i
=
0
;
i
<
tshp
.
ndim
;
++
i
)
{
megdnn_assert
(
cont_sdim
<
cont
.
ndim
);
...
...
src/opr/impl/tensor_manip.cpp
浏览文件 @
95eb6ae3
...
...
@@ -237,7 +237,8 @@ void GetVarShape::record_execute_deps(ExecDependencyArray& deps) {
void
ReshapeBrdcastHelper
::
reshapebrdcast_init
(
VarNode
*
inp
,
VarNode
*
tshp
)
{
add_input
({
inp
,
tshp
});
add_output
(
None
)
->
dtype
(
inp
->
dtype
());
add_output
(
None
)
->
dtype
(
inp
->
dtype
())
.
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
if
(
reshapebrdcast_output_shape_need_input_shape
())
outshape_by_symvar_enable
(
1
,
1
);
else
...
...
@@ -340,6 +341,14 @@ void ReshapeBrdcastHelper::init_output_static_infer_desc() {
infer_value
});
}
ReshapeBrdcastHelper
::
NodeProp
*
ReshapeBrdcastHelper
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
// f}}}
/* f{{{ ======================= Reshape ======================= */
...
...
@@ -394,7 +403,7 @@ Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout(
}
auto
tot_nr_elem
=
src
.
total_nr_elems
();
actual_tshape
.
shape
[
unspec
]
=
0
;
mgb_throw_if
(
tot_nr_elem
%
rem_nr_elem
,
TensorReshapeError
,
mgb_throw_if
(
!
rem_nr_elem
||
tot_nr_elem
%
rem_nr_elem
,
TensorReshapeError
,
"could not reshape: src=%s tshape=%s unspec_axis=%zd"
,
static_cast
<
const
TensorShape
&>
(
src
).
to_string
().
c_str
(),
actual_tshape
.
to_string
().
c_str
(),
...
...
@@ -484,6 +493,17 @@ void AxisManipOprBase::init_output_static_infer_desc() {
{
SourceType
::
DEP
,
{{
input
(
0
),
DepType
::
VALUE
}},
infer_value
});
}
AxisManipOprBase
::
NodeProp
*
AxisManipOprBase
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
void
AxisManipOprBase
::
axis_manip_init
(
VarNode
*
inp
)
{
add_input
({
inp
});
add_output
(
None
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
// f}}}
...
...
@@ -504,8 +524,7 @@ Dimshuffle::Dimshuffle(VarNode *inp, const std::vector<int> &pattern,
mgb_throw_if
(
i
<
-
1
||
i
>=
int
(
ndim
),
GraphError
,
"bad Dimshuffle pattern"
);
}
add_input
({
inp
});
add_output
(
None
);
axis_manip_init
(
inp
);
add_equivalence_component
<
PODHash
<
int
>>
(
m_pattern
.
data
(),
m_pattern
.
size
());
}
...
...
@@ -587,8 +606,7 @@ AxisAddRemove::AxisAddRemove(
{
mgb_throw_if
(
desc
.
empty
(),
GraphError
,
"desc for AxisAddRemove could not be empty"
);
add_input
({
inp
});
add_output
(
None
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
axis_manip_init
(
inp
);
add_equivalence_component
<
PODHash
<
AxisDesc
>>
(
m_desc
.
data
(),
m_desc
.
size
());
}
...
...
@@ -631,13 +649,6 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout(
return
layout
;
}
AxisAddRemove
::
NodeProp
*
AxisAddRemove
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
AxisAddRemove
)
{
MGB_MARK_USED_VAR
(
wrt_idx
);
...
...
src/opr/include/megbrain/opr/tensor_manip.h
浏览文件 @
95eb6ae3
...
...
@@ -92,6 +92,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ReshapeBrdcastHelper,
void
scn_do_execute
()
override
final
;
void
add_input_layout_constraint
()
override
final
;
void
init_output_static_infer_desc
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
protected:
using
Super
::
Super
;
...
...
@@ -199,11 +200,14 @@ MGB_DEFINE_CLS_WITH_SUPER(AxisManipOprBase,
void
mem_plan_fwd_in2out_readonly
()
override
final
;
void
scn_do_execute
()
override
final
;
void
init_output_static_infer_desc
()
override
final
;
NodeProp
*
do_make_node_prop
()
const
override
;
protected:
using
Super
::
Super
;
virtual
TensorLayout
axis_manip_get_output_layout
(
const
TensorLayout
&
inp_layout
)
const
=
0
;
void
axis_manip_init
(
VarNode
*
inp
);
};
}
...
...
@@ -319,8 +323,6 @@ MGB_DEFINE_OPR_CLASS(AxisAddRemove, intl::AxisManipOprBase) // {
TensorLayout
axis_manip_get_output_layout
(
const
TensorLayout
&
inp_layout
)
const
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
}
;
namespace
intl
{
...
...
src/opr/test/tensor_manip.cpp
浏览文件 @
95eb6ae3
...
...
@@ -17,6 +17,7 @@
#include "megbrain/opr/io.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/misc.h"
#include "megbrain/utils/arith_helper.h"
using
namespace
mgb
;
...
...
@@ -138,7 +139,7 @@ TEST(TestTensorManip, Reshape) {
auto
&&
dep_map
=
opr0_reshp
.
node
()
->
owner_opr
()
->
node_prop
().
dep_map
();
using
DT
=
cg
::
OperatorNodeBase
::
NodeProp
::
DepType
;
ASSERT_EQ
(
2u
,
dep_map
.
size
());
ASSERT_EQ
(
DT
::
DEV_VALUE
,
dep_map
.
at
(
op
->
input
(
0
)));
ASSERT_EQ
(
DT
::
DEV_VALUE
|
DT
::
VALUE_ALLOW_EMPTY
,
dep_map
.
at
(
op
->
input
(
0
)));
ASSERT_EQ
(
DT
::
HOST_VALUE
,
dep_map
.
at
(
op
->
input
(
1
)));
}
...
...
@@ -318,6 +319,39 @@ TEST(TestTensorManip, ReshapeInferShapeForDynamicInput) {
run
({
23
,
12
,
5
});
}
TEST
(
TestTensorManip
,
ReshapeEmptyShape
)
{
HostTensorGenerator
<>
gen
;
constexpr
size_t
x_length
=
233
;
auto
host_x
=
gen
({
x_length
}),
host_v
=
gen
({
2
,
3
,
3
,
3
});
for
(
size_t
i
=
0
;
i
<
x_length
;
++
i
)
{
host_x
->
ptr
<
float
>
()[
i
]
=
1.
f
;
}
constexpr
auto
INVALID_AXIS
=
opr
::
Reshape
::
Param
::
INVALID_AXIS
;
for
(
auto
unspec_axis
:
{
INVALID_AXIS
,
0
,
1
,
3
})
{
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
TensorShape
tshape
{
2
,
3
,
3
,
3
};
auto
zero_axis
=
unspec_axis
;
if
(
unspec_axis
==
INVALID_AXIS
)
{
tshape
[
zero_axis
=
2
]
=
0
;
}
using
CondTakeMode
=
opr
::
CondTake
::
Param
::
Mode
;
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
x_empty
=
opr
::
CondTake
::
make
(
x
,
x
,
{
CondTakeMode
::
EQ
,
0.
f
})[
0
],
v
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_v
),
x_reshape
=
opr
::
Reshape
::
make
(
x_empty
,
tshape
,
{
unspec_axis
}),
y
=
opr
::
Concat
::
make
({
x_reshape
,
v
},
zero_axis
);
HostTensorND
host_empty
,
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
x_reshape
,
host_empty
),
make_callback_copy
(
y
,
host_y
)});
func
->
execute
().
wait
();
ASSERT_TRUE
(
host_empty
.
layout
().
is_empty
());
MGB_ASSERT_TENSOR_EQ
(
*
host_v
,
host_y
);
}
}
TEST
(
TestTensorManip
,
ReshapeWithNegativeUnspec
)
{
HostTensorGenerator
<>
gen
;
auto
host_x
=
gen
({
4
,
8
});
...
...
@@ -365,6 +399,26 @@ TEST(TestTensorManip, Broadcast) {
}
}
TEST
(
TestTensorManip
,
BroadcastEmptyShape
)
{
HostTensorGenerator
<>
gen
;
for
(
auto
&&
arg
:
{
std
::
make_pair
(
TensorShape
{
1
},
TensorShape
{
0
}),
{{
1
,
2
,
3
},
{
0
,
2
,
3
}},
{{
2
,
3
},
{
1
,
0
,
2
,
3
}},
{{
1
,
0
,
2
,
3
},
{
4
,
0
,
2
,
3
}},
{{
0
,
1
,
2
,
3
},
{
3
,
0
,
4
,
2
,
3
}}})
{
auto
host_x
=
gen
(
arg
.
first
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
Broadcast
::
make
(
x
,
arg
.
second
);
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
func
->
execute
();
ASSERT_TRUE
(
host_y
.
shape
().
eq_shape
(
arg
.
second
));
}
}
TEST
(
TestTensorManip
,
Dimshuffle
)
{
HostTensorGenerator
<>
gen
;
constexpr
size_t
S0
=
8
,
S1
=
3
;
...
...
@@ -395,6 +449,34 @@ TEST(TestTensorManip, Dimshuffle) {
}
}
TEST
(
TestTensorManip
,
DimshuffleEmptyShape
)
{
HostTensorGenerator
<>
gen
;
for
(
auto
&&
arg
:
{
std
::
make_pair
(
TensorShape
{
3
,
0
},
std
::
vector
<
int
>
{
1
,
-
1
,
0
,
-
1
}),
{{
3
,
1
,
0
,
4
},
{
-
1
,
3
,
-
1
,
0
,
2
}},
{{
2
,
0
,
3
,
0
},
{
1
,
0
,
2
,
3
}}})
{
auto
host_x
=
gen
(
arg
.
first
);
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
Dimshuffle
::
make
(
x
,
arg
.
second
);
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
func
->
execute
();
auto
&&
y_shape
=
host_y
.
shape
();
for
(
size_t
idx
=
0
;
idx
<
arg
.
second
.
size
();
++
idx
)
{
auto
elem
=
arg
.
second
[
idx
];
if
(
elem
==
-
1
)
{
ASSERT_EQ
(
y_shape
[
idx
],
1u
);
}
else
{
ASSERT_EQ
(
arg
.
first
[
elem
],
y_shape
[
idx
]);
}
}
}
}
TEST
(
TestTensorManip
,
DimshuffleCombined
)
{
using
Checker
=
AutoOprChecker
<
1
,
1
>
;
constexpr
int
RED0
=
2
,
RED1
=
3
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录