提交 50d285fc 编写于 作者: M Megvii Engine Team

fix(mgb/opr): fix IndexingModifyMultiAxisVecHelper

GitOrigin-RevId: 06f5a9110580116561ff6625956a8f9c630ebba5
上级 07d1d0ab
......@@ -32,6 +32,27 @@ namespace {
index = opr::TypeCvt::make(index, dtype::Int32());
}
}
enum IndexingModifyType {
SET, INCR
};
template<typename Opr>
struct IndexingModifyTypeGetter {};
#define REG(op, type) \
template<> \
struct IndexingModifyTypeGetter<megdnn::op> { \
static constexpr IndexingModifyType value = IndexingModifyType::type; \
};
REG(IndexingIncrMultiAxisVec, INCR)
REG(IncrMeshIndexing, INCR)
REG(BatchedIncrMeshIndexing, INCR)
REG(IndexingSetMultiAxisVec, SET)
REG(SetMeshIndexing, SET)
REG(BatchedSetMeshIndexing, SET)
#undef REG
}
namespace mgb {
......@@ -371,15 +392,20 @@ void intl::IndexingModifyMultiAxisVecHelper<Opr>::scn_do_execute() {
auto index_desc = this->make_megdnn_index_desc(
inp.first.layout().ndim, ShouldWarnOnScalarIndexer<Opr>::val);
if (index_desc.empty()) {
if (std::is_same<Opr, megdnn::IndexingSetMultiAxisVec>::value) {
inp.first.copy_from_fixlayout(inp.second);
} else {
static constexpr bool is_incr = std::is_same<
Opr, megdnn::IndexingIncrMultiAxisVec>::value;
mgb_assert(is_incr);
megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr<
megdnn::AddUpdate>(comp_node());
add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn());
using IMT = IndexingModifyType;
static constexpr auto modify_type =
IndexingModifyTypeGetter<Opr>::value;
switch (modify_type) {
case IMT::SET: {
inp.first.copy_from_fixlayout(inp.second);
break;
} case IMT::INCR: {
megdnn::AddUpdate* add_update = intl::get_megdnn_global_opr<
megdnn::AddUpdate>(comp_node());
add_update->exec(inp.first.as_megdnn(), inp.second.as_megdnn());
break;
} default:
mgb_throw(MegBrainError, "bad modify type");
}
} else {
this->megdnn_opr(*this).exec(
......
......@@ -1165,6 +1165,34 @@ TEST(TestOprIndexing, SetMeshIndexing) {
checker.run({TensorShape{8, 20, 10, 7, 7}, {1}, {9}, {3, 9, 1, 7, 7}},
opt);
}
{ // only interval AxisIndexer given
using Checker = AutoOprChecker<2, 1>;
auto make_graph =
[&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
SymbolVar x = inputs[0], val = inputs[1];
return {opr::SetMeshIndexing::make(
x, val,
{AIdx::make_interval(0, x.make_scalar(1),
None, x.make_scalar(2))})};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
dest[0].copy_from(*inp[0]);
auto value = *inp[1];
auto value_iter = megdnn::tensor_iter<float>(value.as_megdnn()).begin();
size_t n = dest[0].layout().stride[0];
float* raw_ptr = dest[0].ptr<float>();
for (size_t i = 0; i < value.shape().total_nr_elems(); ++i) {
ptrdiff_t offset = (i / n * 2 + 1) * n + i % n;
*(raw_ptr + offset) = *value_iter;
++ value_iter;
}
};
Checker checker{make_graph, fwd};
checker.run({TensorShape{11}, {5}});
checker.run({TensorShape{6, 7}, {3, 7}});
checker.run({TensorShape{4, 7, 1}, {2, 7, 1}});
checker.run({TensorShape{7, 1, 1, 2}, {3, 1, 1, 2}});
}
}
#endif // MGB_ENABLE_EXCEPTION
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册