From 703b783c24c126360f783a24771be0b96e2952ca Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 18 Aug 2021 20:46:40 +0800 Subject: [PATCH] feat(mgb/opr): let Indexing(Set)MultiAxisVec support empty input GitOrigin-RevId: f15b1d45a1f8ea9b51ccfe1fc4e42414de496fe2 --- src/opr/impl/indexing.cpp | 15 +++++++++++++-- src/opr/include/megbrain/opr/indexing.h | 7 ++++--- .../megbrain/opr/internal/indexing_helper.h | 4 +++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index 11426900d..ad56bb48c 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -291,8 +291,10 @@ template cg::OperatorNodeBase::NodeProp* IndexingMultiAxisVecBase::do_make_node_prop() const { auto prop = Super::do_make_node_prop(); + using DT = NodeProp::DepType; // TODO: should also allow input shape is empty if any // indexer's shape is empty + prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY); for (auto i: m_input2idxonly_axis_indexer) { if (i) { prop->add_dep_type_existing_var( @@ -415,7 +417,7 @@ void intl::IndexingModifyMultiAxisVecHelper::scn_do_execute() { auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute(); auto index_desc = this->make_megdnn_index_desc( inp.first.layout().ndim, ShouldWarnOnScalarIndexer::val); - if (index_desc.second){ + if (inp.first.shape().is_empty() || index_desc.second){ mgb_assert(inp.second.shape().is_empty()); return; } @@ -476,10 +478,19 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET( output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); ); MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( - IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false); + IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false, + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + ); MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); +IndexingSetMultiAxisVec::NodeProp* IndexingSetMultiAxisVec::do_make_node_prop() const { + auto prop = Super::do_make_node_prop(); + prop->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return prop; +} + #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { if (wrt_idx) diff --git a/src/opr/include/megbrain/opr/indexing.h b/src/opr/include/megbrain/opr/indexing.h index a7ebbe24b..c73979506 100644 --- a/src/opr/include/megbrain/opr/indexing.h +++ b/src/opr/include/megbrain/opr/indexing.h @@ -132,11 +132,11 @@ namespace intl { void init_output_static_infer_desc() override final; void scn_do_execute() override final; - NodeProp* do_make_node_prop() const override; void add_input_layout_constraint() override final; - protected: - using Super::Super; + protected: + using Super::Super; + NodeProp* do_make_node_prop() const override; }; } // namespace intl @@ -158,6 +158,7 @@ public: MGB_DEFINE_OPR_CLASS(IndexingSetMultiAxisVec, intl::IndexingModifyMultiAxisVecHelper ) // { + NodeProp* do_make_node_prop() const override; public: MGB_DECL_FANCY_INDEXING_OPR_MODIFY(IndexingSetMultiAxisVec); diff --git a/src/opr/include/megbrain/opr/internal/indexing_helper.h b/src/opr/include/megbrain/opr/internal/indexing_helper.h index 9f6406d79..974dc2bcd 100644 --- a/src/opr/include/megbrain/opr/internal/indexing_helper.h +++ b/src/opr/include/megbrain/opr/internal/indexing_helper.h @@ -241,13 +241,15 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(_opr) const OperatorNodeConfig &config = {}, \ const InputTensorReplacer &input_tensor_replacer = {}) -#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index) \ +#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index, \ + ctor_body...) \ _opr::_opr(VarNode *inp, VarNode *value, const IndexDesc &desc, \ const OperatorNodeConfig &config, \ const InputTensorReplacer &input_tensor_replacer): \ Super({inp->owner_graph(), config, _name, {inp, value}}, \ inp, value, desc, _require_scalar_index, input_tensor_replacer) \ { \ + ctor_body; \ } \ SymbolVar _opr::make(SymbolVar inp, SymbolVar value, const IndexDesc &desc, \ const OperatorNodeConfig &config, \ -- GitLab