From f2e1bb41b4539006f6967b1bd96d93a260026c8d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 18 May 2020 15:37:03 +0800 Subject: [PATCH] feat(mgb/opr): let more indexing ops support empty shape GitOrigin-RevId: db4eba5877293cb1865801e315f53026527f6c6f --- src/opr/impl/indexing.cpp | 46 +++++++++---- src/opr/include/megbrain/opr/indexing.h | 5 +- src/opr/test/basic_arith/elemwise.cpp | 21 ++---- src/opr/test/indexing.cpp | 87 +++++++++++++++++++++++++ test/src/helper.cpp | 10 +++ test/src/include/megbrain/test/helper.h | 5 ++ 6 files changed, 145 insertions(+), 29 deletions(-) diff --git a/src/opr/impl/indexing.cpp b/src/opr/impl/indexing.cpp index cfbdbc6a9..133f175f3 100644 --- a/src/opr/impl/indexing.cpp +++ b/src/opr/impl/indexing.cpp @@ -226,17 +226,19 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder::record_megdnn_opr( } /* ==================== MultiAxisVecFancyIndexingHelper ==================== */ -const megdnn::IndexingMultiAxisVec::IndexDesc& +std::pair intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc( size_t inp_ndim, bool warn_all_scalar) { auto &&index = m_megdnn_index_cache; index.clear(); + bool is_empty_shape = false; for (auto i: reverse_adaptor(m_input2idxonly_axis_indexer)) { if (i) { index.push_back({ i->axis.get(inp_ndim), i->idx.node()->dev_tensor().as_megdnn()}); + is_empty_shape |= index.back().vec.layout.is_empty(); } } @@ -264,7 +266,7 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc( m_scalar_idx_warn_printed = true; } - return index; + return {index, is_empty_shape}; } /* ==================== IndexingMultiAxisVecBase ==================== */ @@ -272,6 +274,8 @@ template cg::OperatorNodeBase::NodeProp* IndexingMultiAxisVecBase::do_make_node_prop() const { auto prop = Super::do_make_node_prop(); + // TODO: should also allow input shape is empty if any + // indexer's shape is empty for (auto i: m_input2idxonly_axis_indexer) { if (i) { prop->add_dep_type_existing_var( @@ -360,13 +364,13 @@ void IndexingMultiAxisVecBase::scn_do_execute() { auto &&index_desc = make_megdnn_index_desc( inp.layout().ndim, ShouldWarnOnScalarIndexer::val); auto &&odev = output(0)->dev_tensor(); - if (index_desc.empty()) { + if (index_desc.first.empty()) { odev.copy_from_fixlayout(inp); } else { - if (index_desc[0].vec.layout[0]) { + if (!index_desc.second) { // only call megdnn exec if result is not empty this->megdnn_opr(*this).exec( - inp.as_megdnn(), index_desc, odev.as_megdnn(), + inp.as_megdnn(), index_desc.first, odev.as_megdnn(), intl::get_megdnn_workspace_from_var(output(1))); } else { mgb_assert(odev.empty()); @@ -391,7 +395,11 @@ void intl::IndexingModifyMultiAxisVecHelper::scn_do_execute() { auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute(); auto index_desc = this->make_megdnn_index_desc( inp.first.layout().ndim, ShouldWarnOnScalarIndexer::val); - if (index_desc.empty()) { + if (index_desc.second){ + mgb_assert(inp.second.shape().is_empty()); + return; + } + if (index_desc.first.empty()) { using IMT = IndexingModifyType; static constexpr auto modify_type = IndexingModifyTypeGetter::value; @@ -410,11 +418,28 @@ void intl::IndexingModifyMultiAxisVecHelper::scn_do_execute() { } else { this->megdnn_opr(*this).exec( inp.first.as_megdnn(), inp.second.as_megdnn(), - index_desc, + index_desc.first, intl::get_megdnn_workspace_from_var(output(1))); } } +template +cg::OperatorNodeBase::NodeProp* +intl::IndexingModifyMultiAxisVecHelper::do_make_node_prop() const { + auto prop = Super::do_make_node_prop(); + using DT = NodeProp::DepType; + // TODO: should also allow input shape is empty if any + // indexer's shape is empty + prop->add_dep_type_existing_var(input(1), DT::VALUE_ALLOW_EMPTY); + for (auto i: m_input2idxonly_axis_indexer) { + if (i) { + prop->add_dep_type_existing_var( + i->idx.node(), DT::VALUE_ALLOW_EMPTY); + } + } + return prop; +} + template void intl::IndexingModifyMultiAxisVecHelper:: add_input_layout_constraint() { @@ -429,7 +454,6 @@ add_input_layout_constraint() { MGB_IMPL_FANCY_INDEXING_OPR_GET( IndexingMultiAxisVec, "indexing_multi_axis_vec", false, output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); - output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); ); MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false); @@ -469,12 +493,10 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { MGB_IMPL_FANCY_INDEXING_OPR_GET( MeshIndexing, "mesh_indexing", false, - output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); - output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); MGB_IMPL_FANCY_INDEXING_OPR_GET( BatchedMeshIndexing, "batched_mesh_indexing", false, - output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); - output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); MGB_IMPL_OPR_GRAD(MeshIndexing) { if (wrt_idx != 0) { diff --git a/src/opr/include/megbrain/opr/indexing.h b/src/opr/include/megbrain/opr/indexing.h index fd7537b71..d9104100d 100644 --- a/src/opr/include/megbrain/opr/indexing.h +++ b/src/opr/include/megbrain/opr/indexing.h @@ -117,7 +117,9 @@ namespace intl { protected: using Super::Super; - const megdnn::IndexingMultiAxisVec::IndexDesc& + //! return IndexDesc and whether it has an AxisIndexer with + //! empty shape + std::pair make_megdnn_index_desc( size_t inp_ndim, bool warn_all_scalar = true); }; @@ -130,6 +132,7 @@ namespace intl { void init_output_static_infer_desc() override final; void scn_do_execute() override final; + NodeProp* do_make_node_prop() const override; void add_input_layout_constraint() override final; protected: diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index 231abc22c..08cf6ad4b 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -649,17 +649,6 @@ namespace { > TernaryTraitTypes; TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes); - ::testing::AssertionResult assert_shape_equal(const TensorShape& v0, - const TensorShape& v1) { - if (v0.eq_shape(v1)) - return ::testing::AssertionSuccess() - << v0.to_string() << " == " << v1.to_string(); - else - return ::testing::AssertionFailure() - << v0.to_string() << " != " << v1.to_string(); - } -#define ASSERT_SHAPE_EQ(v0, v1) ASSERT_TRUE(assert_shape_equal(v0, v1)) - } // anonymous namespace template @@ -974,14 +963,14 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) { ASSERT_NO_THROW(func->execute().wait()); ASSERT_TRUE(host_y.empty()); ASSERT_TRUE(host_y.shape().is_empty()); - ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3})); + MGB_ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3})); } TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { HostTensorGenerator<> gen; auto graph = ComputingGraph::make(); auto host_x = gen({0, 8, 1, 7}), host_y = gen({0, 8, 1, 7}); - + auto x = opr::Host2DeviceCopy::make(*graph, host_x), y = opr::Host2DeviceCopy::make(*graph, host_y), z = x + y; @@ -997,14 +986,14 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { ASSERT_NO_THROW(func->execute().wait()); ASSERT_TRUE(host_z.empty()); ASSERT_TRUE(host_z.shape().is_empty()); - ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7})); + MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7})); // Broadcast to 0 (2) host_y->resize({2, 8, 1, 7}); ASSERT_NO_THROW(func->execute().wait()); ASSERT_TRUE(host_z.empty()); ASSERT_TRUE(host_z.shape().is_empty()); - ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); + MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); // Scalar broadcast z = x + x.make_scalar(1.f); @@ -1012,7 +1001,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { ASSERT_NO_THROW(func->execute().wait()); ASSERT_TRUE(host_z.empty()); ASSERT_TRUE(host_z.shape().is_empty()); - ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); + MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/indexing.cpp b/src/opr/test/indexing.cpp index cbb22bae9..76096ef2e 100644 --- a/src/opr/test/indexing.cpp +++ b/src/opr/test/indexing.cpp @@ -14,6 +14,7 @@ #include "megbrain/opr/basic_arith.h" #include "megbrain/opr/io.h" #include "megbrain/opr/misc.h" +#include "megbrain/opr/utility.h" #include "megbrain/test/autocheck.h" #include "megbrain/test/helper.h" #include "megbrain/test/megdnn_helper.h" @@ -1195,6 +1196,92 @@ TEST(TestOprIndexing, SetMeshIndexing) { } } +namespace { + +template +void run_multi_axis_vec_empty_shape( + const TensorShape& ishp, const TensorShape& idx0, + const TensorShape& idx1, const TensorShape& tshp) { + mgb_assert(ishp.ndim >= 4); + mgb_assert(idx0.is_empty() || idx1.is_empty()); + using AI = opr::indexing::AxisIndexer; + auto graph = ComputingGraph::make(); + HostTensorGenerator<> gen; + HostTensorGenerator gen_idx; + auto host_x = gen(ishp); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + idx_dynamic_shape = opr::MarkDynamicVar::make( + opr::ImmutableTensor::make(*graph, *gen_idx(idx0))), + idx_static_shape = + opr::ImmutableTensor::make(*graph, *gen_idx(idx1)), + y = Opr::make(x, { + AI::make_interval(-1, None, None, x.make_scalar(2)), + AI::make_index(1, idx_dynamic_shape), + AI::make_index(2, idx_static_shape)}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_TRUE(host_y.shape().is_empty()); + MGB_ASSERT_SHAPE_EQ(host_y.shape(), tshp); +} + +template +void run_modify_multi_axis_vec_empty_shape( + const TensorShape& ishp, const TensorShape& vshp, + const TensorShape& idx0, const TensorShape& idx1) { + mgb_assert(ishp.ndim >= 4); + mgb_assert(vshp.is_empty() && (idx0.is_empty() || idx1.is_empty())); + using AI = opr::indexing::AxisIndexer; + auto graph = ComputingGraph::make(); + HostTensorGenerator<> gen; + HostTensorGenerator gen_idx; + auto host_x = gen(ishp), host_v = gen(vshp); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + v = opr::Host2DeviceCopy::make(*graph, host_v), + idx_dynamic_shape = opr::MarkDynamicVar::make( + opr::ImmutableTensor::make(*graph, *gen_idx(idx0))), + idx_static_shape = + opr::ImmutableTensor::make(*graph, *gen_idx(idx1)), + y = Opr::make(x, v, { + AI::make_interval(-1, None, None, x.make_scalar(2)), + AI::make_index(1, idx_dynamic_shape), + AI::make_index(2, idx_static_shape)}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + MGB_ASSERT_TENSOR_EQ(*host_x, host_y); +} + +} + +TEST(TestOprIndexing, MultiAxisVecEmptyShape) { + TensorShape ishp{8, 2, 3, 4}; + size_t n = ishp[0], last_ndim = ishp[ishp.ndim - 1] / 2; + run_multi_axis_vec_empty_shape( + ishp, {0}, {0}, {n, 0, last_ndim}); + run_multi_axis_vec_empty_shape( + ishp, {0}, {2}, {n, 0, 2, last_ndim}); + run_multi_axis_vec_empty_shape( + ishp, {3}, {0}, {n, 3, 0, last_ndim}); + run_multi_axis_vec_empty_shape( + ishp, {n, 0}, {n, 2}, {n, 0, 2, last_ndim}); + run_multi_axis_vec_empty_shape( + ishp, {n, 4}, {n, 0}, {n, 4, 0, last_ndim}); + + run_modify_multi_axis_vec_empty_shape( + ishp, {n, 0, last_ndim}, {0}, {0}); + run_modify_multi_axis_vec_empty_shape( + ishp, {n, 0, last_ndim}, {0}, {0}); + run_modify_multi_axis_vec_empty_shape( + ishp, {n, 0, 2, last_ndim}, {0}, {2}); + run_modify_multi_axis_vec_empty_shape( + ishp, {n, 3, 0, last_ndim}, {3}, {0}); + run_modify_multi_axis_vec_empty_shape( + ishp, {n, 4, 0, last_ndim}, {n, 4}, {n, 0}); + run_modify_multi_axis_vec_empty_shape( + ishp, {n, 0, 5, last_ndim}, {n, 0}, {n, 5}); +} + #endif // MGB_ENABLE_EXCEPTION // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/test/src/helper.cpp b/test/src/helper.cpp index d8acb5748..bedca8fb5 100644 --- a/test/src/helper.cpp +++ b/test/src/helper.cpp @@ -158,6 +158,16 @@ namespace mgb { return ::testing::AssertionSuccess(); } +::testing::AssertionResult mgb::__assert_shape_equal(const TensorShape& v0, + const TensorShape& v1) { + if (v0.eq_shape(v1)) + return ::testing::AssertionSuccess() + << v0.to_string() << " == " << v1.to_string(); + else + return ::testing::AssertionFailure() + << v0.to_string() << " != " << v1.to_string(); +} + #if WIN32 #include #include diff --git a/test/src/include/megbrain/test/helper.h b/test/src/include/megbrain/test/helper.h index 66dfa1e5d..e4e9f51eb 100644 --- a/test/src/include/megbrain/test/helper.h +++ b/test/src/include/megbrain/test/helper.h @@ -133,6 +133,11 @@ decltype(auto) container_to_vector(Container &&ct) { #define MGB_ASSERT_TENSOR_EQ(v0, v1) \ MGB_ASSERT_TENSOR_NEAR(v0, v1, 1e-6) +::testing::AssertionResult __assert_shape_equal(const TensorShape& v0, + const TensorShape& v1); + +#define MGB_ASSERT_SHAPE_EQ(v0, v1) \ + ASSERT_TRUE(::mgb::__assert_shape_equal(v0, v1)) /*! * \brief xorshift+ RNG, which is very fast -- GitLab