提交 fa4bf168 编写于 作者: M Megvii Engine Team

feat(mge/functional): add repeat and tile opr

GitOrigin-RevId: a20d4b6fb0684699175916385e78e3a49776efee
上级 c33a7173
...@@ -279,8 +279,8 @@ class GradManager: ...@@ -279,8 +279,8 @@ class GradManager:
tensor.grad = grad tensor.grad = grad
else: else:
tensor.grad += grad tensor.grad += grad
if tensor.isscalar() and tensor.grad is not None: if tensor._isscalar() and tensor.grad is not None:
tensor.grad.setscalar() tensor.grad._setscalar()
finally: finally:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
......
...@@ -225,7 +225,7 @@ def getitem(tensor, index): ...@@ -225,7 +225,7 @@ def getitem(tensor, index):
op = builtin.IndexingMultiAxisVec(items=items) op = builtin.IndexingMultiAxisVec(items=items)
(result,) = apply(op, tensor, *tensors) (result,) = apply(op, tensor, *tensors)
if ret_scalar: if ret_scalar:
result.setscalar() result._setscalar()
return result return result
......
...@@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None): ...@@ -51,10 +51,10 @@ def concatenate(inputs, axis=0, *, device=None):
def astype(x, dtype): def astype(x, dtype):
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if not is_dtype_equal(x.dtype, dtype): if not is_dtype_equal(x.dtype, dtype):
isscalar = x.isscalar() isscalar = x._isscalar()
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) (x,) = apply(builtin.TypeCvt(dtype=dtype), x)
if isscalar: if isscalar:
x.setscalar() x._setscalar()
return x return x
...@@ -98,14 +98,14 @@ def result_type(*args): ...@@ -98,14 +98,14 @@ def result_type(*args):
def isscalar(x): def isscalar(x):
if isinstance(x, Tensor): if isinstance(x, Tensor):
return x.isscalar() return x._isscalar()
return np.isscalar(x) return np.isscalar(x)
def setscalar(x): def setscalar(x):
if isinstance(x, Tensor): if isinstance(x, Tensor):
x.setscalar() x._setscalar()
else: else:
raise NotImplementedError("Unsupport type {}".format(type(x))) raise NotImplementedError("Unsupport type {}".format(type(x)))
......
...@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list): ...@@ -67,7 +67,7 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
outputs = apply(op, inp) outputs = apply(op, inp)
for s, x in zip(shapes, outputs): for s, x in zip(shapes, outputs):
if not s: if not s:
x.setscalar() x._setscalar()
return outputs return outputs
......
...@@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd ...@@ -12,8 +12,8 @@ from ..core.ops.builtin import InplaceAdd
def _inplace_add_(dest, delta, alpha, beta): def _inplace_add_(dest, delta, alpha, beta):
isscalar = dest.isscalar() isscalar = dest._isscalar()
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
if isscalar: if isscalar:
dest.setscalar() dest._setscalar()
return dest return dest
...@@ -44,11 +44,13 @@ __all__ = [ ...@@ -44,11 +44,13 @@ __all__ = [
"linspace", "linspace",
"ones", "ones",
"ones_like", "ones_like",
"repeat",
"reshape", "reshape",
"split", "split",
"squeeze", "squeeze",
"stack", "stack",
"scatter", "scatter",
"tile",
"transpose", "transpose",
"where", "where",
"zeros", "zeros",
...@@ -987,3 +989,144 @@ def arange( ...@@ -987,3 +989,144 @@ def arange(
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:
return result.astype(dtype) return result.astype(dtype)
return result return result
def repeat(inp: Tensor, repeats: int, axis: Optional[int] = None):
"""
Repeat elements of an array.
:param inp: input tensor.
:param repeats: the number of repetitions for each element.
:param axis: the axis along which to repeat values. By default, use the
flattened input array, and return a flat output array.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
import megengine.functional as F
from megengine import tensor
x = tensor([[1, 2], [3, 4]], np.int32)
y = F.repeat(x, 2, axis=0)
print(y.numpy())
Outputs:
.. testoutput::
[[1 2]
[1 2]
[3 4]
[3 4]]
"""
if axis is None:
inp = inp.reshape(-1) # flatten
axis = 0
if inp._isscalar():
inp._unsetscalar()
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace
max_axis = len(shape) - 1
assert axis >= 0 and axis <= max_axis
assert repeats >= 1
base_shape, bcast_shape, target_shape = [], [], []
if axis != 0:
target_shape.append(shape[:axis])
base_shape.extend([shape[: axis + 1], [1,]])
bcast_shape.extend([shape[: axis + 1], [repeats,]])
target_shape.extend(
[shape[axis] * repeats,]
)
if axis + 1 <= max_axis:
base_shape.append(shape[axis + 1 :])
bcast_shape.append(shape[axis + 1 :])
target_shape.append(shape[axis + 1 :])
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
return out
def _tile_one_dim(inp, rep, axis):
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
# assume inp.ndim is not changed during trace
max_axis = len(shape) - 1
base_shape, bcast_shape, target_shape = [], [], []
if axis != 0:
base_shape.append(shape[:axis])
bcast_shape.append(shape[:axis])
target_shape.append(shape[:axis])
base_shape.extend([[1,], shape[axis:]])
bcast_shape.extend([rep, shape[axis:]])
target_shape.append(shape[axis] * rep)
if axis + 1 <= max_axis:
target_shape.append(shape[axis + 1 :])
out = broadcast_to(inp.reshape(concat(base_shape)), concat(bcast_shape)).reshape(
concat(target_shape)
)
return out
def tile(inp: Tensor, reps: Iterable[int]):
"""
Construct an array by repeating ``inp`` the number of times given by ``reps``. If reps has length d,
the result will have dimension of ``max(d, inp.ndim)``. It is required that ``d >= inp.dim``. If ``inp.ndim < d``,
``inp`` is promoted to be ``d``-dimensional by prepending new axis.
:param inp: input tensor.
:param reps: The number of repetitions of inp along each axis.
:return: output tensor.
Examples:
.. testcode::
import numpy as np
import megengine.functional as F
from megengine import tensor
x = tensor([[1, 2], [3, 4]], np.int32)
y = F.tile(x, (2,1))
print(y.numpy())
Outputs:
.. testoutput::
[[1 2]
[3 4]
[1 2]
[3 4]]
"""
shape = astensor1d(inp.shape, inp, dtype="int32", device=inp.device)
reps = astensor1d(reps, inp, dtype="int32", device=inp.device)
l_shape = len(shape)
l_reps = len(reps)
assert (
l_reps >= l_shape
), "Number of dimensions of tiled dims can not be smaller than number of dimensions of tensor"
for i in range(l_shape):
rep = reps[i + (l_reps - l_shape)]
inp = _tile_one_dim(inp, rep, i)
if l_reps > l_shape:
shape = inp.shape
extra = reps[:-l_shape]
extra_ones = ones_like(extra)
base_shape = concat([extra_ones, shape])
bcast_shape = concat([extra, shape])
target_shape = concat([extra, shape])
inp = broadcast_to(inp.reshape(base_shape), bcast_shape).reshape(target_shape)
return inp
...@@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin): ...@@ -51,10 +51,6 @@ class Tensor(_Tensor, ArrayMethodMixin):
cn = device._cn cn = device._cn
if isinstance(data, _Tensor): if isinstance(data, _Tensor):
if dtype is not None:
logger.warning(
"dtype does not work when creating a new Tensor with another Tensor"
)
obj = _Tensor.__new__(cls, data) obj = _Tensor.__new__(cls, data)
else: else:
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
......
...@@ -557,6 +557,11 @@ void TensorWrapper::setscalar() { ...@@ -557,6 +557,11 @@ void TensorWrapper::setscalar() {
} }
void TensorWrapper::unsetscalar() {
m_tensor->m_flags &= ~Tensor::Flags::SCALAR;
}
struct TensorWeakRef { struct TensorWeakRef {
std::weak_ptr<Tensor> wptr; std::weak_ptr<Tensor> wptr;
...@@ -794,8 +799,9 @@ void init_tensor(py::module m) { ...@@ -794,8 +799,9 @@ void init_tensor(py::module m) {
.def_getset<&TensorWrapper::dtype>("dtype") .def_getset<&TensorWrapper::dtype>("dtype")
.def_getset<&TensorWrapper::device>("device") .def_getset<&TensorWrapper::device>("device")
.def<&TensorWrapper::reset>("_reset") .def<&TensorWrapper::reset>("_reset")
.def<&TensorWrapper::isscalar>("isscalar") .def<&TensorWrapper::isscalar>("_isscalar")
.def<&TensorWrapper::setscalar>("setscalar") .def<&TensorWrapper::setscalar>("_setscalar")
.def<&TensorWrapper::unsetscalar>("_unsetscalar")
.def<&TensorWrapper::detach>("detach") .def<&TensorWrapper::detach>("detach")
.def<&TensorWrapper::_dev_tensor>("_dev_tensor") .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
.def<&TensorWrapper::_swap_out>("_swap_out") .def<&TensorWrapper::_swap_out>("_swap_out")
......
...@@ -153,6 +153,7 @@ struct TensorWrapper { ...@@ -153,6 +153,7 @@ struct TensorWrapper {
PyObject* detach(); PyObject* detach();
PyObject* isscalar(); PyObject* isscalar();
void setscalar(); void setscalar();
void unsetscalar();
PyObject* _dev_tensor(); PyObject* _dev_tensor();
void _swap_in(); void _swap_in();
void _swap_out(); void _swap_out();
......
...@@ -406,3 +406,53 @@ def test_copy_d2h(): ...@@ -406,3 +406,53 @@ def test_copy_d2h():
def test_copy_d2d(): def test_copy_d2d():
copy_test("gpu0", "gpu1") copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1") copy_test("gpu0:0", "gpu0:1")
@pytest.mark.parametrize(
"shape, repeats, axis",
[
((2,), 2, 0),
((2, 3, 4, 5), 3, 0),
((2, 3, 4, 5), 4, 3),
((2,), 2, None),
((2, 3, 4, 5), 3, None),
((), 1, None),
((), 10, None),
],
)
def test_repeat(shape, repeats, axis):
def repeat_func(inp):
return F.repeat(inp=inp, repeats=repeats, axis=axis)
if shape != ():
cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
else:
cases = [{"input": np.array(1.23)}]
opr_test(
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis),
)
@pytest.mark.parametrize(
"shape, reps",
[
((2,), (2,)),
((2, 3, 4, 5), (1, 1, 1, 1)),
((2, 3, 4, 5), (1, 2, 3, 4)),
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
],
)
def test_tile(shape, reps):
def tile_func(inp):
return F.tile(inp=inp, reps=reps)
cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
opr_test(
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps),
)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io import io
import itertools
from tempfile import mkstemp from tempfile import mkstemp
import numpy as np import numpy as np
...@@ -359,7 +360,7 @@ def test_trace_warp_perspective(): ...@@ -359,7 +360,7 @@ def test_trace_warp_perspective():
np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2])) np.testing.assert_equal(out.shape.numpy(), np.array([1, 1, 2, 2]))
return out return out
for i in range(1): for i in range(3):
f(x, M) f(x, M)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册