diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 1c3c164a9a90e33c22cd73da5a15b12323d1159d..25e5b97876922992ecbb2554fb81b7c52b211be6 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -406,3 +406,16 @@ def test_clip(): for i in range(3): f(x, tensor([0]), tensor([1])) + + +# test returning noncontiguous tensor from trace +def test_slice(): + @trace + def f(x): + return x[:, 1::2] + + x = F.arange(8).reshape(2, 4) + f(x) + y = f(x) + np.testing.assert_array_equal(y.numpy(), x.numpy()[:, 1::2]) + y + y diff --git a/imperative/src/impl/opr_utility.cpp b/imperative/src/impl/opr_utility.cpp index cf0bf2a6c2ef750ed26d561ed1ca0b93b53d7c16..65052c6d6734eacc57ede497a78c3d912d31b603 100644 --- a/imperative/src/impl/opr_utility.cpp +++ b/imperative/src/impl/opr_utility.cpp @@ -156,6 +156,12 @@ cg::OperatorNodeBase::NodeProp* OutputCallback::do_make_node_prop() const { return prop; } +void OutputCallback::add_input_layout_constraint() { + if (m_param.require_contiguous) { + input(0)->add_layout_constraint_contiguous(); + } +} + void OutputCallback::scn_do_execute() { if (m_use_host_value) { m_param.callback(owner_graph()->static_infer_manager().infer_value(input(0))); diff --git a/imperative/src/include/megbrain/imperative/opr_utility.h b/imperative/src/include/megbrain/imperative/opr_utility.h index 14f5f2727abd5bf30037dccfaac5f9bb126397a0..fbd57d715d0397395af290dbdbd01233e87c77e2 100644 --- a/imperative/src/include/megbrain/imperative/opr_utility.h +++ b/imperative/src/include/megbrain/imperative/opr_utility.h @@ -62,6 +62,7 @@ public: callback_t callback; bool borrow = false; // do not obtain shared ownership on DeviceTensorND bool prefer_host_value = false; // use host value when possible + bool require_contiguous = true; }; OutputCallback(Param param, const VarNodeArray& inputs, @@ -80,6 +81,7 @@ protected: void scn_do_execute() override; void init_output_static_infer_desc() override; NodeProp* do_make_node_prop() const override; + void add_input_layout_constraint() override; private: Param m_param; mutable bool m_use_host_value;