提交 fa4bf168 编写于 作者: M Megvii Engine Team

feat(mge/functional): add repeat and tile opr

GitOrigin-RevId: a20d4b6fb0684699175916385e78e3a49776efee
上级 c33a7173
......@@ -279,8 +279,8 @@ class GradManager:
tensor.grad = grad
else:
tensor.grad += grad
if tensor.isscalar() and tensor.grad is not None:
tensor.grad.setscalar()
if tensor._isscalar() and tensor.grad is not None:
tensor.grad._setscalar()
finally:
self.release()
backwarding_grad_manager = cache
......
......@@ -225,7 +225,7 @@ def getitem(tensor, index):
op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors)
if ret_scalar:
result.setscalar()
result._setscalar()
return result
......
......@@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_dtype_equal(x.dtype, dtype):
isscalar = x.isscalar()
isscalar = x._isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x)
if isscalar:
x.setscalar()
x._setscalar()
return x
......@@ -98,14 +98,14 @@ def result_type(*args):
def isscalar(x):
if isinstance(x, Tensor):
return x.isscalar()
return x._isscalar()
return np.isscalar(x)
def setscalar(x):
if isinstance(x, Tensor):
x.setscalar()
x._setscalar()
else:
raise NotImplementedError("Unsupport type {}".format(type(x)))
......
......@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
outputs = apply(op, inp)
for s, x in zip(shapes, outputs):
if not s:
x.setscalar()
x._setscalar()
return outputs
......
......@@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd
def _inplace_add_(dest, delta, alpha, beta):
isscalar = dest.isscalar()
isscalar = dest._isscalar()
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
if isscalar:
dest.setscalar()
dest._setscalar()
return dest
......@@ -44,11 +44,13 @@ __all__ = [
"linspace",
"ones",
"ones_like",
"repeat",
"reshape",
"split",
"squeeze",
"stack",
"scatter",
"tile",
"transpose",
"where",
"zeros",
......@@ -987,3 +989,144 @@ def arange(
if np.dtype(dtype) == np.int32:
return result.astype(dtype)
return result
def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
"""
Repeat elements of an array.
:param inp: input tensor.
:param repeats: the number of repetitions for each element.
:param axis: the axis along which to repeat values. By default, use the
flattened input array, and return a flat output array.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
import megengine.functional as F
from megengine import tensor
x = tensor([[1, 2], [3, 4]], np.int32)
y = F.repeat(x, 2, axis=0)
print(y.numpy())
Outputs:
.. testoutput::
[[1 2]
[1 2]
[3 4]
[3 4]]
"""
if axis is None:
inp = inp.reshape(-1) # flatten
axis = 0
if inp._isscalar():
inp._unsetscalar()
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace
max_axis = len(shape) - 1
assert axis >= 0 and axis <= max_axis
assert repeats >= 1
base_shape, bcast_shape, target_shape = [], [], []
if axis != 0:
target_shape.append(shape[:axis])
base_shape.extend([shape[: axis + 1], [1,]])
bcast_shape.extend([shape[: axis + 1], [repeats,]])
target_shape.extend(
[shape[axis] * repeats,]
)
if axis + 1 <= max_axis:
base_shape.append(shape[axis + 1 :])
bcast_shape.append(shape[axis + 1 :])
target_shape.append(shape[axis + 1 :])
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
return out
def _tile_one_dim(inp, rep, axis):
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace
max_axis = len(shape) - 1
base_shape, bcast_shape, target_shape = [], [], []
if axis != 0:
base_shape.append(shape[:axis])
bcast_shape.append(shape[:axis])
target_shape.append(shape[:axis])
base_shape.extend([[1,], shape[axis:]])
bcast_shape.extend([rep, shape[axis:]])
target_shape.append(shape[axis] * rep)
if axis + 1 <= max_axis:
target_shape.append(shape[axis + 1 :])
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
return out
def tile(inp: Tensor, reps: Iterable[int]):
"""
Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d,
the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``,
``inp`` is promoted to be ``d``-dimensional by prepending new axis.
:param inp: input tensor.
:param reps: The number of repetitions of inp along each axis.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
import megengine.functional as F
from megengine import tensor
x = tensor([[1, 2], [3, 4]], np.int32)
y = F.tile(x, (2,1))
print(y.numpy())
Outputs:
.. testoutput::
[[1 2]
[3 4]
[1 2]
[3 4]]
"""
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
reps = astensor1d(reps, inp, dtype="int32", device=inp.device)
l_shape = len(shape)
l_reps = len(reps)
assert (
l_reps >= l_shape
), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor"
for i in range(l_shape):
rep = reps[i + (l_reps - l_shape)]
inp = _tile_one_dim(inp, rep, i)
if l_reps > l_shape:
shape = inp.shape
extra = reps[:-l_shape]
extra_ones = ones_like(extra)
base_shape = concat([extra_ones, shape])
bcast_shape = concat([extra, shape])
target_shape = concat([extra, shape])
inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
return inp
......@@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
cn = device._cn
if isinstance(data, _Tensor):
if dtype is not None:
logger.warning(
"dtype does not work when creating a new Tensor with another Tensor"
)
obj = _Tensor.__new__(cls, data)
else:
if isinstance(data, np.ndarray):
......
......@@ -557,6 +557,11 @@ void TensorWrapper::setscalar() {
}
void TensorWrapper::unsetscalar() {
m_tensor->m_flags &= ~Tensor::Flags::SCALAR;
}
struct TensorWeakRef {
std::weak_ptr<Tensor> wptr;
......@@ -794,8 +799,9 @@ void init_tensor(py::module m) {
.def_getset<&TensorWrapper::dtype>("dtype")
.def_getset<&TensorWrapper::device>("device")
.def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar")
.def<&TensorWrapper::setscalar>("setscalar")
.def<&TensorWrapper::isscalar>("_isscalar")
.def<&TensorWrapper::setscalar>("_setscalar")
.def<&TensorWrapper::unsetscalar>("_unsetscalar")
.def<&TensorWrapper::detach>("detach")
.def<&TensorWrapper::_dev_tensor>("_dev_tensor")
.def<&TensorWrapper::_swap_out>("_swap_out")
......
......@@ -153,6 +153,7 @@ struct TensorWrapper {
PyObject* detach();
PyObject* isscalar();
void setscalar();
void unsetscalar();
PyObject* _dev_tensor();
void _swap_in();
void _swap_out();
......
......@@ -406,3 +406,53 @@ def test_copy_d2h():
def test_copy_d2d():
copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1")
@pytest.mark.parametrize(
"shape, repeats, axis",
[
((2,), 2, 0),
((2, 3, 4, 5), 3, 0),
((2, 3, 4, 5), 4, 3),
((2,), 2, None),
((2, 3, 4, 5), 3, None),
((), 1, None),
((), 10, None),
],
)
def test_repeat(shape, repeats, axis):
def repeat_func(inp):
return F.repeat(inp=inp, repeats=repeats, axis=axis)
if shape != ():
cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
else:
cases = [{"input": np.array(1.23)}]
opr_test(
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis),
)
@pytest.mark.parametrize(
"shape, reps",
[
((2,), (2,)),
((2, 3, 4, 5), (1, 1, 1, 1)),
((2, 3, 4, 5), (1, 2, 3, 4)),
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
],
)
def test_tile(shape, reps):
def tile_func(inp):
return F.tile(inp=inp, reps=reps)
cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
opr_test(
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps),
)
......@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io
import itertools
from tempfile import mkstemp
import numpy as np
......@@ -359,7 +360,7 @@ def test_trace_warp_perspective():
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
return out
for i in range(1):
for i in range(3):
f(x, M)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册