提交 7252825c 编写于 作者: M Megvii Engine Team

fix(functional): broadcast_to supports mutable target shape

GitOrigin-RevId: ff79456d5d2d669d20112d57fdeb255ae837e868
上级 2484cd27
...@@ -924,78 +924,67 @@ bool enable_fastpath(py::handle inp) { ...@@ -924,78 +924,67 @@ bool enable_fastpath(py::handle inp) {
return true; return true;
} }
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { py::object _broadcast_cpp(py::handle input, py::handle args) {
py::object shape_hdl = _expand_args(args); py::object shape = _expand_args(args);
bool auto_infer = false; py::list dims;
py::list lis; bool all_imm;
py::list new_shape; if (PyList_Check(shape.ptr()) || PyTuple_Check(shape.ptr())) {
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { dims = py::reinterpret_steal<py::list>(PySequence_List(shape.ptr()));
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr())); mgb_assert(!dims.is_none());
for (size_t i = 0; i < lis.size(); ++i) { all_imm = true;
if (lis[i].is_none()) { py::object inp_shape = py::none();
auto_infer = true; size_t inp_ndim;
size_t right = lis.size() - i; for (size_t i = 0; i < dims.size(); ++i) {
py::object tshp = getattr(inp_hdl, "_tuple_shape"); py::object dim = dims[i];
if (tshp.is_none()) { if (dim.is_none()) {
throw py::index_error("does not support `None` with unknown shape"); ptrdiff_t right = (ptrdiff_t)i - dims.size();
if (inp_shape.is_none()) {
inp_shape = input.attr("shape");
mgb_assert(!inp_shape.is_none());
inp_ndim = py::len(inp_shape);
} }
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp); if ((ptrdiff_t)inp_ndim + right < 0) {
if (inp_shape.size() >= right) { throw py::value_error("size connot be `None` for new axis");
if (enable_fastpath(inp_hdl)) {
lis[i] = inp_shape[inp_shape.size() - right];
}
new_shape.append(inp_shape[inp_shape.size() - right]);
} else {
throw py::value_error("invalid broadcast shape");
} }
} else { dim = inp_shape.attr("__getitem__")(right);
new_shape.append(lis[i]); dims[i] = dim;
if (PyLong_Check(lis[i].ptr())) { }
int32_t s = lis[i].cast<int32_t>(); if (py::int_::check_(dim)) {
if (s < 0) { if (dim.cast<long>() < 0) {
throw py::value_error( throw py::value_error(ssprintf(
"expect shape[" + std::to_string(i) + "expect shape[%zu] >= 0 or use `None` to auto infer, got "
"] >= 0 or use `None` to auto infer, got " + "%s",
std::to_string(s)); i, py::repr(dims[i]).cast<std::string>().c_str()));
}
} }
} else {
all_imm = false;
} }
} }
shape = dims;
} else {
all_imm = false;
} }
if (auto_infer) { bool fastpath = all_imm && enable_fastpath(input);
if (enable_fastpath(inp_hdl)) { if ((!fastpath) && (!is_tensor(shape))) {
shape_hdl = py::reinterpret_borrow<py::tuple>(lis); shape = _astensor1d_cpp(
} else { shape, py::cast((mgb::DType)dtype::Int32()), input.attr("device"),
shape_hdl = _astensor1d_cpp( input);
new_shape, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl);
}
} }
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
std::vector<PyObject*> p; SmallVector<PyObject*> p(2);
py::object shape_tensor;
if (fastpath) { if (fastpath) {
op = Broadcast::make(shape); std::vector<int32_t> shape_vec;
p.resize(2); for (auto&& dim : dims) {
shape_vec.push_back(dim.cast<long>());
}
op = Broadcast::make(shape_vec);
} else { } else {
op = Broadcast::make(); op = Broadcast::make();
shape_tensor = _astensor1d_cpp( p.push_back(shape.ptr());
shape_hdl, py::cast((mgb::DType)dtype::Int32()),
getattr(inp_hdl, "device"), inp_hdl);
p.resize(3);
p[2] = shape_tensor.ptr();
} }
py::object Op = py::cast(op); py::object py_op = py::cast(op);
p[0] = Op.ptr(); p[0] = py_op.ptr();
p[1] = inp_hdl.ptr(); p[1] = input.ptr();
py::tuple ret = py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0]; return ret[0];
...@@ -1675,4 +1664,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) { ...@@ -1675,4 +1664,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr) PYEXT17_TRANSLATE_EXC_RET(nullptr)
} }
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
\ No newline at end of file
...@@ -753,6 +753,40 @@ def test_broadcast_on_empty_tensor(is_trace): ...@@ -753,6 +753,40 @@ def test_broadcast_on_empty_tensor(is_trace):
test(func, inp, comp, target_shp) test(func, inp, comp, target_shp)
@pytest.mark.parametrize(
"input_shape, target_shapes",
[
((3,), [(2, 1, 3), (1, 2, 3), (2, 2, 3)]),
((1, 3, 1), [(2, None, 3), (3, None, 3), (1, None, 1)]),
],
)
@pytest.mark.parametrize("is_symbolic", [True, False])
def test_broadcast_on_trace(is_symbolic, input_shape, target_shapes):
x = F.ones(input_shape)
@trace(symbolic=is_symbolic)
def broadcast(inp, shape):
return F.broadcast_to(inp, shape)
for target_shape in target_shapes:
if None in target_shape:
symbolic_target_shape = tuple(
map(lambda x: None if x is None else Tensor(x), target_shape)
)
output = broadcast(x, symbolic_target_shape)
for i in range(len(target_shape)):
if target_shape[i] is not None:
assert output._tuple_shape[i] == target_shape[i]
else:
assert (
output._tuple_shape[i] == x._tuple_shape[i - len(target_shape)]
)
else:
symbolic_target_shape = Tensor(target_shape)
output = broadcast(x, symbolic_target_shape)
assert output._tuple_shape == target_shape
@pytest.mark.parametrize("is_varnode", [True, False]) @pytest.mark.parametrize("is_varnode", [True, False])
def test_utils_astensor1d(is_varnode): def test_utils_astensor1d(is_varnode):
if is_varnode: if is_varnode:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册