提交 f551d44e 编写于 作者: M Megvii Engine Team

feat(mge): rename arange paramter end -> stop

GitOrigin-RevId: ea9064e351985e3b8a43d620be5986b1bd5830d1
上级 9c5ea689
......@@ -929,15 +929,15 @@ def linspace(
def arange(
start: Union[int, float, Tensor] = 0,
end: Optional[Union[int, float, Tensor]] = None,
stop: Optional[Union[int, float, Tensor]] = None,
step: Union[int, float, Tensor] = 1,
dtype="float32",
device: Optional[CompNode] = None,
) -> Tensor:
r"""Returns a tensor with values from start to end with adjacent interval step.
r"""Returns a tensor with values from start to stop with adjacent interval step.
:param start: starting value of the squence, shoule be scalar.
:param end: ending value of the squence, shoule be scalar.
:param stop: ending value of the squence, shoule be scalar.
:param step: gap between each pair of adjacent values. Default: 1
:param dtype: result data type.
:return: generated tensor.
......@@ -961,16 +961,16 @@ def arange(
[0. 1. 2. 3. 4.]
"""
if end is None:
start, end = 0, start
if stop is None:
start, stop = 0, start
if isinstance(start, Tensor):
start = start.astype("float32")
if isinstance(end, Tensor):
end = end.astype("float32")
if isinstance(stop, Tensor):
stop = stop.astype("float32")
if isinstance(step, Tensor):
step = step.astype("float32")
num = ceil(Tensor((end - start) / step, device=device))
num = ceil(Tensor((stop - start) / step, device=device))
stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device)
if np.dtype(dtype) == np.int32:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册