提交 bccda5c4 编写于 作者: M Megvii Engine Team

fix(mgb/imperative): fix repeat bug in trace mode

GitOrigin-RevId: 9547fc6102dbe10e5fe8879bdd463303bbffc866
上级 fca6c76a
......@@ -530,11 +530,23 @@ py::object _astensor1d_cpp(
return get_res_by_refhdl(value, dtype, device, ref);
}
if (lis.size() > 1) {
std::vector<PyObject*> c_args(lis.size() + 1);
for (size_t i = 0; i < lis.size(); ++i) {
c_args[i] = lis[i].ptr();
py::list flat_list;
for (auto item : lis) {
if (!PyList_Check(item.ptr())) {
flat_list.append(item);
} else {
py::list sub_lis =
py::reinterpret_steal<py::list>(PySequence_List(item.ptr()));
for (auto sub_item : sub_lis) {
flat_list.append(sub_item);
}
}
}
std::vector<PyObject*> c_args(flat_list.size() + 1);
for (size_t i = 0; i < flat_list.size(); ++i) {
c_args[i] = flat_list[i].ptr();
}
c_args[lis.size()] = Py_None;
c_args[flat_list.size()] = Py_None;
py::tuple inp_tup = py::reinterpret_steal<py::tuple>(
convert_inputs_cpp(NULL, c_args.data(), c_args.size()));
if (device_obj.is_none()) {
......
......@@ -161,6 +161,18 @@ def test_elemwise_fuse_in_grad(trace_mode):
y.numpy()
def test_repeat_in_trace():
@trace(symbolic=False)
def fun(data, repeats):
F.repeat(data, repeats)
data = tensor(np.random.random([1, 2, 3]).astype(np.float32))
for i in range(1, 5):
repeats = tensor(i)
fun(data, repeats)
def test_print_in_trace():
for symbolic in [False]: # cannot read value in symbolic mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册