diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 44f03c3234bcdc52fda2d558ca3d9c8829675f0d..e8f75b34ec43d3b6f9bd0823237f8e5fe058391a 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -530,11 +530,23 @@ py::object _astensor1d_cpp( return get_res_by_refhdl(value, dtype, device, ref); } if (lis.size() > 1) { - std::vector c_args(lis.size() + 1); - for (size_t i = 0; i < lis.size(); ++i) { - c_args[i] = lis[i].ptr(); + py::list flat_list; + for (auto item : lis) { + if (!PyList_Check(item.ptr())) { + flat_list.append(item); + } else { + py::list sub_lis = + py::reinterpret_steal(PySequence_List(item.ptr())); + for (auto sub_item : sub_lis) { + flat_list.append(sub_item); + } + } + } + std::vector c_args(flat_list.size() + 1); + for (size_t i = 0; i < flat_list.size(); ++i) { + c_args[i] = flat_list[i].ptr(); } - c_args[lis.size()] = Py_None; + c_args[flat_list.size()] = Py_None; py::tuple inp_tup = py::reinterpret_steal( convert_inputs_cpp(NULL, c_args.data(), c_args.size())); if (device_obj.is_none()) { diff --git a/imperative/python/test/unit/jit/test_tracing.py b/imperative/python/test/unit/jit/test_tracing.py index dbd7885c06b4285075561ad2c57feaf5b7bf85db..4e4cff5d3f7fa0318e39542fb135ce8d1fec0846 100644 --- a/imperative/python/test/unit/jit/test_tracing.py +++ b/imperative/python/test/unit/jit/test_tracing.py @@ -161,6 +161,18 @@ def test_elemwise_fuse_in_grad(trace_mode): y.numpy() +def test_repeat_in_trace(): + @trace(symbolic=False) + def fun(data, repeats): + F.repeat(data, repeats) + + data = tensor(np.random.random([1, 2, 3]).astype(np.float32)) + + for i in range(1, 5): + repeats = tensor(i) + fun(data, repeats) + + def test_print_in_trace(): for symbolic in [False]: # cannot read value in symbolic mode