Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
dea52781
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
dea52781
编写于
6月 25, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): let PowC & TypeCvt support empty IO
GitOrigin-RevId: f97b3005fd3d60c7c8d1159672debf3a5e30cc64
上级
2f68aeb9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
61 addition
and
1 deletion
+61
-1
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+27
-1
src/opr/include/megbrain/opr/basic_arith.h
src/opr/include/megbrain/opr/basic_arith.h
+2
-0
src/opr/test/basic_arith/others.cpp
src/opr/test/basic_arith/others.cpp
+32
-0
未找到文件。
src/opr/impl/basic_arith.cpp
浏览文件 @
dea52781
...
...
@@ -776,6 +776,10 @@ void TypeCvt::perform(DeviceTensorND &dest,
intl
::
UniqPtrWithCN
<
megdnn
::
TypeCvt
>
&
opr
)
{
mgb_assert
(
src
.
comp_node
()
==
opr
.
comp_node
());
mgb_assert
(
dest_type
.
valid
());
if
(
src
.
empty
())
{
mgb_assert
(
dest
.
empty
());
return
;
}
if
(
src
.
dtype
()
==
dest_type
)
{
dest
.
copy_from
(
src
);
return
;
...
...
@@ -1739,7 +1743,13 @@ void Reduce::record_execute_deps(ExecDependencyArray& deps) {
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
PowC
);
MEGDNN_OPR_CTOR_INIT1
(
PowC
,
ssprintf
(
"powc_%g"
,
param
.
exp
))
PowC
::
PowC
(
VarNode
*
i0
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
:
Super
(
OperatorNodeBaseCtorParam
{
i0
->
owner_graph
(),
config
,
ssprintf
(
"powc_%g"
,
param
.
exp
),
{
i0
}}
)
{
init_megdnn_opr
(
*
this
,
param
);
add_input
({
i0
});
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
intl
::
MegDNNOprInitPostCtor
<
PowC
>::
apply
(
*
this
);
}
SymbolVar
PowC
::
make
(
SymbolVar
x
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
)
{
...
...
@@ -1778,6 +1788,22 @@ void PowC::init_output_static_infer_desc() {
{
SourceType
::
DEP
,
{{
input
(
0
),
DepType
::
VALUE
}},
infer_value
});
}
void
PowC
::
scn_do_execute
()
{
if
(
input
(
0
)
->
dev_tensor
().
empty
())
{
mgb_assert
(
output
(
0
)
->
dev_tensor
().
empty
());
return
;
}
mgb_assert
(
!
output
(
0
)
->
dev_tensor
().
empty
());
Super
::
scn_do_execute
();
}
PowC
::
NodeProp
*
PowC
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
PowC
)
{
auto
exp
=
opr
.
param
().
exp
;
...
...
src/opr/include/megbrain/opr/basic_arith.h
浏览文件 @
dea52781
...
...
@@ -352,6 +352,8 @@ MGB_DEFINE_OPR_CLASS(PowC, intl::MegDNNOprWrapperFwd<megdnn::PowC>) // {
void
add_input_layout_constraint
()
override
;
void
init_output_static_infer_desc
()
override
;
void
mem_plan_fwd_in2out_writable
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
void
scn_do_execute
()
override
;
public
:
PowC
(
VarNode
*
inp
,
const
Param
&
param
,
const
OperatorNodeConfig
&
config
);
...
...
src/opr/test/basic_arith/others.cpp
浏览文件 @
dea52781
...
...
@@ -589,6 +589,23 @@ TEST(TestOprBasicArith, TypeCvtFromBool) {
ASSERT_EQ
(
TensorShape
({
2
}),
host_y
.
shape
());
}
TEST
(
TestOprBasicArith
,
TypeCvtPerformEmptyIO
)
{
HostTensorGenerator
<>
gen
;
auto
cn
=
CompNode
::
load
(
"xpu0"
);
auto
host_x
=
gen
({
2
,
0
,
3
,
4
});
auto
dev_x
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
);
dev_x
->
copy_from
(
*
host_x
);
auto
dev_y
=
std
::
make_shared
<
DeviceTensorND
>
(
cn
,
dtype
::
Int32
{});
dev_y
->
resize
(
dev_x
->
shape
());
auto
dnn_opr
=
opr
::
intl
::
create_megdnn_opr
<
megdnn
::
TypeCvt
>
(
cn
);
ASSERT_NO_THROW
(
opr
::
TypeCvt
::
perform
(
*
dev_y
,
dtype
::
Int32
{},
*
dev_x
,
dnn_opr
));
ASSERT_TRUE
(
dev_y
->
empty
());
ASSERT_TRUE
(
dev_y
->
shape
().
is_empty
());
MGB_ASSERT_SHAPE_EQ
(
dev_x
->
shape
(),
dev_y
->
shape
());
}
TEST
(
TestOprBasicArith
,
ElemwiseMemFwd
)
{
auto
graph
=
ComputingGraph
::
make
();
graph
->
options
().
graph_opt_level
=
0
;
...
...
@@ -756,4 +773,19 @@ TEST(TestOprBasicArith, PowCInfer) {
run
(
true
);
}
TEST
(
TestOprBasicArith
,
PowCEmptyIO
)
{
HostTensorGenerator
<>
gen
;
auto
graph
=
ComputingGraph
::
make
();
// empty input
auto
host_x
=
gen
({
4
,
0
,
2
,
3
});
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
y
=
opr
::
PowC
::
make
(
x
,
3.
f
);
HostTensorND
host_y
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
y
,
host_y
)});
ASSERT_NO_THROW
(
func
->
execute
().
wait
());
ASSERT_TRUE
(
host_y
.
empty
());
ASSERT_TRUE
(
host_y
.
shape
().
is_empty
());
MGB_ASSERT_SHAPE_EQ
(
host_x
->
shape
(),
host_y
.
shape
());
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录