提交 f2e1bb41 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let more indexing ops support empty shape

GitOrigin-RevId: db4eba5877293cb1865801e315f53026527f6c6f
上级 a4879fc6
...@@ -226,17 +226,19 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr( ...@@ -226,17 +226,19 @@ void mixin::IndexingMultiAxisVecMegDNNOprHolder<Opr>::record_megdnn_opr(
} }
/* ==================== MultiAxisVecFancyIndexingHelper ==================== */ /* ==================== MultiAxisVecFancyIndexingHelper ==================== */
const megdnn::IndexingMultiAxisVec::IndexDesc& std::pair<const megdnn::IndexingMultiAxisVec::IndexDesc&, bool>
intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc( intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
size_t inp_ndim, bool warn_all_scalar) { size_t inp_ndim, bool warn_all_scalar) {
auto &&index = m_megdnn_index_cache; auto &&index = m_megdnn_index_cache;
index.clear(); index.clear();
bool is_empty_shape = false;
for (auto i: reverse_adaptor(m_input2idxonly_axis_indexer)) { for (auto i: reverse_adaptor(m_input2idxonly_axis_indexer)) {
if (i) { if (i) {
index.push_back({ index.push_back({
i->axis.get(inp_ndim), i->axis.get(inp_ndim),
i->idx.node()->dev_tensor().as_megdnn()}); i->idx.node()->dev_tensor().as_megdnn()});
is_empty_shape |= index.back().vec.layout.is_empty();
} }
} }
...@@ -264,7 +266,7 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc( ...@@ -264,7 +266,7 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
m_scalar_idx_warn_printed = true; m_scalar_idx_warn_printed = true;
} }
return index; return {index, is_empty_shape};
} }
/* ==================== IndexingMultiAxisVecBase ==================== */ /* ==================== IndexingMultiAxisVecBase ==================== */
...@@ -272,6 +274,8 @@ template<class Opr> ...@@ -272,6 +274,8 @@ template<class Opr>
cg::OperatorNodeBase::NodeProp* cg::OperatorNodeBase::NodeProp*
IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const { IndexingMultiAxisVecBase<Opr>::do_make_node_prop() const {
auto prop = Super::do_make_node_prop(); auto prop = Super::do_make_node_prop();
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
for (auto i: m_input2idxonly_axis_indexer) { for (auto i: m_input2idxonly_axis_indexer) {
if (i) { if (i) {
prop->add_dep_type_existing_var( prop->add_dep_type_existing_var(
...@@ -360,13 +364,13 @@ void IndexingMultiAxisVecBase<Opr>::scn_do_execute() { ...@@ -360,13 +364,13 @@ void IndexingMultiAxisVecBase<Opr>::scn_do_execute() {
auto &&index_desc = make_megdnn_index_desc( auto &&index_desc = make_megdnn_index_desc(
inp.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val); inp.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
auto &&odev = output(0)->dev_tensor(); auto &&odev = output(0)->dev_tensor();
if (index_desc.empty()) { if (index_desc.first.empty()) {
odev.copy_from_fixlayout(inp); odev.copy_from_fixlayout(inp);
} else { } else {
if (index_desc[0].vec.layout[0]) { if (!index_desc.second) {
// only call megdnn exec if result is not empty // only call megdnn exec if result is not empty
this->megdnn_opr(*this).exec( this->megdnn_opr(*this).exec(
inp.as_megdnn(), index_desc, odev.as_megdnn(), inp.as_megdnn(), index_desc.first, odev.as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} else { } else {
mgb_assert(odev.empty()); mgb_assert(odev.empty());
...@@ -391,7 +395,11 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() { ...@@ -391,7 +395,11 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute(); auto inp = this->fancy_indexing_get_tensors_for_modify_in_scn_do_execute();
auto index_desc = this->make_megdnn_index_desc( auto index_desc = this->make_megdnn_index_desc(
inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val); inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
if (index_desc.empty()) { if (index_desc.second){
mgb_assert(inp.second.shape().is_empty());
return;
}
if (index_desc.first.empty()) {
using IMT = IndexingModifyType; using IMT = IndexingModifyType;
static constexpr auto modify_type = static constexpr auto modify_type =
IndexingModifyTypeGetter<Opr>::value; IndexingModifyTypeGetter<Opr>::value;
...@@ -410,11 +418,28 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() { ...@@ -410,11 +418,28 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
} else { } else {
this->megdnn_opr(*this).exec( this->megdnn_opr(*this).exec(
inp.first.as_megdnn(), inp.second.as_megdnn(), inp.first.as_megdnn(), inp.second.as_megdnn(),
index_desc, index_desc.first,
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
} }
template<class Opr>
cg::OperatorNodeBase::NodeProp*
intl::IndexingModifyMultiAxisVecHelper<Opr>::do_make_node_prop() const {
auto prop = Super::do_make_node_prop();
using DT = NodeProp::DepType;
// TODO: should also allow input shape is empty if any
// indexer's shape is empty
prop->add_dep_type_existing_var(input(1), DT::VALUE_ALLOW_EMPTY);
for (auto i: m_input2idxonly_axis_indexer) {
if (i) {
prop->add_dep_type_existing_var(
i->idx.node(), DT::VALUE_ALLOW_EMPTY);
}
}
return prop;
}
template<class Opr> template<class Opr>
void intl::IndexingModifyMultiAxisVecHelper<Opr>:: void intl::IndexingModifyMultiAxisVecHelper<Opr>::
add_input_layout_constraint() { add_input_layout_constraint() {
...@@ -429,7 +454,6 @@ add_input_layout_constraint() { ...@@ -429,7 +454,6 @@ add_input_layout_constraint() {
MGB_IMPL_FANCY_INDEXING_OPR_GET( MGB_IMPL_FANCY_INDEXING_OPR_GET(
IndexingMultiAxisVec, "indexing_multi_axis_vec", false, IndexingMultiAxisVec, "indexing_multi_axis_vec", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
); );
MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(
IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false); IndexingSetMultiAxisVec, "indexing_set_multi_axis_vec", false);
...@@ -469,12 +493,10 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { ...@@ -469,12 +493,10 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) {
MGB_IMPL_FANCY_INDEXING_OPR_GET( MGB_IMPL_FANCY_INDEXING_OPR_GET(
MeshIndexing, "mesh_indexing", false, MeshIndexing, "mesh_indexing", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
MGB_IMPL_FANCY_INDEXING_OPR_GET( MGB_IMPL_FANCY_INDEXING_OPR_GET(
BatchedMeshIndexing, "batched_mesh_indexing", false, BatchedMeshIndexing, "batched_mesh_indexing", false,
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
output(1)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE););
MGB_IMPL_OPR_GRAD(MeshIndexing) { MGB_IMPL_OPR_GRAD(MeshIndexing) {
if (wrt_idx != 0) { if (wrt_idx != 0) {
......
...@@ -117,7 +117,9 @@ namespace intl { ...@@ -117,7 +117,9 @@ namespace intl {
protected: protected:
using Super::Super; using Super::Super;
const megdnn::IndexingMultiAxisVec::IndexDesc& //! return IndexDesc and whether it has an AxisIndexer with
//! empty shape
std::pair<const megdnn::IndexingMultiAxisVec::IndexDesc&, bool>
make_megdnn_index_desc( make_megdnn_index_desc(
size_t inp_ndim, bool warn_all_scalar = true); size_t inp_ndim, bool warn_all_scalar = true);
}; };
...@@ -130,6 +132,7 @@ namespace intl { ...@@ -130,6 +132,7 @@ namespace intl {
void init_output_static_infer_desc() override final; void init_output_static_infer_desc() override final;
void scn_do_execute() override final; void scn_do_execute() override final;
NodeProp* do_make_node_prop() const override;
void add_input_layout_constraint() override final; void add_input_layout_constraint() override final;
protected: protected:
......
...@@ -649,17 +649,6 @@ namespace { ...@@ -649,17 +649,6 @@ namespace {
> TernaryTraitTypes; > TernaryTraitTypes;
TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes); TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes);
::testing::AssertionResult assert_shape_equal(const TensorShape& v0,
const TensorShape& v1) {
if (v0.eq_shape(v1))
return ::testing::AssertionSuccess()
<< v0.to_string() << " == " << v1.to_string();
else
return ::testing::AssertionFailure()
<< v0.to_string() << " != " << v1.to_string();
}
#define ASSERT_SHAPE_EQ(v0, v1) ASSERT_TRUE(assert_shape_equal(v0, v1))
} // anonymous namespace } // anonymous namespace
template<typename Trait, typename dtype> template<typename Trait, typename dtype>
...@@ -974,7 +963,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) { ...@@ -974,7 +963,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) {
ASSERT_NO_THROW(func->execute().wait()); ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_y.empty()); ASSERT_TRUE(host_y.empty());
ASSERT_TRUE(host_y.shape().is_empty()); ASSERT_TRUE(host_y.shape().is_empty());
ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3})); MGB_ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3}));
} }
TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
...@@ -997,14 +986,14 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { ...@@ -997,14 +986,14 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
ASSERT_NO_THROW(func->execute().wait()); ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_z.empty()); ASSERT_TRUE(host_z.empty());
ASSERT_TRUE(host_z.shape().is_empty()); ASSERT_TRUE(host_z.shape().is_empty());
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7})); MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7}));
// Broadcast to 0 (2) // Broadcast to 0 (2)
host_y->resize({2, 8, 1, 7}); host_y->resize({2, 8, 1, 7});
ASSERT_NO_THROW(func->execute().wait()); ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_z.empty()); ASSERT_TRUE(host_z.empty());
ASSERT_TRUE(host_z.shape().is_empty()); ASSERT_TRUE(host_z.shape().is_empty());
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7}));
// Scalar broadcast // Scalar broadcast
z = x + x.make_scalar(1.f); z = x + x.make_scalar(1.f);
...@@ -1012,7 +1001,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) { ...@@ -1012,7 +1001,7 @@ TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
ASSERT_NO_THROW(func->execute().wait()); ASSERT_NO_THROW(func->execute().wait());
ASSERT_TRUE(host_z.empty()); ASSERT_TRUE(host_z.empty());
ASSERT_TRUE(host_z.shape().is_empty()); ASSERT_TRUE(host_z.shape().is_empty());
ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7})); MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7}));
} }
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "megbrain/opr/misc.h" #include "megbrain/opr/misc.h"
#include "megbrain/opr/utility.h"
#include "megbrain/test/autocheck.h" #include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h" #include "megbrain/test/megdnn_helper.h"
...@@ -1195,6 +1196,92 @@ TEST(TestOprIndexing, SetMeshIndexing) { ...@@ -1195,6 +1196,92 @@ TEST(TestOprIndexing, SetMeshIndexing) {
} }
} }
namespace {
template<class Opr>
void run_multi_axis_vec_empty_shape(
const TensorShape& ishp, const TensorShape& idx0,
const TensorShape& idx1, const TensorShape& tshp) {
mgb_assert(ishp.ndim >= 4);
mgb_assert(idx0.is_empty() || idx1.is_empty());
using AI = opr::indexing::AxisIndexer;
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
HostTensorGenerator<dtype::Int32> gen_idx;
auto host_x = gen(ishp);
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
idx_dynamic_shape = opr::MarkDynamicVar::make(
opr::ImmutableTensor::make(*graph, *gen_idx(idx0))),
idx_static_shape =
opr::ImmutableTensor::make(*graph, *gen_idx(idx1)),
y = Opr::make(x, {
AI::make_interval(-1, None, None, x.make_scalar(2)),
AI::make_index(1, idx_dynamic_shape),
AI::make_index(2, idx_static_shape)});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_TRUE(host_y.shape().is_empty());
MGB_ASSERT_SHAPE_EQ(host_y.shape(), tshp);
}
template<class Opr>
void run_modify_multi_axis_vec_empty_shape(
const TensorShape& ishp, const TensorShape& vshp,
const TensorShape& idx0, const TensorShape& idx1) {
mgb_assert(ishp.ndim >= 4);
mgb_assert(vshp.is_empty() && (idx0.is_empty() || idx1.is_empty()));
using AI = opr::indexing::AxisIndexer;
auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;
HostTensorGenerator<dtype::Int32> gen_idx;
auto host_x = gen(ishp), host_v = gen(vshp);
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
v = opr::Host2DeviceCopy::make(*graph, host_v),
idx_dynamic_shape = opr::MarkDynamicVar::make(
opr::ImmutableTensor::make(*graph, *gen_idx(idx0))),
idx_static_shape =
opr::ImmutableTensor::make(*graph, *gen_idx(idx1)),
y = Opr::make(x, v, {
AI::make_interval(-1, None, None, x.make_scalar(2)),
AI::make_index(1, idx_dynamic_shape),
AI::make_index(2, idx_static_shape)});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
MGB_ASSERT_TENSOR_EQ(*host_x, host_y);
}
}
TEST(TestOprIndexing, MultiAxisVecEmptyShape) {
TensorShape ishp{8, 2, 3, 4};
size_t n = ishp[0], last_ndim = ishp[ishp.ndim - 1] / 2;
run_multi_axis_vec_empty_shape<opr::IndexingMultiAxisVec>(
ishp, {0}, {0}, {n, 0, last_ndim});
run_multi_axis_vec_empty_shape<opr::MeshIndexing>(
ishp, {0}, {2}, {n, 0, 2, last_ndim});
run_multi_axis_vec_empty_shape<opr::MeshIndexing>(
ishp, {3}, {0}, {n, 3, 0, last_ndim});
run_multi_axis_vec_empty_shape<opr::BatchedMeshIndexing>(
ishp, {n, 0}, {n, 2}, {n, 0, 2, last_ndim});
run_multi_axis_vec_empty_shape<opr::BatchedMeshIndexing>(
ishp, {n, 4}, {n, 0}, {n, 4, 0, last_ndim});
run_modify_multi_axis_vec_empty_shape<opr::IndexingIncrMultiAxisVec>(
ishp, {n, 0, last_ndim}, {0}, {0});
run_modify_multi_axis_vec_empty_shape<opr::IndexingSetMultiAxisVec>(
ishp, {n, 0, last_ndim}, {0}, {0});
run_modify_multi_axis_vec_empty_shape<opr::IncrMeshIndexing>(
ishp, {n, 0, 2, last_ndim}, {0}, {2});
run_modify_multi_axis_vec_empty_shape<opr::SetMeshIndexing>(
ishp, {n, 3, 0, last_ndim}, {3}, {0});
run_modify_multi_axis_vec_empty_shape<opr::BatchedIncrMeshIndexing>(
ishp, {n, 4, 0, last_ndim}, {n, 4}, {n, 0});
run_modify_multi_axis_vec_empty_shape<opr::BatchedSetMeshIndexing>(
ishp, {n, 0, 5, last_ndim}, {n, 0}, {n, 5});
}
#endif // MGB_ENABLE_EXCEPTION #endif // MGB_ENABLE_EXCEPTION
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -158,6 +158,16 @@ namespace mgb { ...@@ -158,6 +158,16 @@ namespace mgb {
return ::testing::AssertionSuccess(); return ::testing::AssertionSuccess();
} }
::testing::AssertionResult mgb::__assert_shape_equal(const TensorShape& v0,
const TensorShape& v1) {
if (v0.eq_shape(v1))
return ::testing::AssertionSuccess()
<< v0.to_string() << " == " << v1.to_string();
else
return ::testing::AssertionFailure()
<< v0.to_string() << " != " << v1.to_string();
}
#if WIN32 #if WIN32
#include <io.h> #include <io.h>
#include <fcntl.h> #include <fcntl.h>
......
...@@ -133,6 +133,11 @@ decltype(auto) container_to_vector(Container &&ct) { ...@@ -133,6 +133,11 @@ decltype(auto) container_to_vector(Container &&ct) {
#define MGB_ASSERT_TENSOR_EQ(v0, v1) \ #define MGB_ASSERT_TENSOR_EQ(v0, v1) \
MGB_ASSERT_TENSOR_NEAR(v0, v1, 1e-6) MGB_ASSERT_TENSOR_NEAR(v0, v1, 1e-6)
::testing::AssertionResult __assert_shape_equal(const TensorShape& v0,
const TensorShape& v1);
#define MGB_ASSERT_SHAPE_EQ(v0, v1) \
ASSERT_TRUE(::mgb::__assert_shape_equal(v0, v1))
/*! /*!
* \brief xorshift+ RNG, which is very fast * \brief xorshift+ RNG, which is very fast
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册