diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 4e7e1f34a440ba0fefb89ddf8a3b13e3d2a86ddb..44f03c3234bcdc52fda2d558ca3d9c8829675f0d 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -924,78 +924,67 @@ bool enable_fastpath(py::handle inp) { return true; } -py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { - py::object shape_hdl = _expand_args(args); - bool auto_infer = false; - py::list lis; - py::list new_shape; - if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { - lis = py::reinterpret_steal(PySequence_List(shape_hdl.ptr())); - for (size_t i = 0; i < lis.size(); ++i) { - if (lis[i].is_none()) { - auto_infer = true; - size_t right = lis.size() - i; - py::object tshp = getattr(inp_hdl, "_tuple_shape"); - if (tshp.is_none()) { - throw py::index_error("does not support `None` with unknown shape"); +py::object _broadcast_cpp(py::handle input, py::handle args) { + py::object shape = _expand_args(args); + py::list dims; + bool all_imm; + if (PyList_Check(shape.ptr()) || PyTuple_Check(shape.ptr())) { + dims = py::reinterpret_steal(PySequence_List(shape.ptr())); + mgb_assert(!dims.is_none()); + all_imm = true; + py::object inp_shape = py::none(); + size_t inp_ndim; + for (size_t i = 0; i < dims.size(); ++i) { + py::object dim = dims[i]; + if (dim.is_none()) { + ptrdiff_t right = (ptrdiff_t)i - dims.size(); + if (inp_shape.is_none()) { + inp_shape = input.attr("shape"); + mgb_assert(!inp_shape.is_none()); + inp_ndim = py::len(inp_shape); } - py::tuple inp_shape = py::reinterpret_borrow(tshp); - if (inp_shape.size() >= right) { - if (enable_fastpath(inp_hdl)) { - lis[i] = inp_shape[inp_shape.size() - right]; - } - new_shape.append(inp_shape[inp_shape.size() - right]); - } else { - throw py::value_error("invalid broadcast shape"); + if ((ptrdiff_t)inp_ndim + right < 0) { + throw py::value_error("size connot be `None` for new axis"); } - } else { - new_shape.append(lis[i]); - if (PyLong_Check(lis[i].ptr())) { - int32_t s = lis[i].cast(); - if (s < 0) { - throw py::value_error( - "expect shape[" + std::to_string(i) + - "] >= 0 or use `None` to auto infer, got " + - std::to_string(s)); - } + dim = inp_shape.attr("__getitem__")(right); + dims[i] = dim; + } + if (py::int_::check_(dim)) { + if (dim.cast() < 0) { + throw py::value_error(ssprintf( + "expect shape[%zu] >= 0 or use `None` to auto infer, got " + "%s", + i, py::repr(dims[i]).cast().c_str())); } + } else { + all_imm = false; } } + shape = dims; + } else { + all_imm = false; } - if (auto_infer) { - if (enable_fastpath(inp_hdl)) { - shape_hdl = py::reinterpret_borrow(lis); - } else { - shape_hdl = _astensor1d_cpp( - new_shape, py::cast((mgb::DType)dtype::Int32()), - getattr(inp_hdl, "device"), inp_hdl); - } + bool fastpath = all_imm && enable_fastpath(input); + if ((!fastpath) && (!is_tensor(shape))) { + shape = _astensor1d_cpp( + shape, py::cast((mgb::DType)dtype::Int32()), input.attr("device"), + input); } - py::object shape_tuple; - try { - shape_tuple = _make_shape_tuple(shape_hdl); - } catch (py::error_already_set& err) { - shape_tuple = py::reinterpret_borrow(shape_hdl); - } - auto [shape, fastpath] = tuple2vector(shape_tuple); - fastpath &= enable_fastpath(inp_hdl); std::shared_ptr op; - std::vector p; - py::object shape_tensor; + SmallVector p(2); if (fastpath) { - op = Broadcast::make(shape); - p.resize(2); + std::vector shape_vec; + for (auto&& dim : dims) { + shape_vec.push_back(dim.cast()); + } + op = Broadcast::make(shape_vec); } else { op = Broadcast::make(); - shape_tensor = _astensor1d_cpp( - shape_hdl, py::cast((mgb::DType)dtype::Int32()), - getattr(inp_hdl, "device"), inp_hdl); - p.resize(3); - p[2] = shape_tensor.ptr(); + p.push_back(shape.ptr()); } - py::object Op = py::cast(op); - p[0] = Op.ptr(); - p[1] = inp_hdl.ptr(); + py::object py_op = py::cast(op); + p[0] = py_op.ptr(); + p[1] = input.ptr(); py::tuple ret = py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); return ret[0]; @@ -1675,4 +1664,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } -} // namespace mgb::imperative::python \ No newline at end of file +} // namespace mgb::imperative::python diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 69a2ed4302e13e96552964a9989dc1fd1069cba7..5651855bc39e36f58d9d735b43a64b73d9cf14a3 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -753,6 +753,40 @@ def test_broadcast_on_empty_tensor(is_trace): test(func, inp, comp, target_shp) +@pytest.mark.parametrize( + "input_shape, target_shapes", + [ + ((3,), [(2, 1, 3), (1, 2, 3), (2, 2, 3)]), + ((1, 3, 1), [(2, None, 3), (3, None, 3), (1, None, 1)]), + ], +) +@pytest.mark.parametrize("is_symbolic", [True, False]) +def test_broadcast_on_trace(is_symbolic, input_shape, target_shapes): + x = F.ones(input_shape) + + @trace(symbolic=is_symbolic) + def broadcast(inp, shape): + return F.broadcast_to(inp, shape) + + for target_shape in target_shapes: + if None in target_shape: + symbolic_target_shape = tuple( + map(lambda x: None if x is None else Tensor(x), target_shape) + ) + output = broadcast(x, symbolic_target_shape) + for i in range(len(target_shape)): + if target_shape[i] is not None: + assert output._tuple_shape[i] == target_shape[i] + else: + assert ( + output._tuple_shape[i] == x._tuple_shape[i - len(target_shape)] + ) + else: + symbolic_target_shape = Tensor(target_shape) + output = broadcast(x, symbolic_target_shape) + assert output._tuple_shape == target_shape + + @pytest.mark.parametrize("is_varnode", [True, False]) def test_utils_astensor1d(is_varnode): if is_varnode: