提交 703b783c 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let Indexing(Set)MultiAxisVec support empty input

GitOrigin-RevId: f15b1d45a1f8ea9b51ccfe1fc4e42414de496fe2
上级 a430c912
......@@ -291,8 +291,10 @@ template<class Opr>
cg::OperatorNodeBase::NodeProp*
IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
using DT = NodeProp::DepType;
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
prop->add_dep_type_existing_var(input(0), DT::VALUE_ALLOW_EMPTY);
for (auto i: m_input2idxonly_axis_indexer) {
if (i) {
prop->add_dep_type_existing_var(
......@@ -415,7 +417,7 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute();
auto index_desc = this->make_megdnn_index_desc(
inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
if (index_desc.second){
if (inp.first.shape().is_empty() || index_desc.second){
mgb_assert(inp.second.shape().is_empty());
return;
}
......@@ -476,10 +478,19 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET(
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
);
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false);
IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
);
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false);
IndexingSetMultiAxisVec::NodeProp* IndexingSetMultiAxisVec::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
prop->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return prop;
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) {
if (wrt_idx)
......
......@@ -132,11 +132,11 @@ namespace intl {
void init_output_static_infer_desc() override final;
void scn_do_execute() override final;
NodeProp* do_make_node_prop() const override;
void add_input_layout_constraint() override final;
protected:
using Super::Super;
NodeProp* do_make_node_prop() const override;
};
} // namespace intl
......@@ -158,6 +158,7 @@ public:
MGB_DEFINE_OPR_CLASS(IndexingSetMultiAxisVec,
intl::IndexingModifyMultiAxisVecHelper<megdnn::IndexingSetMultiAxisVec>
) // {
NodeProp* do_make_node_prop() const override;
public:
MGB_DECL_FANCY_INDEXING_OPR_MODIFY(IndexingSetMultiAxisVec);
......
......@@ -241,13 +241,15 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(_opr)
const OperatorNodeConfig &config = {}, \
const InputTensorReplacer &input_tensor_replacer = {})
#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index) \
#define MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(_opr, _name, _require_scalar_index, \
ctor_body...) \
_opr::_opr(VarNode *inp, VarNode *value, const IndexDesc &desc, \
const OperatorNodeConfig &config, \
const InputTensorReplacer &input_tensor_replacer): \
Super({inp->owner_graph(), config, _name, {inp, value}}, \
inp, value, desc, _require_scalar_index, input_tensor_replacer) \
{ \
ctor_body; \
} \
SymbolVar _opr::make(SymbolVar inp, SymbolVar value, const IndexDesc &desc, \
const OperatorNodeConfig &config, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册