From d03fe58039a4a7e20133d12dcb51540af631f43c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 29 Dec 2022 19:18:37 +0800 Subject: [PATCH] fix(imperative): fix the split dump problem GitOrigin-RevId: 0a0265e59819a89b12853229eff9d14d1e55ace6 --- imperative/python/src/tensor_utils.cpp | 20 ++++++++++++++--- .../python/test/unit/jit/test_tracing.py | 22 +++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 268b34d01..638b538ec 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -400,7 +400,6 @@ py::object get_res_by_refhdl( ref = py::reinterpret_borrow(ref_hdl); } if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) { - auto temp = dtype.cast(); ComputingGraph* graph = getattr(ref, "graph").cast(); cg::VarNode* node = getattr(ref, "var").cast(); CompNode cn; @@ -1473,8 +1472,23 @@ py::object _split_cpp( std::to_string(axis) + " cannot be split into " + std::to_string(n_sections) + " sections"); } - op = Split::make(axis, n_sections); - p.resize(2); + if (enable_fastpath(inp_hdl)) { + op = Split::make(axis, n_sections); + p.resize(2); + } else { + size_t n_total_ = n_total.cast(); + for (size_t idx = 0; idx < n_sections; ++idx) { + auto section_size = (n_total_ + n_sections - idx - 1) / n_sections; + partitions.append(_Const( + py::int_(section_size), py::cast((mgb::DType)dtype::Int32()), + getattr(inp_hdl, "device"))); + } + op = Split::make(axis, 0); + p.resize(partitions.size() + 2); + for (size_t idx = 0; idx < partitions.size(); ++idx) { + p[idx + 2] = partitions[idx].ptr(); + } + } } py::object Op = py::cast(op); p[0] = Op.ptr(); diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index 4e4cff5d3..c12a27565 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -308,6 +308,28 @@ def test_dump_with_testcase(): f.dump(file, input_data=["#rand(0, 255, 1)"]) +def test_split_dump(): + class SimpleNet(Module): + def __init__(self, num_segments: int = 3): + super().__init__() + self.num_segments = num_segments + + def forward(self, x): + x = F.split(x, self.num_segments, axis=1) + return x + + model = SimpleNet() + model.eval() + data = tensor(np.random.random((1, 12, 224, 224))) + + @trace(symbolic=True, capture_as_const=True) + def fun(data, *, net): + return net(data) + + x = fun(data, net=model) + fun.dump(io.BytesIO(), arg_names=["data"]) + + @pytest.mark.parametrize("trace_mode", [False, True]) def test_trace_profiler(trace_mode): @trace(symbolic=trace_mode, profiling=True) -- GitLab