提交 ce30d045 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/functional): add arange in tensor.py

GitOrigin-RevId: ad88a4c18ecdfb8e4bb3a8da7497ac9f5a7e7343
上级 23478a0d
......@@ -75,6 +75,7 @@ from .nn import (
from .sort import argsort, sort, top_k
from .tensor import (
add_axis,
arange,
broadcast_to,
concat,
dimshuffle,
......
......@@ -17,6 +17,7 @@ from megengine._internal import CompGraph, CompNode
from ..core import zeros
from ..core.graph import _use_default_if_none
from ..core.tensor import Tensor, wrap_io_tensor
from .elemwise import ceil
from .utils import _decide_comp_node_and_comp_graph
......@@ -553,6 +554,46 @@ def linspace(
return ret.astype(dtype)
def arange(
start: Union[int, float, Tensor],
end: Union[int, float, Tensor],
step: Union[int, float, Tensor] = 1,
dtype=np.float32,
device: Optional[CompNode] = None,
comp_graph: Optional[CompGraph] = None,
) -> Tensor:
r"""
Returns a Tensor with values from `start` to `end` with adjacent interval `step`
:param start: starting value of the squence, shoule be scalar
:param end: ending value of the squence, shoule be scalar
:param step: the gap between each pair of adjacent values. Default 1
:param dtype: result data type
:return: The generated tensor
Examples:
.. testcode::
import numpy as np
import megengine.functional as F
a = F.arange(1, 5, 1)
print(a.numpy())
.. testoutput::
[1. 2. 3. 4.]
"""
if dtype is not np.float32:
raise ValueError("arange is only implemented for float32")
num = ceil((end - start) / step)
stop = start + step * (num - 1)
ret = linspace(start, stop, num, device=device, comp_graph=comp_graph)
return ret
def zeros_like(inp: Tensor) -> Tensor:
r"""
Returns a zero tensor with the same shape as input tensor
......
......@@ -157,6 +157,38 @@ def test_broadcast_to():
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
def test_arange():
cases = [
{"input": [1, 9, 1]},
{"input": [2, 10, 2]},
]
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
)
cases = [
{"input": [9, 1, -1]},
{"input": [10, 2, -2]},
]
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
)
cases = [
{"input": [9.3, 1.2, -0.5]},
{"input": [10.3, 2.1, -1.7]},
]
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
)
def test_add_update():
shape = (2, 3)
v = np.random.random(shape).astype(np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册