From 432fdb7e6aa19f64378b1faaacf928890e07a481 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 7 Aug 2021 18:48:42 +0800 Subject: [PATCH] feat(mgb/opr): let SetSubtensor support empty IO GitOrigin-RevId: 13909e3b11f54a1f3ea0acc7af259a4804c2d77b --- .../python/test/unit/core/test_indexing_op.py | 47 +++++++++++++++++++ .../src/impl/interpreter/interpreter_impl.cpp | 5 ++ src/opr/impl/tensor_manip.cpp | 30 +++++++++++- src/opr/include/megbrain/opr/tensor_manip.h | 1 + src/opr/test/tensor_manip.cpp | 42 +++++++++++++++++ 5 files changed, 123 insertions(+), 2 deletions(-) diff --git a/imperative/python/test/unit/core/test_indexing_op.py b/imperative/python/test/unit/core/test_indexing_op.py index d962c7e3..5f3c6689 100644 --- a/imperative/python/test/unit/core/test_indexing_op.py +++ b/imperative/python/test/unit/core/test_indexing_op.py @@ -608,3 +608,50 @@ def test_subtensor_on_empty_tensor(symbolic): run_test(lambda x: x[3, 10:1:1, 5:-1]) run_test(lambda x: x[:100, :100, :100]) run_test(lambda x: x[100:200, 300:400, 500:600]) + + +@pytest.mark.parametrize("symbolic", [True, False, None]) +def test_setsubtensor_on_empty_tensor(symbolic): + def run_test(inp_shp, fn): + np_x = np.random.randn(*inp_shp).astype(np.float32) + mge_x = megengine.tensor(np_x) + out_ref = fn(np_x) + if symbolic is not None: + fn = jit.trace(symbolic=symbolic)(fn) + for i in range(3): + out = fn(mge_x) + np.testing.assert_equal(out.numpy(), out_ref) + + def test1(x): + x[1:100:2, :, :] = x[1:100:2, :, :] + return x + + def test2(x): + x[-10:5:2, :, :] = x[-10:5:2, :, :] + return x + + def test3(x): + x[5:1:-1, :, :] = x[5:1:-1, :, :] + return x + + def test4(x): + x[3, 10:1:1, 5:-1] = x[3, 10:1:1, 5:-1] + return x + + def test5(x): + x[:100, :100, :100] = x[:100, :100, :100] + return x + + def test6(x): + x[100:200, 300:400, 500:600] = x[100:200, 300:400, 500:600] + return x + + run_test((10, 0, 10), test1) + run_test((10, 0, 10), test2) + run_test((10, 0, 10), test3) + run_test((10, 0, 10), test4) + run_test((10, 0, 10), test5) + run_test((10, 0, 10), test6) + run_test((10, 10, 10), test4) + run_test((10, 10, 10), test5) + run_test((10, 10, 10), test6) diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index d57430fe..e2507497 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -134,6 +134,11 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { } TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { + if (value.empty()) { + auto layout = value.layout(); + layout.init_contiguous_stride(); + const_cast(value).reset(value.storage(), layout); + } auto info = alloc(); init(info, {value.layout(), value.comp_node(), value.proxy_to_default_cpu()}); info->mem_desc.id = StorageIdentifier::make(++m_storage_id); diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index cf66e0d3..196f7200 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -819,10 +819,36 @@ void ModifySubtensorImplHelper::init_output_static_infer_desc() { /* f{{{ ======================= SetSubtensor ======================= */ -MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetSubtensor, "set_subtensor", true); +SetSubtensor::SetSubtensor(VarNode *inp, VarNode *value, const IndexDesc &desc, + const OperatorNodeConfig &config, + const InputTensorReplacer &input_tensor_replacer): + Super({inp->owner_graph(), config, "set_subtensor", {inp, value}}, + inp, value, desc, true, input_tensor_replacer) { + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); +} + +SymbolVar SetSubtensor::make(SymbolVar inp, SymbolVar value, const IndexDesc &desc, + const OperatorNodeConfig &config, + const InputTensorReplacer &input_tensor_replacer) { + return inp.insert_single_output_opr( + inp.node(), value.node(), desc, config, input_tensor_replacer); +} + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(SetSubtensor); void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { - sub.copy_from_fixlayout(val); + if (!val.layout().is_empty()) { + sub.copy_from_fixlayout(val); + } +} + +SetSubtensor::NodeProp* SetSubtensor::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + ret->add_dep_type_existing_var(input(1), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; } #if MGB_ENABLE_GRAD diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index 9f9be3a0..ef7eca64 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -374,6 +374,7 @@ MGB_DEFINE_OPR_CLASS(Subtensor, MGB_DEFINE_OPR_CLASS(SetSubtensor, intl::ModifySubtensorImplHelper) // { void modify(DeviceTensorND &sub, const DeviceTensorND &val) override; + NodeProp* do_make_node_prop() const override; public: MGB_DECL_FANCY_INDEXING_OPR_MODIFY(SetSubtensor); diff --git a/src/opr/test/tensor_manip.cpp b/src/opr/test/tensor_manip.cpp index a07757da..6b510e14 100644 --- a/src/opr/test/tensor_manip.cpp +++ b/src/opr/test/tensor_manip.cpp @@ -935,6 +935,48 @@ TEST(TestTensorManip, SubtensorEmptyIO) { }); } +TEST(TestTensorManip, SetSubtensorEmptyIO) { + using AIdx = opr::SetSubtensor::AxisIndexer; + using IndexDesc = std::vector; + using IndexDescCreater = thin_function; + HostTensorGenerator<> gen; + auto run = [&](const TensorShape& inp_shp, const TensorShape& val_shp, const IndexDescCreater& c) { + auto host_x = gen(inp_shp), + host_v = gen(val_shp); + auto graph = ComputingGraph::make(); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + v = opr::Host2DeviceCopy::make(*graph, host_v); + + auto y = opr::SetSubtensor::make(x, v, c(x)); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_EQ(host_y.shape(), inp_shp); + }; + // x.shape = {0}, v.shape = {0}, x[:0] = v + run({0}, {0}, [&](SymbolVar x)->IndexDesc { + return {AIdx::make_interval(0, None, x.make_scalar(0), None)}; + }); + // x.shape = {100, 0}, v.shape = {45, 0}, x[0:-10:2] = v + run({100, 0}, {45, 0}, [&](SymbolVar x)->IndexDesc { + return {AIdx::make_interval(0, x.make_scalar(0), x.make_scalar(-10), x.make_scalar(2))}; + }); + // x.shape = {100, 0}, v.shape = {40, 0}, x[10:-10:2, 0:0] = v + run({100, 0}, {40, 0}, [&](SymbolVar x)->IndexDesc { + return {AIdx::make_interval(0, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2)), + AIdx::make_interval(1, x.make_scalar(0), x.make_scalar(0), None)}; + }); + // x.shape = {10, 0, 10}, v.shape = {0, 10}, x[5, 10:-10:-2] = v + run({10, 0, 10}, {0, 10}, [&](SymbolVar x)->IndexDesc { + return {AIdx::make_index(0, x.make_scalar(5)), + AIdx::make_interval(1, x.make_scalar(10), x.make_scalar(-10), x.make_scalar(2))}; + }); + // x.shape = {10}, v.shape = {0}, x[100:] = v + run({10}, {0}, [&](SymbolVar x)->IndexDesc { + return {AIdx::make_interval(0, x.make_scalar(100), None, None)}; + }); +} + namespace { void test_subtensor_fwdonly(bool dyn_inp, bool dyn_idx) { -- GitLab