Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
703b783c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
703b783c
编写于
8月 18, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/opr): let Indexing(Set)MultiAxisVec support empty input
GitOrigin-RevId: f15b1d45a1f8ea9b51ccfe1fc4e42414de496fe2
上级
a430c912
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
20 addition
and
6 deletion
+20
-6
src/opr/impl/indexing.cpp
src/opr/impl/indexing.cpp
+13
-2
src/opr/include/megbrain/opr/indexing.h
src/opr/include/megbrain/opr/indexing.h
+4
-3
src/opr/include/megbrain/opr/internal/indexing_helper.h
src/opr/include/megbrain/opr/internal/indexing_helper.h
+3
-1
未找到文件。
src/opr/impl/indexing.cpp
浏览文件 @
703b783c
...
...
@@ -291,8 +291,10 @@ template<class Opr>
cg
::
OperatorNodeBase
::
NodeProp
*
IndexingMultiAxisVecBase
<
Opr
>::
do_make_node_prop
()
const
{
auto
prop
=
Super
::
do_make_node_prop
();
using
DT
=
NodeProp
::
DepType
;
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
prop
->
add_dep_type_existing_var
(
input
(
0
),
DT
::
VALUE_ALLOW_EMPTY
);
for
(
auto
i
:
m_input2idxonly_axis_indexer
)
{
if
(
i
)
{
prop
->
add_dep_type_existing_var
(
...
...
@@ -415,7 +417,7 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
auto
inp
=
this
->
fancy_indexing_get_tensors_for_modify_in_scn_do_execute
();
auto
index_desc
=
this
->
make_megdnn_index_desc
(
inp
.
first
.
layout
().
ndim
,
ShouldWarnOnScalarIndexer
<
Opr
>::
val
);
if
(
index_desc
.
second
){
if
(
in
p
.
first
.
shape
().
is_empty
()
||
in
dex_desc
.
second
){
mgb_assert
(
inp
.
second
.
shape
().
is_empty
());
return
;
}
...
...
@@ -476,10 +478,19 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET(
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
);
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
IndexingSetMultiAxisVec
,
"indexing_set_multi_axis_vec"
,
false
);
IndexingSetMultiAxisVec
,
"indexing_set_multi_axis_vec"
,
false
,
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
);
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY
(
IndexingIncrMultiAxisVec
,
"indexing_incr_multi_axis_vec"
,
false
);
IndexingSetMultiAxisVec
::
NodeProp
*
IndexingSetMultiAxisVec
::
do_make_node_prop
()
const
{
auto
prop
=
Super
::
do_make_node_prop
();
prop
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
prop
;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD
(
IndexingMultiAxisVec
)
{
if
(
wrt_idx
)
...
...
src/opr/include/megbrain/opr/indexing.h
浏览文件 @
703b783c
...
...
@@ -132,11 +132,11 @@ namespace intl {
void
init_output_static_infer_desc
()
override
final
;
void
scn_do_execute
()
override
final
;
NodeProp
*
do_make_node_prop
()
const
override
;
void
add_input_layout_constraint
()
override
final
;
protected
:
using
Super
::
Super
;
protected
:
using
Super
::
Super
;
NodeProp
*
do_make_node_prop
()
const
override
;
}
;
}
// namespace intl
...
...
@@ -158,6 +158,7 @@ public:
MGB_DEFINE_OPR_CLASS
(
IndexingSetMultiAxisVec
,
intl
::
IndexingModifyMultiAxisVecHelper
<
megdnn
::
IndexingSetMultiAxisVec
>
)
// {
NodeProp
*
do_make_node_prop
()
const
override
;
public
:
MGB_DECL_FANCY_INDEXING_OPR_MODIFY
(
IndexingSetMultiAxisVec
);
...
...
src/opr/include/megbrain/opr/internal/indexing_helper.h
浏览文件 @
703b783c
...
...
@@ -241,13 +241,15 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(_opr)
const OperatorNodeConfig &config = {}, \
const InputTensorReplacer &input_tensor_replacer = {})
#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index) \
#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index, \
ctor_body...) \
_opr::_opr(VarNode *inp, VarNode *value, const IndexDesc &desc, \
const OperatorNodeConfig &config, \
const InputTensorReplacer &input_tensor_replacer): \
Super({inp->owner_graph(), config, _name, {inp, value}}, \
inp, value, desc, _require_scalar_index, input_tensor_replacer) \
{ \
ctor_body; \
} \
SymbolVar _opr::make(SymbolVar inp, SymbolVar value, const IndexDesc &desc, \
const OperatorNodeConfig &config, \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录