diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index aec0eb41a06f13949272125419d1ca3c5b4314cc..57100bd56013de1aa6f59fe68bcddad806a45867 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from functools import lru_cache -from typing import Iterable, Optional, Sequence, Tuple, Union +from typing import Iterable, List, Optional, Sequence, Tuple, Union import numpy as np @@ -36,6 +36,7 @@ __all__ = [ "full_like", "gather", "linspace", + "meshgrid", "ones", "ones_like", "repeat", @@ -1205,3 +1206,49 @@ def cumsum(inp: Tensor, axis: int): assert isinstance(inp, Tensor), "input of cumsum must be type of Tensor" op = builtin.Cumsum(axis=axis, exclusive=False, reverse=False) return apply(op, inp)[0] + + +def meshgrid(*inputs: Tensor, indexing: str = "xy") -> List[Tensor]: + r"""Returns coordinate matrices from coordinate vectors. + + Args: + inputs: an arbitrary number of one-dimensional tensors representing grid + coordinates. Each input should have the same numeric data type. + indexing: Cartesian ``'xy'`` or matrix ``'ij'`` indexing of output. + If provided zero or one one-dimensional vector(s) (i.e., the zero- and one-dimensional + cases, respectively), the indexing keyword has no effect and should be ignored. + + + Returns: + out: list of N tensors, where N is the number of provided one-dimensional input tensors. + Each returned tensor must have rank N. For N one-dimensional tensors having lengths ``Ni = len(xi)``, + + * if matrix indexing ``ij``, then each returned tensor must have the shape ``(N1, N2, N3, ..., Nn)``. + * if Cartesian indexing ``xy``, then each returned tensor must have shape ``(N2, N1, N3, ..., Nn)``. + + Accordingly, for the two-dimensional case with input one-dimensional tensors of length ``M`` and ``N``, + if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N)``, and, if Cartesian indexing ``xy``, + then each returned tensor must have shape ``(N, M)``. + + Similarly, for the three-dimensional case with input one-dimensional tensor of length ``M``, ``N``, and ``P``, + if matrix indexing ``ij``, then each returned tensor must have shape ``(M, N, P)``, and, if Cartesian indexing ``xy``, + then each returned tensor must have shape ``(N, M, P)``. + + Each returned tensor should have the same data type as the input tensors. + + Examples: + >>> nx, ny = (3, 2) + >>> x = F.linspace(0, 1, nx) + >>> y = F.linspace(0, 1, ny) + >>> xv, yv = F.meshgrid(x, y) + >>> xv + Tensor([[0. 0.5 1. ] + [0. 0.5 1. ]], device=xpux:0) + >>> yv + Tensor([[0. 0. 0.] + [1. 1. 1.]], device=xpux:0) + + + """ + op = builtin.MeshGrid(indexing) + return apply(op, *inputs) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 8112ab1c1cc44cabd08f4b0e73ec77082207aac9..ff35cfed98679a3ad7c03faa5cc8b68e5c990ae9 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -1,13 +1,129 @@ +#include +#include "megbrain/graph/helper.h" #include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" -#include "megbrain/graph/helper.h" - #include "../op_trait.h" namespace mgb { namespace imperative { +namespace meshgrid { +SmallVector get_input_layout_constraint( + const OpDef& def, const SmallVector& inputs) { + return SmallVector(inputs.size()); +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + for (size_t i = 0; i < inputs.size() - 1; i++) { + mgb_assert(inputs[i].layout.dtype == inputs[i + 1].layout.dtype); + mgb_assert(inputs[i].comp_node == inputs[i + 1].comp_node); + } + auto&& op = def.cast_final_safe(); + mgb_assert(op.indexing == "xy" || op.indexing == "ij"); + bool success = true; + SmallVector shp; + for (size_t i = 0; i < inputs.size(); i++) { + mgb_assert(inputs[i].layout.ndim <= 1); + if (inputs[i].layout.ndim == 0) { + success = false; + } + shp.push_back(inputs[i].layout.total_nr_elems()); + } + if (op.indexing == "xy" and shp.size() >= 2) { + std::swap(shp[0], shp[1]); + } + TensorShape tshp(shp); + SmallVector descs; + for (size_t i = 0; i < inputs.size(); i++) { + if (success) { + descs.push_back( + {TensorLayout(tshp, inputs[0].layout.dtype), inputs[0].comp_node}); + } else { + descs.push_back( + {TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}); + } + } + return {descs, success}; +} +VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = def.cast_final_safe(); + std::vector indexs(inputs.size()); + std::iota(indexs.begin(), indexs.end(), 0); + auto cn = inputs[0]->comp_node(); + auto graph = inputs[0]->owner_graph(); + if (op.indexing == "xy") { + if (indexs.size() >= 2) { + std::swap(indexs[0], indexs[1]); + } + } else { + mgb_assert(op.indexing == "ij", "meshgrid only support \"ij\" or \"xy\""); + } + VarNodeArray shps; + for (size_t ind = 0; ind < inputs.size(); ind++) { + auto&& inp = inputs[indexs[ind]]; + shps.push_back(opr::GetVarShape::make(inp).node()); + } + VarNode* tshp = opr::Concat::make(shps, 0, cn).node(); + VarNodeArray results; + auto t_ndim = inputs.size(); + for (size_t ind = 0; ind < inputs.size(); ind++) { + auto axis = indexs[ind]; + HostTensorND hv = HostTensorND(cn, {t_ndim}, dtype::Int32()); + auto* ptr = hv.ptr(); + std::fill_n(ptr, t_ndim, 1); + ptr[axis] = -1; + auto shp = opr::ImmutableTensor::make(*graph, hv, cn).node(); + auto tmp = opr::Reshape::make(inputs[ind], shp, axis).node(); + results.push_back(opr::Broadcast::make(tmp, tshp).node()); + } + return results; +} +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op = def.cast_final_safe(); + TensorShape tshp; + TensorShape view_shp; + tshp.ndim = inputs.size(); + view_shp.ndim = inputs.size(); + std::vector indexs(inputs.size()); + std::iota(indexs.begin(), indexs.end(), 0); + + if (op.indexing == "xy") { + if (indexs.size() >= 2) { + std::swap(indexs[0], indexs[1]); + } + } else { + mgb_assert(op.indexing == "ij", "meshgrid only support \"ij\" or \"xy\""); + } + for (size_t ind = 0; ind < inputs.size(); ind++) { + auto&& inp = inputs[indexs[ind]]; + mgb_assert(inp->layout().ndim <= 1); + tshp[ind] = inp->layout().total_nr_elems(); + view_shp[ind] = 1; + } + SmallVector grids; + for (size_t i = 0; i < inputs.size(); i++) { + auto&& src = inputs[i]; + TensorLayout layout; + view_shp[indexs[i]] = src->layout().total_nr_elems(); + mgb_assert(src->layout().try_reshape(layout, view_shp)); + layout = layout.broadcast(tshp); + view_shp[indexs[i]] = 1; + grids.push_back(Tensor::make(src->blob(), src->offset(), layout)); + } + return grids; +} +OP_TRAIT_REG(MeshGrid, MeshGrid) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .get_input_layout_constraint(get_input_layout_constraint) + .fallback(); +} // namespace meshgrid namespace broadcast { std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { @@ -211,7 +327,6 @@ SmallVector apply_on_physical_tensor( tshp, tshp_nd->get_value().proxy_to_default_cpu()); } if (op.axis != opr::Reshape::Param::INVALID_AXIS) { - mgb_assert(tshp[op.axis] == -1); tshp[op.axis] = 1; tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); } @@ -237,7 +352,6 @@ SmallVector get_input_layout_constraint( tshp, inputs[1]->get_value().proxy_to_default_cpu()); } if (op.axis != opr::Reshape::Param::INVALID_AXIS) { - mgb_assert(tshp[op.axis] == -1); tshp[op.axis] = 1; tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); } @@ -250,7 +364,7 @@ SmallVector get_input_layout_constraint( return layout_checker; } -OP_TRAIT_REG(Reshape, Reshape) +OP_TRAIT_REG(Reshape, Reshape, opr::Reshape) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) .apply_on_physical_tensor(apply_on_physical_tensor) diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index d8d21eb7c9d1a91f0cfcdfe2c62611af37d31512..35080e895349ed7f138157c87e8851aeb720f3f1 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ 905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py -e35e13523f43b7bea4034a0bf75937b7 ../../src/core/include/megbrain/ir/ops.td -240dccd6f8d42cadfd08c6ca90fe61b1 generated/opdef.h.inl -a79a4058ff18ffd9593ee5db3deef6c4 generated/opdef.cpp.inl -83c179ee7416824fbfab978a097cd4d3 generated/opdef.py.inl -86f70b1052331130f5e4c0ca53e68423 generated/opdef.cpy.inl +40708c56b1f05fdb7d06cc097a300330 ../../src/core/include/megbrain/ir/ops.td +9f3af118c7fe8d0c9db433825d5ad77b generated/opdef.h.inl +4041e44a8ba3cca3b3affa1ed9ed44a2 generated/opdef.cpp.inl +319e1d170c989fe793a4e9c45decefc4 generated/opdef.py.inl +26a18a7593566128ecce76e8f74dcc5d generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index a8ef16ababe1772bf2687c37723fe3a5bfa46ff3..08449b9daabfe8e5b9ad2c416529c1c109d0acd2 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -4672,6 +4672,43 @@ OP_TRAIT_REG(MatrixMul, MatrixMul) .props(MatrixMul_props_impl) .make_name(MatrixMul_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshGrid); + +namespace { +size_t MeshGrid_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::hash(op_.indexing)); + return val; +} +bool MeshGrid_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.indexing != b_.indexing) return false; + return true; +} +std::vector> MeshGrid_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + props_.emplace_back("indexing", op_.indexing); + return props_; +} +std::string MeshGrid_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "MeshGrid"; +} +} // anonymous namespace +OP_TRAIT_REG(MeshGrid, MeshGrid) + .hash(MeshGrid_hash_impl) + .is_same_st(MeshGrid_is_same_st_impl) + .props(MeshGrid_props_impl) + .make_name(MeshGrid_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(MeshIndexing); namespace { diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index fa1fa9fcedc5f35d292bdb137b8413fd4a6c6cc0..070dc1f2b655daa5755f455732a04f3b3fff3926 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -12467,6 +12467,95 @@ void _init_py_MatrixMul(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MatrixMul::typeinfo(), &py_type).second); } +PyOpDefBegin(MeshGrid) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"indexing", serialization::dump(opdef.indexing)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("indexing"); + if (iter != state.end()) { + opdef.indexing = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); +// }; +PyOpDefEnd(MeshGrid) + +int PyOp(MeshGrid)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"indexing", "scope", NULL}; + PyObject *indexing = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OO", const_cast(kwlist), &indexing, &scope)) + return -1; + + if (indexing) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().indexing = + py::cast(py::handle(indexing)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(MeshGrid)::py_getsetters[] = { + {const_cast("indexing"), py_get_generic(MeshGrid, indexing), py_set_generic(MeshGrid, indexing), const_cast("indexing"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(MeshGrid)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(MeshGrid)::getstate, METH_NOARGS, "MeshGrid getstate"}, + {const_cast("__setstate__"), PyOp(MeshGrid)::setstate, METH_VARARGS, "MeshGrid setstate"}, + {NULL} /* Sentinel */ + }; + +void _init_py_MeshGrid(py::module m) { + using py_op = PyOp(MeshGrid); + auto& py_type = PyOpType(MeshGrid); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.MeshGrid"; + py_type.tp_basicsize = sizeof(PyOp(MeshGrid)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "MeshGrid"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + mgb_assert(PyType_Ready(&py_type) >= 0); + + PyType_Modified(&py_type); + m.add_object("MeshGrid", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(MeshGrid::typeinfo(), &py_type).second); +} + PyOpDefBegin(MeshIndexing) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -18594,6 +18683,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_MagicMindRuntime(m); \ _init_py_MatrixInverse(m); \ _init_py_MatrixMul(m); \ + _init_py_MeshGrid(m); \ _init_py_MeshIndexing(m); \ _init_py_NMSKeep(m); \ _init_py_NvOf(m); \ diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index 47d78f116720812fdd9705b3b8489641c48266e7..294a6b299c493fadf5d92e3e0765ba5b052a926f 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -1262,6 +1262,15 @@ public: } }; +class MeshGrid : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + std::string indexing; + MeshGrid() = default; + MeshGrid(std::string indexing_, std::string scope_ = {}): indexing(indexing_) { set_scope(scope_); } +}; + class MeshIndexing : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index 696a6e0da49c6b5b6f74549e5b0c61b8dcc4fd73..781157005592ae56d983bbde7ee208aef38890ef 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1365,6 +1365,13 @@ MatrixMulInst .def_readwrite("dimA", &MatrixMul::dimA) .def_readwrite("dimB", &MatrixMul::dimB); +py::class_, OpDef> MeshGridInst(m, "MeshGrid"); + +MeshGridInst + .def(py::init(), py::arg("indexing"), py::arg("scope") = {}) + .def(py::init<>()) + .def_readwrite("indexing", &MeshGrid::indexing); + py::class_, OpDef> MeshIndexingInst(m, "MeshIndexing"); MeshIndexingInst diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index a7206006c13c1f9f074686d16a73744aa4b445ef..ac6968076ee836ad206a8abde286cd1e4f02c111 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -515,4 +515,9 @@ def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> { let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}]; } +def MeshGrid: MgbHashableOp<"MeshGrid"> { + let extraArguments = (ins + MgbStringAttr:$indexing + ); +} #endif // MGB_OPS