diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index da591cc343e9f566a9e4b790cd7c2fa95733586b..00148234148c4b1d3d472ceb1322c70e1cf20fda 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -15,9 +15,15 @@ import numpy as np from .. import _config from .._imperative_rt.common import CompNode -from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion +from .._imperative_rt.core2 import ( + SymbolVar, + Tensor, + apply, + broadcast_cpp, + dtype_promotion, +) from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar -from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp +from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp from ..ops import builtin from . import amp from .indexing import getitem, setitem @@ -331,70 +337,6 @@ def _matmul( return result -def _broadcast(inp, shape): - auto_infer = False - if isinstance(shape, (list, tuple)): - shape_tuple = list(shape) - for i, s in enumerate(shape_tuple): - if isinstance(s, type(None)): - if s is None: - right = i - len(shape_tuple) - inp_shape = inp._tuple_shape - if len(inp_shape) + right >= 0: - shape_tuple[right] = list(inp_shape)[right] - auto_infer = True - continue - else: - raise ValueError("invalided Broadcast shape") - else: - raise ValueError( - "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( - i, s - ) - ) - if s < 0: - raise ValueError( - "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( - i, s - ) - ) - if auto_infer: - shape = tuple(shape_tuple) - try: - shape_tuple = make_shape_tuple(shape) - except ValueError: - shape_tuple = shape - shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device) - (result,) = apply(builtin.Broadcast(), inp, shape) - return result - - -def _reshape(x, shape): - unspec_axis = None - try: - shape_tuple = make_shape_tuple(shape) - except ValueError: - pass - else: - # XXX: assume unspec_axis is not changed in trace - for i, s in enumerate(shape_tuple): - if s < 0: - if s != -1: - raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) - if unspec_axis is not None: - raise ValueError( - "multiple -1 in shape: {} & {}".format(unspec_axis, i) - ) - unspec_axis = i - shape = astensor1d(shape, x, dtype="int32", device=x.device) - if unspec_axis is None: - op = builtin.Reshape() - else: - op = builtin.Reshape(axis=unspec_axis) - (x,) = apply(op, x, shape) - return x - - def _unary_elwise(mode): def f(self): return _elwise(self, mode=mode) @@ -667,11 +609,11 @@ class ArrayMethodMixin(abc.ABC): def reshape(self, *args): r"""See :func:`~.reshape`.""" - return _reshape(self, _expand_args(args)) + return reshape_cpp(self, args) # FIXME: remove this method def _broadcast(self, *args): - return _broadcast(self, _expand_args(args)) + return broadcast_cpp(self, args) def transpose(self, *args): r"""See :func:`~.transpose`.""" @@ -679,7 +621,7 @@ class ArrayMethodMixin(abc.ABC): def flatten(self): r"""See :func:`~.flatten`.""" - return self.reshape(-1) + return reshape_cpp(self, (-1,)) def sum(self, axis=None, keepdims: bool = False): r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 68ce956cbc217197312a423bf78cd6ad4f5d52d7..9f4134d7f19f83bb03515ab651a22303a13e123a 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -15,6 +15,7 @@ from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import ( SymbolVar, apply, + broadcast_cpp, dtype_promotion, expand_dims_cpp, split_cpp, @@ -24,7 +25,6 @@ from ..core._wrap import as_device from ..core.ops import builtin from ..core.ops.builtin import Copy, Identity from ..core.ops.special import Const -from ..core.tensor.array_method import _broadcast from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn from ..device import get_default_device from ..tensor import Tensor @@ -360,7 +360,7 @@ def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: [[0. 1. 2.] [0. 1. 2.]] """ - return _broadcast(inp, shape) + return broadcast_cpp(inp, shape) def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 0e37499d3b9a9429fc4ddbb957588580c638d898..45a7b03f9a0c96185549fb3bf94efbc01cce5099 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -135,23 +135,24 @@ std::optional elemwise_grad_rule( std::optional reshape_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { - mgb_assert(inputs.size() == 2); + mgb_assert(inputs.size() == 1 || inputs.size() == 2); + size_t nr_inp = inputs.size(); std::array input_shapes; - for (size_t i = 0; i < 2; ++i) { + for (size_t i = 0; i < nr_inp; ++i) { if (inputs_require_grad[i]) { input_shapes[i] = get_shape(inputs[i]); } } auto maker = CustomGradMaker(backward, inputs.size()); maker.output_size(1).output_captured(0, false); - maker.backward([shapes = std::move(input_shapes)](Span grads) { + maker.backward([shapes = std::move(input_shapes), nr_inp](Span grads) { mgb_assert(grads.size() == 1); ValueRef grad = grads[0]; - SmallVector ret(2); + SmallVector ret(nr_inp); if (!grad) { return ret; } - for (size_t i = 0; i < 2; ++i) { + for (size_t i = 0; i < nr_inp; ++i) { if (shapes[i]) { ret[i] = reshape_to(grad, shapes[i]); } @@ -162,6 +163,37 @@ std::optional reshape_grad_rule( return imperative::apply(ApplyOp(op), inputs); } +std::optional broadcast_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + mgb_assert(inputs.size() == 1 || inputs.size() == 2); + size_t nr_inp = inputs.size(); + std::array input_shapes; + for (size_t i = 0; i < nr_inp; ++i) { + if (inputs_require_grad[i]) { + input_shapes[i] = get_shape(inputs[i]); + } + } + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([shapes = std::move(input_shapes), nr_inp](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(nr_inp); + if (!grad) { + return ret; + } + for (size_t i = 0; i < nr_inp; ++i) { + if (shapes[i]) { + ret[i] = reduce_to(grad, shapes[i]); + } + } + return ret; + }); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); +} + std::optional subtensor_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { @@ -330,6 +362,7 @@ struct Init { Init() { CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); + CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule); CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); CustomBackward::register_grad_rule( IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 58c9beb4baf94b57353a26c3343f2fa188b696bc..86f004aeac2dfa2f193b5ea20a06a5107311d757 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -637,6 +637,8 @@ WRAP_FUNC_PY35(split_cpp); WRAP_FUNC_PY35(expand_dims_cpp); WRAP_FUNC_PY35(squeeze_cpp); WRAP_FUNC_PY35(transpose_cpp); +WRAP_FUNC_PY35(broadcast_cpp); +WRAP_FUNC_PY35(reshape_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -773,6 +775,8 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), + MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), + MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 16f4062e1612e3d65a0b7acad3c27b98321143d1..2b0e3687b1b3af3c3ba955e7dfb06eb1444ebdd3 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -800,29 +800,46 @@ size_t fast_ndim(py::handle tensor) { return getattr(tensor, "ndim").cast(); } -py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { +py::object _expand_args(py::handle args) { + if (!PyTuple_Check(args.ptr())) { + return py::reinterpret_borrow(args); + } py::tuple args_tup = py::reinterpret_borrow(args.ptr()); + if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || + is_tensor_or_symbolvar(args_tup[0].ptr()))) { + return py::reinterpret_borrow(args_tup[0]); + } else { + return py::reinterpret_steal(PySequence_List(args_tup.ptr())); + } +} + +py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { + py::object obj = _expand_args(args); + py::list lis; + if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { + lis = py::reinterpret_steal(PySequence_List(obj.ptr())); + } else { + py::object np = getattr(obj, "numpy")(); + PyArrayObject* arr = (PyArrayObject*)np.ptr(); + PyObject* maybe_list = PyArray_ToList(arr); + if (PyList_Check(maybe_list)) { + lis = py::reinterpret_steal(maybe_list); + } + } if (fast_ndim(inp_hdl) == 0) { - if (args_tup.size() != 0) { + if (lis.size() != 0) { throw py::index_error( "transpose for scalar does not accept additional args"); } return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); } std::vector pattern; - if (!args_tup.size()) { + if (!lis.size()) { size_t ndim = getattr(inp_hdl, "ndim").cast(); for (size_t i = 0; i < ndim; ++i) { pattern.push_back(ndim - i - 1); } } else { - py::list lis; - if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || - is_tensor_or_symbolvar(args_tup[0].ptr()))) { - lis = py::reinterpret_steal(PySequence_List(args_tup[0].ptr())); - } else { - lis = py::reinterpret_steal(PySequence_List(args_tup.ptr())); - } for (size_t i = 0; i < lis.size(); ++i) { if (PyLong_Check(lis[i].ptr())) { pattern.push_back(lis[i].cast()); @@ -844,6 +861,182 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { return ret[0]; } +std::tuple, bool> tuple2vector(py::object shape) { + std::vector shp; + if (!PyTuple_Check(shape.ptr())) { + return {shp, false}; + } + py::tuple tup = py::reinterpret_borrow(shape); + for (size_t i = 0; i < tup.size(); ++i) { + if (!PyLong_Check(tup[i].ptr())) { + return {shp, false}; + } else { + shp.push_back(tup[i].cast()); + } + } + return {shp, true}; +} + +bool enable_fastpath(py::handle inp) { + if (!TensorWrapper::try_cast(inp.ptr()) || + TransformationManager::get_instance() + .segments[TransformationManager::Segment::Trace] + .size() > 0 || + TransformationManager::get_instance() + .segments[TransformationManager::Segment::ModuleTrace] + .size() > 0) { + return false; + } + return true; +} + +py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { + py::object shape_hdl = _expand_args(args); + bool auto_infer = false; + py::list lis; + py::list new_shape; + if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { + lis = py::reinterpret_steal(PySequence_List(shape_hdl.ptr())); + for (size_t i = 0; i < lis.size(); ++i) { + if (lis[i].ptr() == Py_None) { + auto_infer = true; + size_t right = lis.size() - i; + py::object tshp = getattr(inp_hdl, "_tuple_shape"); + if (tshp.ptr() == Py_None) { + throw py::index_error("does not support `None` with unknown shape"); + } + py::tuple inp_shape = py::reinterpret_borrow(tshp); + if (inp_shape.size() >= right) { + if (enable_fastpath(inp_hdl)) { + lis[i] = inp_shape[inp_shape.size() - right]; + } + new_shape.append(inp_shape[inp_shape.size() - right]); + } else { + throw py::value_error("invalid broadcast shape"); + } + } else { + new_shape.append(lis[i]); + if (PyLong_Check(lis[i].ptr())) { + int32_t s = lis[i].cast(); + if (s < 0) { + throw py::value_error( + "expect shape[" + std::to_string(i) + + "] >= 0 or use `None` to auto infer, got " + + std::to_string(s)); + } + } + } + } + } + if (auto_infer) { + if (enable_fastpath(inp_hdl)) { + shape_hdl = py::reinterpret_borrow(lis); + } else { + py::tuple args = py::make_tuple(new_shape, inp_hdl); + py::dict kwargs; + kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); + kwargs["device"] = getattr(inp_hdl, "device"); + shape_hdl = py::reinterpret_steal( + PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); + } + } + py::object shape_tuple; + try { + shape_tuple = _make_shape_tuple(shape_hdl); + } catch (py::error_already_set& err) { + shape_tuple = py::reinterpret_borrow(shape_hdl); + } + auto [shape, fastpath] = tuple2vector(shape_tuple); + fastpath &= enable_fastpath(inp_hdl); + std::shared_ptr op; + std::vector p; + py::object shape_tensor; + if (fastpath) { + op = Broadcast::make(shape); + p.resize(2); + } else { + op = Broadcast::make(); + py::tuple args = py::make_tuple(shape_hdl, inp_hdl); + py::dict kwargs; + kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); + kwargs["device"] = getattr(inp_hdl, "device"); + shape_tensor = py::reinterpret_steal( + PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); + p.resize(3); + p[2] = shape_tensor.ptr(); + } + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = inp_hdl.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; +} + +py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { + py::object shape_hdl = _expand_args(args); + py::object shape_tuple; + try { + shape_tuple = _make_shape_tuple(shape_hdl); + } catch (py::error_already_set& err) { + shape_tuple = py::reinterpret_borrow(shape_hdl); + } + int32_t unspec_axis = -1; + if (PyTuple_Check(shape_tuple.ptr())) { + py::tuple tup = py::reinterpret_borrow(shape_tuple); + for (size_t i = 0; i < tup.size(); ++i) { + py::object obj = py::reinterpret_borrow(tup[i]); + if (obj < py::int_(0)) { + if (obj.not_equal(py::int_(-1))) { + throw py::value_error( + "expect shape [" + std::to_string(i) + "] >= -1, got " + + repr(obj).cast()); + } + if (unspec_axis >= 0) { + throw py::value_error( + "multiple -1 in shape: " + std::to_string(unspec_axis) + + " & " + std::to_string(i)); + } + unspec_axis = i; + } + } + } + auto [shape, fastpath] = tuple2vector(shape_tuple); + fastpath &= enable_fastpath(inp_hdl); + std::shared_ptr op; + std::vector p; + py::object shape_tensor; + if (fastpath) { + if (unspec_axis >= 0) { + op = Reshape::make(unspec_axis, shape); + } else { + op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape); + } + p.resize(2); + } else { + shape.clear(); + if (unspec_axis >= 0) { + op = Reshape::make(unspec_axis, shape); + } else { + op = Reshape::make(); + } + py::tuple args = py::make_tuple(shape_hdl, inp_hdl); + py::dict kwargs; + kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); + kwargs["device"] = getattr(inp_hdl, "device"); + shape_tensor = py::reinterpret_steal( + PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); + p.resize(3); + p[2] = shape_tensor.ptr(); + } + py::object Op = py::cast(op); + p[0] = Op.ptr(); + p[1] = inp_hdl.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(py::handle(args[0])).release().ptr(); @@ -900,4 +1093,18 @@ PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _broadcast_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + +PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _reshape_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index 422ee461881b73379a27e9ca47e058baedb01163..4c721ff18d205703745edf77dbc457b26033274c 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -16,4 +16,8 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); + +PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 9470c35ee700cfe2f6617568524feabddc7edc23..ddbae8d4258c4e9d95ead8d4c6b3c28d996c8617 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -267,7 +267,7 @@ def test_broadcast_auto_infer(is_varnode): F.broadcast_to(xx, (None, 1, 2, 3)) F.broadcast_to(xx, (1, None, 2, 3)) - t = tensor(2, dtype=np.int32) + t = make_tensor(2, network) F.broadcast_to(xx, (t, None, 2, 3)) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index e97ef497200ac26a213ae2d55e676a414268d6e2..ba9100fb26f5745a6b580e08ba64eed5fa8295d1 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -51,57 +51,75 @@ bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { + auto&& op = def.cast_final_safe(); size_t nr_inp = inputs.size(); - mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); auto&& src = inputs[0]; - auto&& tshp = inputs[1]; - TensorShape out_shape; - if (tshp.layout.ndim == 0 || tshp.value.empty()) { - out_shape.ndim = 0; - return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; - } - mgb_assert( - tshp.layout.ndim == 1, - "target shape of Broadcast expects ndim=1; got ndim=%lu actually", - tshp.layout.ndim); - - size_t target_ndim = tshp.layout.shape[0]; - out_shape.ndim = target_ndim; - auto* ptr = tshp.value.ptr(); - for (size_t i = 0; i < target_ndim; ++i) { - out_shape[i] = ptr[i]; + if (nr_inp == 1) { + out_shape.ndim = op.shape.size(); + for (size_t i = 0; i < out_shape.ndim; ++i) { + out_shape[i] = op.shape[i]; + } + } else { + auto&& tshp = inputs[1]; + if (tshp.layout.ndim == 0 || tshp.value.empty()) { + out_shape.ndim = 0; + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, + false}; + } + mgb_assert( + tshp.layout.ndim == 1, + "target shape of Broadcast expects ndim=1; got ndim=%lu actually", + tshp.layout.ndim); + size_t target_ndim = tshp.layout.shape[0]; + out_shape.ndim = target_ndim; + auto* ptr = tshp.value.ptr(); + for (size_t i = 0; i < target_ndim; ++i) { + out_shape[i] = ptr[i]; + } } mgb_assert( valid_broadcast(src.layout, out_shape), "the input shape %s can not be broadcasted to target shape %s", src.layout.to_string().c_str(), out_shape.to_string().c_str()); - return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { - def.cast_final_safe(); + auto&& op = def.cast_final_safe(); size_t nr_inp = inputs.size(); - mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); + TensorShape tshp; auto&& src = inputs[0]; - auto&& tshp_nd = inputs[1]; auto slayout = src->layout(); - - TensorShape tshp; - cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); + if (nr_inp == 1) { + tshp.ndim = op.shape.size(); + for (size_t i = 0; i < tshp.ndim; ++i) { + tshp[i] = op.shape[i]; + } + } else { + auto&& tshp_nd = inputs[1]; + cg::copy_tensor_value_to_shape( + tshp, tshp_nd->get_value().proxy_to_default_cpu()); + } TensorLayout tlayout = slayout.broadcast(tshp); // memory forward return {Tensor::make(src->blob(), src->offset(), tlayout)}; } +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + SmallVector layout_checker(inputs.size()); + return layout_checker; +} + OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_physical_tensor(apply_on_physical_tensor) + .get_input_layout_constraint(get_input_layout_constraint) .fallback(); } // namespace broadcast @@ -118,35 +136,49 @@ std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op = def.cast_final_safe(); size_t nr_inp = inputs.size(); - mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); auto&& src = inputs[0]; - auto&& tshp = inputs[1]; TensorShape out_shape; - if (tshp.layout.ndim == 0 || tshp.value.empty()) { - out_shape.ndim = 0; - return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; - } - mgb_assert( - tshp.layout.ndim == 1, - "target shape of Reshape expects ndim=1; got ndim=%lu actually", - tshp.layout.ndim); - - if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { - return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; - } - size_t target_ndim = tshp.layout.shape[0]; - out_shape.ndim = target_ndim; - auto* ptr = tshp.value.ptr(); - for (size_t i = 0; i < target_ndim; ++i) { - out_shape[i] = ptr[i]; - } - - if (src.layout.ndim == 0) { - return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; + if (nr_inp == 1) { + if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, + false}; + } + out_shape.ndim = op.shape.size(); + for (size_t i = 0; i < out_shape.ndim; ++i) { + out_shape[i] = op.shape[i]; + } + if (src.layout.ndim == 0) { + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, + false}; + } + } else { + auto&& tshp = inputs[1]; + if (tshp.layout.ndim == 0 || tshp.value.empty()) { + out_shape.ndim = 0; + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, + false}; + } + mgb_assert( + tshp.layout.ndim == 1, + "target shape of Reshape expects ndim=1; got ndim=%lu actually", + tshp.layout.ndim); + if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, + false}; + } + size_t target_ndim = tshp.layout.shape[0]; + out_shape.ndim = target_ndim; + auto* ptr = tshp.value.ptr(); + for (size_t i = 0; i < target_ndim; ++i) { + out_shape[i] = ptr[i]; + } + if (src.layout.ndim == 0) { + return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, + false}; + } } - if (op.axis != opr::Reshape::Param::INVALID_AXIS) { mgb_assert(out_shape[op.axis] == -1); out_shape[op.axis] = 1; @@ -167,19 +199,27 @@ std::tuple, bool> infer_output_attrs_fallible( SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { - auto&& op_def = def.cast_final_safe(); + auto&& op = def.cast_final_safe(); size_t nr_inp = inputs.size(); - mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); auto&& src = inputs[0]; - auto&& tshp_nd = inputs[1]; auto slayout = src->layout(); - TensorShape tshp; - cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); - if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { - mgb_assert(tshp[op_def.axis] == -1); - tshp[op_def.axis] = 1; - tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); + + if (nr_inp == 1) { + tshp.ndim = op.shape.size(); + for (size_t i = 0; i < tshp.ndim; ++i) { + tshp[i] = op.shape[i]; + } + } else { + auto&& tshp_nd = inputs[1]; + + cg::copy_tensor_value_to_shape( + tshp, tshp_nd->get_value().proxy_to_default_cpu()); + } + if (op.axis != opr::Reshape::Param::INVALID_AXIS) { + mgb_assert(tshp[op.axis] == -1); + tshp[op.axis] = 1; + tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); } TensorLayout tlayout; mgb_assert(slayout.try_reshape(tlayout, tshp)); @@ -188,17 +228,24 @@ SmallVector apply_on_physical_tensor( SmallVector get_input_layout_constraint( const OpDef& def, const SmallVector& inputs) { - auto&& op_def = def.cast_final_safe(); + auto&& op = def.cast_final_safe(); SmallVector layout_checker(inputs.size()); layout_checker[0] = [&](const TensorLayout& layout) { TensorShape tshp; TensorLayout ret; - cg::copy_tensor_value_to_shape( - tshp, inputs[1]->get_value().proxy_to_default_cpu()); - if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { - mgb_assert(tshp[op_def.axis] == -1); - tshp[op_def.axis] = 1; - tshp[op_def.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); + if (inputs.size() == 1) { + tshp.ndim = op.shape.size(); + for (size_t i = 0; i < tshp.ndim; ++i) { + tshp[i] = op.shape[i]; + } + } else { + cg::copy_tensor_value_to_shape( + tshp, inputs[1]->get_value().proxy_to_default_cpu()); + } + if (op.axis != opr::Reshape::Param::INVALID_AXIS) { + mgb_assert(tshp[op.axis] == -1); + tshp[op.axis] = 1; + tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); } if (layout.try_reshape(ret, tshp)) { return true; diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp index ea1ba6725b8a762748e2af2cf211407a6b56252a..544ee51290ad84221898fc0be3d11522bbb19070 100644 --- a/imperative/src/impl/transformations/scalar.cpp +++ b/imperative/src/impl/transformations/scalar.cpp @@ -243,8 +243,10 @@ ValueRefList get_var_shape_rule( ValueRefList reshape_rule( const Reshape& reshape, Span inputs, Span inputs_mask, const Type& scalar_type) { - mgb_assert(inputs.size() == 2); - bool is_scalar = is_scalar_shape(inputs[1]); + mgb_assert(inputs.size() == 1 || inputs.size() == 2); + size_t nr_inp = inputs.size(); + bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || + (nr_inp == 1 && reshape.shape.size() == 0); if (is_scalar) { return {scalar_type.make(imperative::apply( reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; @@ -256,8 +258,10 @@ ValueRefList reshape_rule( ValueRefList broadcast_rule( const Broadcast& broadcast, Span inputs, Span inputs_mask, const Type& scalar_type) { - mgb_assert(inputs.size() == 2); - bool is_scalar = is_scalar_shape(inputs[1]); + mgb_assert(inputs.size() == 1 || inputs.size() == 2); + size_t nr_inp = inputs.size(); + bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || + (nr_inp == 1 && broadcast.shape.size() == 0); if (is_scalar) { return {scalar_type.make(imperative::apply( broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 68432bb7f8e354f3ceb6263c2602057fa97f2841..b244e1a36fc7cf256b00b3080d86693ffe3144e6 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -250,7 +250,11 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> { ); } -def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>; +def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> { + let extraArguments = (ins + MgbArrayAttr:$shape + ); +} def Identity: MgbHashableOp<"Identity">; @@ -318,7 +322,11 @@ def Dimshuffle: MgbHashableOp<"Dimshuffle"> { let results = (outs AnyMemRef); } -def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>; +def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]> { + let extraArguments = (ins + MgbArrayAttr:$shape + ); +} // TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? def AddAxis: MgbHashableOp<"AddAxis"> {