提交 e954b8f9 编写于 作者: M Megvii Engine Team

feat(mgb/opr): let Subtensor support empty IO

GitOrigin-RevId: a768498104fbb6ece0b54f5fe4b901b07e2026ac
上级 1e83ab63
......@@ -15,6 +15,7 @@ from utils import make_tensor
import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
import megengine.jit as jit
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin
......@@ -584,3 +585,26 @@ def test_advance_indexing_with_bool(test_varnode):
np.testing.assert_equal(
a[:, b, 0:2, [True, False]], aa[:, bb, 0:2, [True, False]].numpy()
)
@pytest.mark.parametrize("symbolic", [True, False, None])
def test_subtensor_on_empty_tensor(symbolic):
np_x = np.array([], dtype=np.float32).reshape(10, 0, 10)
mge_x = megengine.tensor(np_x)
def run_test(fn):
out_ref = fn(np_x)
if symbolic is not None:
fn = jit.trace(symbolic=symbolic)(fn)
for i in range(3):
out = fn(mge_x)
np.testing.assert_equal(out.numpy(), out_ref)
run_test(lambda x: x[0:1, :, :])
run_test(lambda x: x[1:100:2, :, :])
run_test(lambda x: x[-10:5:2, :, :])
run_test(lambda x: x[5:1:-1, :, :])
run_test(lambda x: x[3, 10:1:1, 5])
run_test(lambda x: x[3, 10:1:1, 5:-1])
run_test(lambda x: x[:100, :100, :100])
run_test(lambda x: x[100:200, 300:400, 500:600])
......@@ -133,27 +133,42 @@ SubTensorSpec Slice::apply(TensorLayout layout, int axis) const {
return "None";
return std::to_string(v.val());
};
auto mod_size = [size_ax](ptrdiff_t v) {
auto mod_size = [size_ax](ptrdiff_t v)->ptrdiff_t {
if (size_ax == 0) return 0;
return v < 0 ? v + size_ax : v;
};
MGB_MARK_USED_VAR(tostr);
#define CHECK(cond) \
mgb_assert(cond, \
"index out of bound: layout=%s; request begin=%s end=%s step=%s " \
"axis=%d", \
layout.to_string().c_str(), tostr(m_begin).c_str(), \
tostr(m_end).c_str(), tostr(m_step).c_str(), axis)
#define CHECK(cond) \
if (m_is_scalar_idx) { \
mgb_assert(cond, \
"index out of bound: layout=%s; request index=%s, axis=%d", \
layout.to_string().c_str(), tostr(m_begin).c_str(), axis); \
} else { \
mgb_assert(cond, \
"index out of bound: layout=%s; request begin=%s end=%s step=%s " \
"axis=%d", \
layout.to_string().c_str(), tostr(m_begin).c_str(), \
tostr(m_end).c_str(), tostr(m_step).c_str(), axis); \
}
if (step > 0) {
begin = mod_size(m_begin.val_with_default(0));
end = mod_size(m_end.val_with_default(size_ax));
CHECK(begin >= 0 && end >= begin && end <= size_ax);
if (!m_is_scalar_idx) {
end = std::min(end, size_ax);
begin = std::min(begin, end);
}
CHECK(begin >= 0 && end >= begin && end <= size_ax)
} else {
begin = mod_size(m_begin.val_with_default(size_ax - 1));
end = m_end.valid() ? mod_size(m_end.val()) : -1;
if (!m_is_scalar_idx) {
begin = std::min(begin, std::max<ptrdiff_t>(size_ax-1, 0));
end = std::min(end, begin);
}
CHECK(step < 0 && begin >= 0 && end <= begin && begin < size_ax &&
end >= -1);
end >= -1)
}
auto step_abs = std::abs(step);
layout.shape[axis] = (std::abs(end - begin) + step_abs - 1) / step_abs;
......
......@@ -83,16 +83,20 @@ class SubTensorSpec {
/*!
* \brief slice along some axis; index as in Python, with negative indices
* supported
* supported. Scalar index can also be represented as a Slice, where
* m_begin = idx, m_end = idx+1 and m_step = 1. The flag m_is_scalar_idx
* indicates whether the Slice comes from a scalar index.
*/
class Slice {
Maybe<ptrdiff_t> m_begin, m_end, m_step;
bool m_is_scalar_idx;
public:
Slice(Maybe<ptrdiff_t> begin = None,
Maybe<ptrdiff_t> end = None,
Maybe<ptrdiff_t> step = None):
m_begin{begin}, m_end{end}, m_step{step}
Maybe<ptrdiff_t> step = None,
bool is_scalar_idx = false):
m_begin{begin}, m_end{end}, m_step{step}, m_is_scalar_idx{is_scalar_idx}
{ }
/*!
......
......@@ -178,7 +178,9 @@ SubTensorSpec FancyIndexingHelper::do_make_sub_spec(
i.axis.get_raw(), axis);
prev_axis = axis;
Maybe<ptrdiff_t> begin, end, step;
bool is_scalar_idx = false;
if (i.idx.node()) {
is_scalar_idx = true;
if (!m_require_scalar_index) {
continue;
}
......@@ -195,7 +197,7 @@ SubTensorSpec FancyIndexingHelper::do_make_sub_spec(
step = next_iv();
}
spec.merge_with(Slice(begin, end, step).apply(spec.layout(), axis));
spec.merge_with(Slice(begin, end, step, is_scalar_idx).apply(spec.layout(), axis));
}
mgb_assert(iv_iter == m_value_infer_result.end());
......
......@@ -660,7 +660,19 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) {
/* f{{{ ======================= Subtensor ======================= */
MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true);
Subtensor::Subtensor(VarNode *inp, const IndexDesc &desc,
const OperatorNodeConfig &config):
Super({inp->owner_graph(), config, "subtensor", {inp}},
inp, nullptr, desc, true) {
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
SymbolVar Subtensor::make(SymbolVar inp, const IndexDesc &desc,
const OperatorNodeConfig &config) {
return inp.insert_single_output_opr<Subtensor>(inp.node(), desc, config);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Subtensor);
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Subtensor) {
......@@ -722,6 +734,13 @@ void Subtensor::init_rt_force_dynamic_mem_alloc_imply_chain() {
out->add_rt_force_dynamic_mem_alloc_imply_chain(inp);
}
Subtensor::NodeProp* Subtensor::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
// f}}}
/* f{{{ ================== ModifySubtensorImplHelper ================== */
......
......@@ -358,6 +358,7 @@ MGB_DEFINE_OPR_CLASS(Subtensor,
void scn_do_execute() override;
void mem_plan_fwd_in2out_readonly() override;
void init_rt_force_dynamic_mem_alloc_imply_chain() override;
NodeProp* do_make_node_prop() const override;
public:
Subtensor(VarNode *inp, const IndexDesc &desc,
......
......@@ -894,6 +894,47 @@ TEST(TestTensorManip, SubtensorIdxChange) {
run(false);
}
TEST(TestTensorManip, SubtensorEmptyIO) {
using AIdx = opr::Subtensor::AxisIndexer;
using IndexDesc = std::vector<AIdx>;
using IndexDescCreater = thin_function<IndexDesc(SymbolVar)>;
HostTensorGenerator<> gen;
auto run = [&](const TensorShape& inp_shp, const TensorShape& out_shp, const IndexDescCreater& c) {
auto host_x = gen(inp_shp);
auto graph = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto y = opr::Subtensor::make(x, c(x));
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(host_y.shape(), out_shp);
ASSERT_TRUE(host_y.empty());
};
// x.shape = {0}, x[:0]
run({0}, {0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, None, x.make_scalar(0), None)};
});
// x.shape = {100, 0}, x[0:-10:2]
run({100, 0}, {45, 0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, x.make_scalar(0), x.make_scalar(-10), x.make_scalar(2))};
});
// x.shape = {100, 0}, x[10:-10:2, 0:0]
run({100, 0}, {40, 0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2)),
AIdx::make_interval(1, x.make_scalar(0), x.make_scalar(0), None)};
});
// x.shape = {10, 0, 10}, x[5, 10:-10:-2]
run({10, 0, 10}, {0, 10}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_index(0, x.make_scalar(5)),
AIdx::make_interval(1, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2))};
});
// x.shape = {10}, x[100:]
run({10}, {0}, [&](SymbolVar x)->IndexDesc {
return {AIdx::make_interval(0, x.make_scalar(100), None, None)};
});
}
namespace {
void test_subtensor_fwdonly(bool dyn_inp, bool dyn_idx) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册