提交 1709b394 编写于 作者: M Megvii Engine Team

perf(mge/functional): speed up Broadcast and Reshape

GitOrigin-RevId: a72f5460b6966f815449fc44ad6e5ac1e0ec021c
上级 0f736a0a
......@@ -15,9 +15,15 @@ import numpy as np
from .. import _config
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion
from .._imperative_rt.core2 import (
SymbolVar,
Tensor,
apply,
broadcast_cpp,
dtype_promotion,
)
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp
from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp
from ..ops import builtin
from . import amp
from .indexing import getitem, setitem
......@@ -331,70 +337,6 @@ def _matmul(
return result
def _broadcast(inp, shape):
auto_infer = False
if isinstance(shape, (list, tuple)):
shape_tuple = list(shape)
for i, s in enumerate(shape_tuple):
if isinstance(s, type(None)):
if s is None:
right = i - len(shape_tuple)
inp_shape = inp._tuple_shape
if len(inp_shape) + right >= 0:
shape_tuple[right] = list(inp_shape)[right]
auto_infer = True
continue
else:
raise ValueError("invalided Broadcast shape")
else:
raise ValueError(
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format(
i, s
)
)
if s < 0:
raise ValueError(
"expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format(
i, s
)
)
if auto_infer:
shape = tuple(shape_tuple)
try:
shape_tuple = make_shape_tuple(shape)
except ValueError:
shape_tuple = shape
shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device)
(result,) = apply(builtin.Broadcast(), inp, shape)
return result
def _reshape(x, shape):
unspec_axis = None
try:
shape_tuple = make_shape_tuple(shape)
except ValueError:
pass
else:
# XXX: assume unspec_axis is not changed in trace
for i, s in enumerate(shape_tuple):
if s < 0:
if s != -1:
raise ValueError("expect shape[{}] >= -1, got {}".format(i, s))
if unspec_axis is not None:
raise ValueError(
"multiple -1 in shape: {} & {}".format(unspec_axis, i)
)
unspec_axis = i
shape = astensor1d(shape, x, dtype="int32", device=x.device)
if unspec_axis is None:
op = builtin.Reshape()
else:
op = builtin.Reshape(axis=unspec_axis)
(x,) = apply(op, x, shape)
return x
def _unary_elwise(mode):
def f(self):
return _elwise(self, mode=mode)
......@@ -667,11 +609,11 @@ class ArrayMethodMixin(abc.ABC):
def reshape(self, *args):
r"""See :func:`~.reshape`."""
return _reshape(self, _expand_args(args))
return reshape_cpp(self, args)
# FIXME: remove this method
def _broadcast(self, *args):
return _broadcast(self, _expand_args(args))
return broadcast_cpp(self, args)
def transpose(self, *args):
r"""See :func:`~.transpose`."""
......@@ -679,7 +621,7 @@ class ArrayMethodMixin(abc.ABC):
def flatten(self):
r"""See :func:`~.flatten`."""
return self.reshape(-1)
return reshape_cpp(self, (-1,))
def sum(self, axis=None, keepdims: bool = False):
r"""Returns the sum of each row of the input tensor in the given dimension ``axis``.
......
......@@ -15,6 +15,7 @@ from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import (
SymbolVar,
apply,
broadcast_cpp,
dtype_promotion,
expand_dims_cpp,
split_cpp,
......@@ -24,7 +25,6 @@ from ..core._wrap import as_device
from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity
from ..core.ops.special import Const
from ..core.tensor.array_method import _broadcast
from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn
from ..device import get_default_device
from ..tensor import Tensor
......@@ -360,7 +360,7 @@ def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
[[0. 1. 2.]
[0. 1. 2.]]
"""
return _broadcast(inp, shape)
return broadcast_cpp(inp, shape)
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
......
......@@ -135,23 +135,24 @@ std::optional<ValueRefList> elemwise_grad_rule(
std::optional<ValueRefList> reshape_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 2);
mgb_assert(inputs.size() == 1 || inputs.size() == 2);
size_t nr_inp = inputs.size();
std::array<ValueRef, 2> input_shapes;
for (size_t i = 0; i < 2; ++i) {
for (size_t i = 0; i < nr_inp; ++i) {
if (inputs_require_grad[i]) {
input_shapes[i] = get_shape(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(2);
SmallVector<ValueRef> ret(nr_inp);
if (!grad) {
return ret;
}
for (size_t i = 0; i < 2; ++i) {
for (size_t i = 0; i < nr_inp; ++i) {
if (shapes[i]) {
ret[i] = reshape_to(grad, shapes[i]);
}
......@@ -162,6 +163,37 @@ std::optional<ValueRefList> reshape_grad_rule(
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<ValueRefList> broadcast_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 1 || inputs.size() == 2);
size_t nr_inp = inputs.size();
std::array<ValueRef, 2> input_shapes;
for (size_t i = 0; i < nr_inp; ++i) {
if (inputs_require_grad[i]) {
input_shapes[i] = get_shape(inputs[i]);
}
}
auto maker = CustomGradMaker(backward, inputs.size());
maker.output_size(1).output_captured(0, false);
maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
SmallVector<ValueRef> ret(nr_inp);
if (!grad) {
return ret;
}
for (size_t i = 0; i < nr_inp; ++i) {
if (shapes[i]) {
ret[i] = reduce_to(grad, shapes[i]);
}
}
return ret;
});
maker.finalize();
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<ValueRefList> subtensor_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
......@@ -330,6 +362,7 @@ struct Init {
Init() {
CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule);
CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule);
CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule);
CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule);
CustomBackward::register_grad_rule(
IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule);
......
......@@ -637,6 +637,8 @@ WRAP_FUNC_PY35(split_cpp);
WRAP_FUNC_PY35(expand_dims_cpp);
WRAP_FUNC_PY35(squeeze_cpp);
WRAP_FUNC_PY35(transpose_cpp);
WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
......@@ -773,6 +775,8 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
{nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) {
......
......@@ -800,29 +800,46 @@ size_t fast_ndim(py::handle tensor) {
return getattr(tensor, "ndim").cast<size_t>();
}
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::object _expand_args(py::handle args) {
if (!PyTuple_Check(args.ptr())) {
return py::reinterpret_borrow<py::object>(args);
}
py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr());
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) ||
is_tensor_or_symbolvar(args_tup[0].ptr()))) {
return py::reinterpret_borrow<py::object>(args_tup[0]);
} else {
return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr()));
}
}
py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::object obj = _expand_args(args);
py::list lis;
if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) {
lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr()));
} else {
py::object np = getattr(obj, "numpy")();
PyArrayObject* arr = (PyArrayObject*)np.ptr();
PyObject* maybe_list = PyArray_ToList(arr);
if (PyList_Check(maybe_list)) {
lis = py::reinterpret_steal<py::list>(maybe_list);
}
}
if (fast_ndim(inp_hdl) == 0) {
if (args_tup.size() != 0) {
if (lis.size() != 0) {
throw py::index_error(
"transpose for scalar does not accept additional args");
}
return getattr(inp_hdl, "to")(getattr(inp_hdl, "device"));
}
std::vector<int32_t> pattern;
if (!args_tup.size()) {
if (!lis.size()) {
size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>();
for (size_t i = 0; i < ndim; ++i) {
pattern.push_back(ndim - i - 1);
}
} else {
py::list lis;
if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) ||
is_tensor_or_symbolvar(args_tup[0].ptr()))) {
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup[0].ptr()));
} else {
lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr()));
}
for (size_t i = 0; i < lis.size(); ++i) {
if (PyLong_Check(lis[i].ptr())) {
pattern.push_back(lis[i].cast<int32_t>());
......@@ -844,6 +861,182 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
return ret[0];
}
std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) {
std::vector<int32_t> shp;
if (!PyTuple_Check(shape.ptr())) {
return {shp, false};
}
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
for (size_t i = 0; i < tup.size(); ++i) {
if (!PyLong_Check(tup[i].ptr())) {
return {shp, false};
} else {
shp.push_back(tup[i].cast<int32_t>());
}
}
return {shp, true};
}
bool enable_fastpath(py::handle inp) {
if (!TensorWrapper::try_cast(inp.ptr()) ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::Trace]
.size() > 0 ||
TransformationManager::get_instance()
.segments[TransformationManager::Segment::ModuleTrace]
.size() > 0) {
return false;
}
return true;
}
py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) {
py::object shape_hdl = _expand_args(args);
bool auto_infer = false;
py::list lis;
py::list new_shape;
if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) {
lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr()));
for (size_t i = 0; i < lis.size(); ++i) {
if (lis[i].ptr() == Py_None) {
auto_infer = true;
size_t right = lis.size() - i;
py::object tshp = getattr(inp_hdl, "_tuple_shape");
if (tshp.ptr() == Py_None) {
throw py::index_error("does not support `None` with unknown shape");
}
py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp);
if (inp_shape.size() >= right) {
if (enable_fastpath(inp_hdl)) {
lis[i] = inp_shape[inp_shape.size() - right];
}
new_shape.append(inp_shape[inp_shape.size() - right]);
} else {
throw py::value_error("invalid broadcast shape");
}
} else {
new_shape.append(lis[i]);
if (PyLong_Check(lis[i].ptr())) {
int32_t s = lis[i].cast<int32_t>();
if (s < 0) {
throw py::value_error(
"expect shape[" + std::to_string(i) +
"] >= 0 or use `None` to auto infer, got " +
std::to_string(s));
}
}
}
}
}
if (auto_infer) {
if (enable_fastpath(inp_hdl)) {
shape_hdl = py::reinterpret_borrow<py::tuple>(lis);
} else {
py::tuple args = py::make_tuple(new_shape, inp_hdl);
py::dict kwargs;
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32());
kwargs["device"] = getattr(inp_hdl, "device");
shape_hdl = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr()));
}
}
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor;
if (fastpath) {
op = Broadcast::make(shape);
p.resize(2);
} else {
op = Broadcast::make();
py::tuple args = py::make_tuple(shape_hdl, inp_hdl);
py::dict kwargs;
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32());
kwargs["device"] = getattr(inp_hdl, "device");
shape_tensor = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr()));
p.resize(3);
p[2] = shape_tensor.ptr();
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
py::object shape_hdl = _expand_args(args);
py::object shape_tuple;
try {
shape_tuple = _make_shape_tuple(shape_hdl);
} catch (py::error_already_set& err) {
shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl);
}
int32_t unspec_axis = -1;
if (PyTuple_Check(shape_tuple.ptr())) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape_tuple);
for (size_t i = 0; i < tup.size(); ++i) {
py::object obj = py::reinterpret_borrow<py::object>(tup[i]);
if (obj < py::int_(0)) {
if (obj.not_equal(py::int_(-1))) {
throw py::value_error(
"expect shape [" + std::to_string(i) + "] >= -1, got " +
repr(obj).cast<std::string>());
}
if (unspec_axis >= 0) {
throw py::value_error(
"multiple -1 in shape: " + std::to_string(unspec_axis) +
" & " + std::to_string(i));
}
unspec_axis = i;
}
}
}
auto [shape, fastpath] = tuple2vector(shape_tuple);
fastpath &= enable_fastpath(inp_hdl);
std::shared_ptr<OpDef> op;
std::vector<PyObject*> p;
py::object shape_tensor;
if (fastpath) {
if (unspec_axis >= 0) {
op = Reshape::make(unspec_axis, shape);
} else {
op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape);
}
p.resize(2);
} else {
shape.clear();
if (unspec_axis >= 0) {
op = Reshape::make(unspec_axis, shape);
} else {
op = Reshape::make();
}
py::tuple args = py::make_tuple(shape_hdl, inp_hdl);
py::dict kwargs;
kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32());
kwargs["device"] = getattr(inp_hdl, "device");
shape_tensor = py::reinterpret_steal<py::object>(
PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr()));
p.resize(3);
p[2] = shape_tensor.ptr();
}
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = inp_hdl.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
......@@ -900,4 +1093,18 @@ PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _broadcast_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _reshape_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
......@@ -16,4 +16,8 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
......@@ -267,7 +267,7 @@ def test_broadcast_auto_infer(is_varnode):
F.broadcast_to(xx, (None, 1, 2, 3))
F.broadcast_to(xx, (1, None, 2, 3))
t = tensor(2, dtype=np.int32)
t = make_tensor(2, network)
F.broadcast_to(xx, (t, None, 2, 3))
......
......@@ -51,57 +51,75 @@ bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape)
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op = def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp = inputs[1];
TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_shape[i] = ptr[i];
if (nr_inp == 1) {
out_shape.ndim = op.shape.size();
for (size_t i = 0; i < out_shape.ndim; ++i) {
out_shape[i] = op.shape[i];
}
} else {
auto&& tshp = inputs[1];
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
false};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Broadcast expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
size_t target_ndim = tshp.layout.shape[0];
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_shape[i] = ptr[i];
}
}
mgb_assert(
valid_broadcast(src.layout, out_shape),
"the input shape %s can not be broadcasted to target shape %s",
src.layout.to_string().c_str(), out_shape.to_string().c_str());
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
def.cast_final_safe<Broadcast>();
auto&& op = def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
TensorShape tshp;
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();
TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (nr_inp == 1) {
tshp.ndim = op.shape.size();
for (size_t i = 0; i < tshp.ndim; ++i) {
tshp[i] = op.shape[i];
}
} else {
auto&& tshp_nd = inputs[1];
cg::copy_tensor_value_to_shape(
tshp, tshp_nd->get_value().proxy_to_default_cpu());
}
TensorLayout tlayout = slayout.broadcast(tshp);
// memory forward
return {Tensor::make(src->blob(), src->offset(), tlayout)};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
return layout_checker;
}
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace broadcast
......@@ -118,35 +136,49 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp = inputs[1];
TensorShape out_shape;
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Reshape expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
}
size_t target_ndim = tshp.layout.shape[0];
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_shape[i] = ptr[i];
}
if (src.layout.ndim == 0) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false};
if (nr_inp == 1) {
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
false};
}
out_shape.ndim = op.shape.size();
for (size_t i = 0; i < out_shape.ndim; ++i) {
out_shape[i] = op.shape[i];
}
if (src.layout.ndim == 0) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
false};
}
} else {
auto&& tshp = inputs[1];
if (tshp.layout.ndim == 0 || tshp.value.empty()) {
out_shape.ndim = 0;
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
false};
}
mgb_assert(
tshp.layout.ndim == 1,
"target shape of Reshape expects ndim=1; got ndim=%lu actually",
tshp.layout.ndim);
if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
false};
}
size_t target_ndim = tshp.layout.shape[0];
out_shape.ndim = target_ndim;
auto* ptr = tshp.value.ptr<dt_int32>();
for (size_t i = 0; i < target_ndim; ++i) {
out_shape[i] = ptr[i];
}
if (src.layout.ndim == 0) {
return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}},
false};
}
}
if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(out_shape[op.axis] == -1);
out_shape[op.axis] = 1;
......@@ -167,19 +199,27 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<Reshape>();
auto&& op = def.cast_final_safe<Reshape>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();
TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
if (nr_inp == 1) {
tshp.ndim = op.shape.size();
for (size_t i = 0; i < tshp.ndim; ++i) {
tshp[i] = op.shape[i];
}
} else {
auto&& tshp_nd = inputs[1];
cg::copy_tensor_value_to_shape(
tshp, tshp_nd->get_value().proxy_to_default_cpu());
}
if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op.axis] == -1);
tshp[op.axis] = 1;
tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
TensorLayout tlayout;
mgb_assert(slayout.try_reshape(tlayout, tshp));
......@@ -188,17 +228,24 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Reshape>();
auto&& op = def.cast_final_safe<Reshape>();
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = [&](const TensorLayout& layout) {
TensorShape tshp;
TensorLayout ret;
cg::copy_tensor_value_to_shape(
tshp, inputs[1]->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = layout.total_nr_elems() / tshp.total_nr_elems();
if (inputs.size() == 1) {
tshp.ndim = op.shape.size();
for (size_t i = 0; i < tshp.ndim; ++i) {
tshp[i] = op.shape[i];
}
} else {
cg::copy_tensor_value_to_shape(
tshp, inputs[1]->get_value().proxy_to_default_cpu());
}
if (op.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op.axis] == -1);
tshp[op.axis] = 1;
tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems();
}
if (layout.try_reshape(ret, tshp)) {
return true;
......
......@@ -243,8 +243,10 @@ ValueRefList get_var_shape_rule(
ValueRefList reshape_rule(
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]);
mgb_assert(inputs.size() == 1 || inputs.size() == 2);
size_t nr_inp = inputs.size();
bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) ||
(nr_inp == 1 && reshape.shape.size() == 0);
if (is_scalar) {
return {scalar_type.make(imperative::apply(
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
......@@ -256,8 +258,10 @@ ValueRefList reshape_rule(
ValueRefList broadcast_rule(
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]);
mgb_assert(inputs.size() == 1 || inputs.size() == 2);
size_t nr_inp = inputs.size();
bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) ||
(nr_inp == 1 && broadcast.shape.size() == 0);
if (is_scalar) {
return {scalar_type.make(imperative::apply(
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
......
......@@ -250,7 +250,11 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> {
);
}
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>;
def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> {
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$shape
);
}
def Identity: MgbHashableOp<"Identity">;
......@@ -318,7 +322,11 @@ def Dimshuffle: MgbHashableOp<"Dimshuffle"> {
let results = (outs AnyMemRef);
}
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>;
def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]> {
let extraArguments = (ins
MgbArrayAttr<MgbI32Attr>:$shape
);
}
// TODO: merge Add/Remove Axis into AxisAddRemove as megbrain?
def AddAxis: MgbHashableOp<"AddAxis"> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册